Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions benchmarks/float8/float8_inference_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@


@torch.no_grad()
def get_gpu_kernel_time(m, x):
def get_gpu_kernel_time(m, x, trace_filename=None):
# warm up
for _ in range(2):
__ = m(x)
Expand All @@ -72,6 +72,12 @@ def get_gpu_kernel_time(m, x):
for _ in range(n_iter):
__ = m(x)
torch.cuda.synchronize()

# save a trace, if requested
if trace_filename is not None:
print(f"exporting trace to {trace_filename}")
prof.export_chrome_trace(trace_filename)

# get the gpu kernel time and aggregate it
num_leaf_tensors = 1 + len(list(m.parameters()))
ref_times = profiler_output_to_filtered_time_by_kernel_name(
Expand Down Expand Up @@ -161,13 +167,15 @@ def run(
do_benchmarks: bool = True,
shape_gen_name: str = "pow2",
n_limit: Optional[int] = None,
save_profile_traces: bool = False,
):
"""
Args:
* `recipe_name`: quantization recipe (tensorwise, rowwise, mxfp8*, mxfp4*, nvfp4*)
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
# `save_profile_traces (optional)`: if True, saves profiling traces
"""
config_table = [
["GPU", torch.cuda.get_device_name(0)],
Expand Down Expand Up @@ -289,7 +297,11 @@ def run(
# get the bf16 gpu kernel time
torch._dynamo.reset()
m_bf16 = torch.compile(copy.deepcopy(m_orig))
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x)

bf16_trace_filename = None
if save_profile_traces:
bf16_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_bf16.json"
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, bf16_trace_filename)

# get the float8 dynamic scaling gpu kernel time
torch._dynamo.reset()
Expand Down Expand Up @@ -325,7 +337,11 @@ def run(
quantize_(m_fp8_dyn, config)

m_fp8_dyn = torch.compile(m_fp8_dyn)
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)

fp8_trace_filename = None
if save_profile_traces:
fp8_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_fp8.json"
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, fp8_trace_filename)

r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)

Expand Down
125 changes: 123 additions & 2 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math

import pytest
import torch
from torch._inductor.utils import run_and_get_code
Expand All @@ -22,6 +24,7 @@
ScaleCalculationMode,
to_dtype,
)
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
from torchao.quantization.utils import compute_error
from torchao.utils import (
is_sm_at_least_89,
Expand Down Expand Up @@ -388,6 +391,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
MXGemmKernelChoice.EMULATED,
pack_fp6,
None,
False,
)
tensor_hp = tensor_mx.dequantize(torch.float)
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))
Expand Down Expand Up @@ -645,8 +649,6 @@ def to_f8(x):
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked

rows, cols = shape
device = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -662,3 +664,122 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
rtol=0.0,
msg=f"Roundtrip failed for shape {shape} with use_triton_kernel={use_triton_kernel}",
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
@pytest.mark.parametrize("transpose", [False, True])
@pytest.mark.parametrize(
"shape",
(
(128, 64),
(1, 128, 64),
),
)
def test_scale_shape_matches_qdata(transpose, shape):
if len(shape) == 3 and transpose:
pytest.skip("transpose not yet implemented for 3D MXTensor")

block_size = 32

x_hp = torch.randn(*shape, device="cuda")
x = MXTensor.to_mx(
x_hp,
torch.float8_e4m3fn,
block_size,
ScaleCalculationMode.FLOOR,
)

if len(shape) == 2:
m_dim, k_dim = 0, 1
if transpose:
x_hp = x_hp.t()
x = x.t()
m_dim, k_dim = 1, 0
else:
assert len(shape) == 3, "unsupported"
m_dim, k_dim = 1, 2
if transpose:
x_hp = x_hp.transpose(-2, -1)
x = x.transpose(-2, -1)
m_dim, k_dim = 2, 1

orig_m = x_hp.shape[m_dim]
expected_padded_m = orig_m
actual_padded_m = x.scale.shape[m_dim]
assert expected_padded_m == actual_padded_m, (
f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x.scale.shape}"
)

orig_k = x_hp.shape[k_dim]
expected_padded_k = orig_k // block_size
actual_padded_k = x.scale.shape[k_dim]

assert expected_padded_k == actual_padded_k, (
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}"
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
@pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2))
@pytest.mark.parametrize("transpose", [False, True])
@pytest.mark.parametrize(
"shape",
(
(128, 64),
(1, 128, 64),
),
)
def test_swizzle(elem_dtype, transpose, shape):
if len(shape) == 3 and transpose:
pytest.skip("transpose not yet implemented for 3D MXTensor")

block_size = 32

x_hp = torch.randn(*shape, device="cuda")
x = MXTensor.to_mx(
x_hp,
elem_dtype,
block_size,
ScaleCalculationMode.FLOOR,
)

xs = MXTensor.to_mx(
x_hp,
elem_dtype,
block_size,
ScaleCalculationMode.FLOOR,
is_swizzled_scales=True,
)

if transpose:
x = x.t()
xs = xs.t()

torch.testing.assert_close(x.qdata, xs.qdata, atol=0, rtol=0)

if transpose:
leading_dims, M, K = x.shape[:-2], x.shape[-1], x.shape[-2]
xs_scale_unblocked = from_blocked(
xs.scale.t(), math.prod(leading_dims) * M, K // block_size
)
xs_scale_unblocked = xs_scale_unblocked.view(*leading_dims, M, K // block_size)
xs_scale_unblocked = xs_scale_unblocked.t()
else:
leading_dims, M, K = x.shape[:-2], x.shape[-2], x.shape[-1]
xs_scale_unblocked = from_blocked(
xs.scale, math.prod(leading_dims) * M, K // block_size
)
xs_scale_unblocked = xs_scale_unblocked.view(*leading_dims, M, K // block_size)

torch.testing.assert_close(
x.scale,
xs_scale_unblocked,
atol=0,
rtol=0,
)

x_dq = x.dequantize(x.dtype)
xs_dq = xs.dequantize(xs.dtype)
torch.testing.assert_close(x_dq, xs_dq, atol=0, rtol=0)
2 changes: 2 additions & 0 deletions torchao/prototype/mx_formats/inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _mx_inference_linear_transform(
block_size=config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
pack_fp6=False,
is_swizzled_scales=True,
)

# Convert weight to MX Tensor
Expand All @@ -121,6 +122,7 @@ def _mx_inference_linear_transform(
gemm_kernel_choice=config.gemm_kernel_choice,
pack_fp6=False, # TODO
act_quant_kwargs=act_quant_kwargs,
is_swizzled_scales=True,
)

module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ def triton_to_mxfp8_dim1(

return (
output_col_major.t(),
col_scale.view(torch.float8_e8m0fnu),
col_scale.view(torch.float8_e8m0fnu).squeeze(-1),
)

@register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default)
Expand Down Expand Up @@ -1274,7 +1274,7 @@ def triton_to_mxfp8_dim1_reference(
scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu)
return (
x_hp_d1_normalized.t(),
scale_e8m0_dim1.unsqueeze(-1),
scale_e8m0_dim1,
)

@triton.jit
Expand Down
Loading
Loading