Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,9 @@ def test_triton_mxfp8_dim1_randn(M, K):
)
@pytest.mark.parametrize("M", (256, 2048, 131072))
@pytest.mark.parametrize("K", (256, 5120, 7168))
def test_triton_mxfp8_dim0_randn(M, K):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16))
def test_triton_mxfp8_dim0_randn(M, K, orig_dtype):
x = torch.randn(M, K, dtype=orig_dtype, device="cuda")
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
Expand Down Expand Up @@ -521,18 +522,19 @@ def test_triton_mxfp8_dim0_zeros():
)
@pytest.mark.parametrize("M", (256, 2048, 131072))
@pytest.mark.parametrize("K", (256, 5120, 7168))
def test_triton_mxfp8_dequant_dim0(M, K):
x = torch.zeros(M, K, dtype=torch.bfloat16, device="cuda")
@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16))
def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype):
x = torch.zeros(M, K, dtype=orig_dtype, device="cuda")
block_size = 32
x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32)
hp_ref = to_dtype(
x_data,
x_scales,
torch.float8_e4m3fn,
block_size,
torch.bfloat16,
orig_dtype,
)
hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size)
hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, orig_dtype, block_size)
torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0)


Expand Down
7 changes: 4 additions & 3 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,12 +1279,13 @@ def triton_to_mxfp8_dim1_reference(
scale_e8m0_dim1.unsqueeze(-1),
)

@triton_op("torchao::triton_mxfp8_dequant_dim0", mutates_args={})
def triton_mxfp8_dequant_dim0(
e4m3_data: torch.Tensor,
e8m0_scales: torch.Tensor,
out_dtype: torch.dtype,
scale_block_size: int = 32,
) -> None:
) -> torch.Tensor:
assert scale_block_size == 32, "scale_block_size must be 32 for now"
assert out_dtype in (torch.bfloat16, torch.float32), (
"out_dtype must be bf16 or fp32"
Expand All @@ -1300,7 +1301,7 @@ def triton_mxfp8_dequant_dim0(
triton.cdiv(e4m3_data.shape[0], META["ROW_TILE_SIZE"]),
triton.cdiv(e4m3_data.shape[1], META["COL_TILE_SIZE"]),
)
_dequant_mxfp8_kernel[grid](
wrap_triton(_dequant_mxfp8_kernel)[grid](
e4m3_data,
e8m0_scales.to(torch.uint8),
out_buffer,
Expand Down Expand Up @@ -1371,8 +1372,8 @@ def _dequant_mxfp8_kernel(

@triton.jit
def _e8m0_to_fp32(scale_e8m0):
e8m0_exponent_bias = 127
e8m0_nan_val = 255
e8m0_exponent_bias = 127
s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias
s_fp = tl.exp2(s_offset.to(tl.float32))
s_fp = tl.where(scale_e8m0 != e8m0_nan_val, s_fp, float("nan"))
Expand Down
Loading