Skip to content

Commit 5d99ce4

Browse files
authored
extend the MX cast benchmark to include casting to mxfp4 (#2693)
Update [ghstack-poisoned]
1 parent 418593c commit 5d99ce4

File tree

1 file changed

+54
-22
lines changed

1 file changed

+54
-22
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,24 @@ def scale_dim0_dim1_reference(
5454
return x_hp_d0_normalized, x_hp_d1_normalized.t(), amax_dim0, amax_dim1
5555

5656

57-
def to_mx_dim0_reference(x_hp, block_size, scaling_mode=ScaleCalculationMode.FLOOR):
58-
scale_d0, data_d0 = to_mx(
59-
x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode
60-
)
57+
def to_mx_dim0_reference(
58+
x_hp,
59+
block_size,
60+
scaling_mode=ScaleCalculationMode.FLOOR,
61+
target_dtype=torch.float8_e4m3fn,
62+
):
63+
scale_d0, data_d0 = to_mx(x_hp, target_dtype, block_size, scaling_mode=scaling_mode)
6164
return data_d0, scale_d0
6265

6366

64-
def to_mx_dim1_reference(x_hp, block_size, scaling_mode=ScaleCalculationMode.FLOOR):
67+
def to_mx_dim1_reference(
68+
x_hp,
69+
block_size,
70+
scaling_mode=ScaleCalculationMode.FLOOR,
71+
target_dtype=torch.float8_e4m3fn,
72+
):
6573
x_hp = x_hp.t().contiguous()
66-
scale_d1, data_d1 = to_mx(
67-
x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode
68-
)
74+
scale_d1, data_d1 = to_mx(x_hp, target_dtype, block_size, scaling_mode=scaling_mode)
6975
return data_d1.t(), scale_d1
7076

7177

@@ -88,13 +94,14 @@ def run(
8894
"dim0",
8995
"dim1",
9096
"dim0_dim1",
91-
"dim0_mx_floor",
92-
"dim0_mx_rceil",
93-
"dim1_mx_floor",
94-
"dim1_mx_rceil",
95-
"dim1_mx_triton_floor",
96-
"dim1_mx_cuda_floor",
97-
"dim1_mx_cuda_rceil",
97+
"dim0_mxfp8_floor",
98+
"dim0_mxfp4_floor",
99+
"dim0_mxfp8_rceil",
100+
"dim1_mxfp8_floor",
101+
"dim1_mxfp8_rceil",
102+
"dim1_mxfp8_triton_floor",
103+
"dim1_mxfp8_cuda_floor",
104+
"dim1_mxfp8_cuda_rceil",
98105
)
99106

100107
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
@@ -154,7 +161,7 @@ def run(
154161
)
155162
bps = bytes_rw / (time_us / 1e6)
156163

157-
elif mode == "dim0_mx_floor":
164+
elif mode == "dim0_mxfp8_floor":
158165
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
159166
y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE)
160167

@@ -172,7 +179,32 @@ def run(
172179
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
173180
bps = (bytes_r + bytes_w) / (time_us / 1e6)
174181

175-
elif mode == "dim0_mx_rceil":
182+
elif mode == "dim0_mxfp4_floor":
183+
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
184+
y_d0, s_d0 = to_mx_dim0_reference_c(
185+
x, BLOCK_SIZE, target_dtype=torch.float4_e2m1fn_x2
186+
)
187+
188+
for _ in range(2):
189+
__ = to_mx_dim0_reference_c(
190+
x, BLOCK_SIZE, target_dtype=torch.float4_e2m1fn_x2
191+
)
192+
time_us = benchmark_cuda_function_in_microseconds(
193+
lambda x, b: to_mx_dim0_reference_c(
194+
x, BLOCK_SIZE, target_dtype=torch.float4_e2m1fn_x2
195+
),
196+
x,
197+
BLOCK_SIZE,
198+
)
199+
200+
# TODO(future PR): make to_mx return float4 directly
201+
assert y_d0.dtype == torch.uint8
202+
assert s_d0.dtype == torch.float8_e8m0fnu
203+
bytes_r = x.numel() * bytes_per_el_bf16
204+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
205+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
206+
207+
elif mode == "dim0_mxfp8_rceil":
176208
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
177209
y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE, ScaleCalculationMode.RCEIL)
178210

@@ -190,7 +222,7 @@ def run(
190222
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
191223
bps = (bytes_r + bytes_w) / (time_us / 1e6)
192224

193-
elif mode == "dim1_mx_floor":
225+
elif mode == "dim1_mxfp8_floor":
194226
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
195227
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
196228

@@ -208,7 +240,7 @@ def run(
208240
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
209241
bps = (bytes_r + bytes_w) / (time_us / 1e6)
210242

211-
elif mode == "dim1_mx_rceil":
243+
elif mode == "dim1_mxfp8_rceil":
212244
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
213245
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE, ScaleCalculationMode.RCEIL)
214246

@@ -226,7 +258,7 @@ def run(
226258
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
227259
bps = (bytes_r + bytes_w) / (time_us / 1e6)
228260

229-
elif mode == "dim1_mx_triton_floor":
261+
elif mode == "dim1_mxfp8_triton_floor":
230262
y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
231263

232264
for _ in range(2):
@@ -243,7 +275,7 @@ def run(
243275
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
244276
bps = (bytes_r + bytes_w) / (time_us / 1e6)
245277

246-
elif mode == "dim1_mx_cuda_floor":
278+
elif mode == "dim1_mxfp8_cuda_floor":
247279
from torchao.prototype import mxfp8_cuda
248280

249281
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
@@ -269,7 +301,7 @@ def run(
269301
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
270302
bps = (bytes_r + bytes_w) / (time_us / 1e6)
271303

272-
elif mode == "dim1_mx_cuda_rceil":
304+
elif mode == "dim1_mxfp8_cuda_rceil":
273305
from torchao.prototype import mxfp8_cuda
274306

275307
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(

0 commit comments

Comments
 (0)