65
65
}
66
66
67
67
68
- def get_specs ():
69
- gpu_name = torch .cuda .get_device_name (0 )
68
+ def get_specs (gpu_name : Optional [str ] = None ):
69
+ if gpu_name is None :
70
+ gpu_name = torch .cuda .get_device_name (0 )
70
71
return gpu_name_to_specs [gpu_name ]
71
72
72
73
@@ -214,10 +215,15 @@ def get_tensor_memory_traffic_ovhd_s(
214
215
215
216
216
217
def get_individual_gemm_time_sympy (
217
- M : sympy .Symbol , K : sympy .Symbol , N : sympy .Symbol , dtype , mx_recipe_name
218
+ M : sympy .Symbol ,
219
+ K : sympy .Symbol ,
220
+ N : sympy .Symbol ,
221
+ dtype ,
222
+ mx_recipe_name ,
223
+ gpu_name : Optional [str ] = None ,
218
224
) -> sympy .Symbol :
219
225
# compute bound
220
- specs = get_specs ()
226
+ specs = get_specs (gpu_name )
221
227
gemm_ops = 2 * M * K * N
222
228
if dtype is torch .bfloat16 :
223
229
peak_tops = specs ["bf16_peak_tops" ]
@@ -265,6 +271,7 @@ def get_gemm_time_sympy(
265
271
dtype ,
266
272
float8_recipe_name : Optional [str ],
267
273
mx_recipe_name : Optional [str ],
274
+ gpu_name : Optional [str ],
268
275
):
269
276
# next: add rowwise_with_gw_hp here
270
277
# note: this function is currently not super accurate for small shapes:
@@ -279,13 +286,13 @@ def get_gemm_time_sympy(
279
286
gemm_dtype_grad_weight = torch .bfloat16
280
287
281
288
gemm_output_time_s = get_individual_gemm_time_sympy (
282
- M , K , N , gemm_dtype_input , mx_recipe_name
289
+ M , K , N , gemm_dtype_input , mx_recipe_name , gpu_name
283
290
)
284
291
gemm_grad_input_time_s = get_individual_gemm_time_sympy (
285
- M , N , K , gemm_dtype_grad_input , mx_recipe_name
292
+ M , N , K , gemm_dtype_grad_input , mx_recipe_name , gpu_name
286
293
)
287
294
gemm_grad_weight_time_s = get_individual_gemm_time_sympy (
288
- K , M , N , gemm_dtype_grad_weight , mx_recipe_name
295
+ K , M , N , gemm_dtype_grad_weight , mx_recipe_name , gpu_name
289
296
)
290
297
total = gemm_output_time_s + gemm_grad_input_time_s + gemm_grad_weight_time_s
291
298
return total
@@ -298,8 +305,9 @@ def get_float8_mem_sympy(
298
305
float8_recipe_name : Optional [str ],
299
306
mx_recipe_name : Optional [str ],
300
307
enable_fusion_modeling : bool ,
308
+ gpu_name : Optional [str ] = None ,
301
309
):
302
- specs = get_specs ()
310
+ specs = get_specs (gpu_name )
303
311
304
312
# there are three gemms in the fwd/bwd of a linear:
305
313
#
0 commit comments