|
67 | 67 | get_float8_mem_sympy,
|
68 | 68 | get_gemm_time_sympy,
|
69 | 69 | )
|
| 70 | +from torchao.utils import is_MI300 |
70 | 71 |
|
71 | 72 |
|
72 | 73 | class LNLinearSigmoid(torch.nn.Module):
|
@@ -161,7 +162,10 @@ def get_gemm_times(
|
161 | 162 | if float8_recipe_name == "rowwise_with_gw_hp" and gemm_role == "grad_weight":
|
162 | 163 | f8_time_s = bf16_time_s
|
163 | 164 | else:
|
164 |
| - d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16 |
| 165 | + e4m3_dtype = torch.float8_e4m3fn |
| 166 | + if torch.version.hip and torch.cuda.is_available() and is_MI300(): |
| 167 | + e4m3_dtype = torch.float8_e4m3fnuz |
| 168 | + d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 |
165 | 169 | A = torch.zeros(M, K, device=device, dtype=d1)
|
166 | 170 | B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
|
167 | 171 | if float8_recipe_name == "tensorwise":
|
@@ -236,9 +240,11 @@ def run(
|
236 | 240 | mx_recipe_name,
|
237 | 241 | enable_fusion_modeling,
|
238 | 242 | )
|
239 |
| - bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16, None, None) |
| 243 | + bf16_gemm_time_sympy = get_gemm_time_sympy( |
| 244 | + M, K, N, torch.bfloat16, None, None, None |
| 245 | + ) |
240 | 246 | fp8_gemm_time_sympy = get_gemm_time_sympy(
|
241 |
| - M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name |
| 247 | + M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name, None |
242 | 248 | )
|
243 | 249 | print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
|
244 | 250 | print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
|
|
0 commit comments