Skip to content

Commit c88ebe8

Browse files
authored
fix float8 training benchmarks on AMD (#2737)
Summary: Small fixes to make the float8 training rowwise benchmarks work properly on AMD GPUs, just making sure the right float8 flavor is used. Test Plan: ```bash python benchmarks/float8/float8_roofline.py ~/local/tmp/20250811_amd_mi300x_rowwise_with_gw_hp.csv --float8_recipe_name rowwise_with_gw_hp --shape_gen_name pow2_extended ``` MI300x results: https://gist.github.com/vkuzo/586af24b4c9a90f107590ba5e96dd7eb H100 results: https://gist.github.com/vkuzo/586af24b4c9a90f107590ba5e96dd7eb Reviewers: Subscribers: Tasks: Tags:
1 parent d7f7bf2 commit c88ebe8

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

benchmarks/float8/bench_matmul.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchao.ops import mx_fp4_bf16
1919
from torchao.prototype.mx_formats.mx_tensor import to_mx
2020
from torchao.testing.training.roofline_utils import get_specs
21+
from torchao.utils import is_MI300
2122

2223

2324
@torch.inference_mode()
@@ -46,6 +47,7 @@ def run(
4647
bf16_peak_tops = specs["bf16_peak_tops"]
4748
fp8_peak_tops = specs["fp8_peak_tops"]
4849
fp4_peak_tops = specs.get("fp4_peak_tops", 0.0) # only on sm120
50+
print(f"recipe: {recipe}")
4951
print(f"gpu_name: {torch.cuda.get_device_name(0)}")
5052
print(
5153
f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}, fp4 {fp4_peak_tops:.2e}"
@@ -56,8 +58,8 @@ def run(
5658
"M",
5759
"K",
5860
"N",
61+
"ref_time_s",
5962
"time_s",
60-
"speedup",
6163
"fp8_speedup",
6264
)
6365
results = []
@@ -106,7 +108,10 @@ def run(
106108
else:
107109
# raw float8 matmul (upper bound for what we can achive in eager mode)
108110
# TODO(future): add e5m2
109-
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
111+
e4m3_dtype = torch.float8_e4m3fn
112+
if torch.version.hip and torch.cuda.is_available() and is_MI300():
113+
e4m3_dtype = torch.float8_e4m3fnuz
114+
d1, d2, d3 = e4m3_dtype, e4m3_dtype, dtype
110115
A = A_hp.to(d1)
111116
B = B_hp_t.to(d2).contiguous().T
112117
peak_tops = fp8_peak_tops

benchmarks/float8/float8_roofline.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
get_float8_mem_sympy,
6868
get_gemm_time_sympy,
6969
)
70+
from torchao.utils import is_MI300
7071

7172

7273
class LNLinearSigmoid(torch.nn.Module):
@@ -161,7 +162,10 @@ def get_gemm_times(
161162
if float8_recipe_name == "rowwise_with_gw_hp" and gemm_role == "grad_weight":
162163
f8_time_s = bf16_time_s
163164
else:
164-
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16
165+
e4m3_dtype = torch.float8_e4m3fn
166+
if torch.version.hip and torch.cuda.is_available() and is_MI300():
167+
e4m3_dtype = torch.float8_e4m3fnuz
168+
d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16
165169
A = torch.zeros(M, K, device=device, dtype=d1)
166170
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
167171
if float8_recipe_name == "tensorwise":
@@ -236,9 +240,11 @@ def run(
236240
mx_recipe_name,
237241
enable_fusion_modeling,
238242
)
239-
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16, None, None)
243+
bf16_gemm_time_sympy = get_gemm_time_sympy(
244+
M, K, N, torch.bfloat16, None, None, None
245+
)
240246
fp8_gemm_time_sympy = get_gemm_time_sympy(
241-
M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name
247+
M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name, None
242248
)
243249
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
244250
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)

0 commit comments

Comments
 (0)