Skip to content

Commit b2ea221

Browse files
authored
Add gpu_name as a parameter in roofline estimate utils
Differential Revision: D79415350 Pull Request resolved: #2657
1 parent 2f79364 commit b2ea221

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

torchao/testing/training/roofline_utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@
6565
}
6666

6767

68-
def get_specs():
69-
gpu_name = torch.cuda.get_device_name(0)
68+
def get_specs(gpu_name: Optional[str] = None):
69+
if gpu_name is None:
70+
gpu_name = torch.cuda.get_device_name(0)
7071
return gpu_name_to_specs[gpu_name]
7172

7273

@@ -214,10 +215,15 @@ def get_tensor_memory_traffic_ovhd_s(
214215

215216

216217
def get_individual_gemm_time_sympy(
217-
M: sympy.Symbol, K: sympy.Symbol, N: sympy.Symbol, dtype, mx_recipe_name
218+
M: sympy.Symbol,
219+
K: sympy.Symbol,
220+
N: sympy.Symbol,
221+
dtype,
222+
mx_recipe_name,
223+
gpu_name: Optional[str] = None,
218224
) -> sympy.Symbol:
219225
# compute bound
220-
specs = get_specs()
226+
specs = get_specs(gpu_name)
221227
gemm_ops = 2 * M * K * N
222228
if dtype is torch.bfloat16:
223229
peak_tops = specs["bf16_peak_tops"]
@@ -265,6 +271,7 @@ def get_gemm_time_sympy(
265271
dtype,
266272
float8_recipe_name: Optional[str],
267273
mx_recipe_name: Optional[str],
274+
gpu_name: Optional[str],
268275
):
269276
# next: add rowwise_with_gw_hp here
270277
# note: this function is currently not super accurate for small shapes:
@@ -279,13 +286,13 @@ def get_gemm_time_sympy(
279286
gemm_dtype_grad_weight = torch.bfloat16
280287

281288
gemm_output_time_s = get_individual_gemm_time_sympy(
282-
M, K, N, gemm_dtype_input, mx_recipe_name
289+
M, K, N, gemm_dtype_input, mx_recipe_name, gpu_name
283290
)
284291
gemm_grad_input_time_s = get_individual_gemm_time_sympy(
285-
M, N, K, gemm_dtype_grad_input, mx_recipe_name
292+
M, N, K, gemm_dtype_grad_input, mx_recipe_name, gpu_name
286293
)
287294
gemm_grad_weight_time_s = get_individual_gemm_time_sympy(
288-
K, M, N, gemm_dtype_grad_weight, mx_recipe_name
295+
K, M, N, gemm_dtype_grad_weight, mx_recipe_name, gpu_name
289296
)
290297
total = gemm_output_time_s + gemm_grad_input_time_s + gemm_grad_weight_time_s
291298
return total
@@ -298,8 +305,9 @@ def get_float8_mem_sympy(
298305
float8_recipe_name: Optional[str],
299306
mx_recipe_name: Optional[str],
300307
enable_fusion_modeling: bool,
308+
gpu_name: Optional[str] = None,
301309
):
302-
specs = get_specs()
310+
specs = get_specs(gpu_name)
303311

304312
# there are three gemms in the fwd/bwd of a linear:
305313
#

0 commit comments

Comments
 (0)