@@ -86,21 +86,6 @@ def create_barrier_flags(m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count):
86
86
barrier_size = Sm100BlockScaledPersistentDenseGemmKernel .compute_barrier_flag_size (
87
87
m , n , l , mma_tiler_mn , cluster_shape_mn , sm_count
88
88
)
89
- #print("LOOK HERE", (barrier_size,))
90
- # NOTE: use_2cta_instrs from blockedscaled_gemm logic
91
-
92
- # use_2cta_instrs = mma_tiler_mn[0] == 256
93
- # cta_tile_shape_mn = (
94
- # mma_tiler_mn[0] // (2 if use_2cta_instrs else 1),
95
- # mma_tiler_mn[1],
96
- # )
97
- # problem_shape_ntile_mn = (m // cta_tile_shape_mn[0], n // cta_tile_shape_mn[1])
98
- # num_tiles_per_batch = problem_shape_ntile_mn[0] * problem_shape_ntile_mn[1]
99
- # num_tiles = num_tiles_per_batch * l
100
- # num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
101
- # +num_sms for final barrier
102
- # num_tiles + num_sms
103
-
104
89
barrier_flag = symm_mem .empty ((barrier_size ,), device = "cuda" , dtype = torch .int32 )
105
90
106
91
barrier_flag .fill_ (0 )
@@ -158,8 +143,6 @@ def run_blockscaled_gemm_all_reduce_python_interface(
158
143
l , m = lm
159
144
k , n = kn
160
145
161
- #print(f"device: {device}")
162
-
163
146
if not Sm100BlockScaledPersistentDenseGemmKernel .can_implement (
164
147
get_cutlass_dtype (ab_dtype ),
165
148
get_cutlass_dtype (sf_dtype ),
@@ -201,7 +184,6 @@ def run_blockscaled_gemm_all_reduce_python_interface(
201
184
init_type = cutlass_torch .TensorInitType .SCALAR ,
202
185
init_config = cutlass_torch .ScalarInitConfig (value = 0.0 ),
203
186
)
204
- #print(f"Rank {rank}: c_ref INITIAL shape={c_ref.shape}, stride={c_ref.stride()}")
205
187
a_tensor , a_torch = cutlass_torch .cute_tensor_like (
206
188
a_ref ,
207
189
get_cutlass_dtype (ab_dtype ),
@@ -214,21 +196,12 @@ def run_blockscaled_gemm_all_reduce_python_interface(
214
196
is_dynamic_layout = True ,
215
197
assumed_align = 16 ,
216
198
)
217
- # c_tensor, c_torch = cutlass_torch.cute_tensor_like(
218
- # c_ref,
219
- # get_cutlass_dtype(c_dtype),
220
- # is_dynamic_layout=True,
221
- # assumed_align=16,
222
- # )
223
199
c_tensor , c_tensor_mc , c_torch , c_torch_mc = create_mc_tensor (
224
200
c_ref ,
225
201
get_cutlass_dtype (c_dtype ),
226
202
# (1 if c_major == "n" else 0),
227
203
is_dynamic_layout = True ,
228
204
)
229
- # print(
230
- # f"Rank {rank}: c_torch INITIAL shape={c_torch.shape}, stride={c_torch.stride()}"
231
- # )
232
205
alpha_tensor = (
233
206
torch .randn (l , dtype = torch .float32 , device = device ) if fuse_alpha else None
234
207
)
@@ -271,15 +244,11 @@ def run_blockscaled_gemm_all_reduce_python_interface(
271
244
sfb_ref , sfb_tensor , sfb_torch = create_scale_factor_tensor (
272
245
l , n , k , sf_vec_size , get_cutlass_dtype (sf_dtype ), device
273
246
)
274
- # masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
275
247
if rank == 0 :
276
248
masked_m_tensor = torch .randint (0 , m , (l ,), dtype = torch .int32 , device = device )
277
249
else :
278
250
masked_m_tensor = torch .empty ((l ,), dtype = torch .int32 , device = device )
279
251
torch .distributed .broadcast (masked_m_tensor , src = 0 )
280
- # to hack and test:
281
- # masked_m_tensor = torch.full((l,), m, dtype=torch.int32, device=device)
282
- # print(f"Rank {rank}: masked_m = {masked_m_tensor}")
283
252
for _ in range (iterations ):
284
253
dst_signals = (
285
254
torch .zeros ((l ,), dtype = torch .uint32 , device = "cuda" )
@@ -328,18 +297,12 @@ def run_blockscaled_gemm_all_reduce_python_interface(
328
297
)
329
298
# Convert c back to f32 for comparison.
330
299
ref = ref .permute (2 , 0 , 1 ).contiguous ().permute (1 , 2 , 0 )
331
- # print(f"Rank {rank}: c_ref shape={c_ref.shape}, stride={c_ref.stride()}")
332
- # print(f"Rank {rank}: ref shape={ref.shape}, stride={ref.stride()}")
333
- # print(f"Rank {rank}: c_torch shape={c_torch.shape}, stride={c_torch.stride()}")
334
300
cute .testing .convert (
335
301
c_tensor ,
336
302
from_dlpack (c_ref , assumed_align = 16 ).mark_layout_dynamic (
337
303
leading_dim = (1 if c_major == "n" else 0 )
338
304
),
339
305
)
340
- # print(f"Rank {rank}: c_ref shape={c_ref.shape}, stride={c_ref.stride()}")
341
- # print(f"Rank {rank}: ref shape={ref.shape}, stride={ref.stride()}")
342
- # print(f"Rank {rank}: c_torch shape={c_torch.shape}, stride={c_torch.stride()}")
343
306
if c_dtype in ("float32" , "float16" , "bfloat16" ):
344
307
for i in range (l ):
345
308
# skip testing c_ref & ref
@@ -481,23 +444,23 @@ def multi_process_parallel(
481
444
@pytest .mark .parametrize (
482
445
"ab_dtype,sf_dtype,c_dtype,sf_vec_size" ,
483
446
[
484
- # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
485
- # ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
486
- # ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
487
- # ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
488
- # ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
489
- # ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
490
- # ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
491
- # ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
492
- # ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
493
- # ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
494
- ("float8_e4m3fn" , "float8_e8m0fnu" , "float8_e4m3fn" , 32 ),
495
- # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
496
- # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
497
- # ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
498
- # ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
447
+ ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 ),
448
+ ("float4_e2m1fn" , "float8_e8m0fnu" , "float16" , 16 ),
449
+ ("float4_e2m1fn" , "float8_e8m0fnu" , "bfloat16" , 16 ),
450
+ ("float4_e2m1fn" , "float8_e8m0fnu" , "float32" , 16 ),
451
+ ("float4_e2m1fn" , "float8_e4m3fn" , "float16" , 16 ),
452
+ ("float4_e2m1fn" , "float8_e4m3fn" , "bfloat16" , 16 ),
453
+ ("float4_e2m1fn" , "float8_e4m3fn" , "float32" , 16 ),
454
+ ("float8_e4m3fn" , "float8_e8m0fnu" , "bfloat16" , 32 ),
455
+ ("float8_e4m3fn" , "float8_e8m0fnu" , "float16" , 32 ),
456
+ ("float8_e4m3fn" , "float8_e8m0fnu" , "float32" , 32 ),
457
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
458
+ ("float8_e4m3fn" , "float8_e8m0fnu" , "float8_e5m2" , 32 ),
459
+ ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 ),
460
+ ("float8_e5m2" , "float8_e8m0fnu" , "float16" , 32 ),
461
+ ("float8_e5m2" , "float8_e8m0fnu" , "float32" , 32 ),
499
462
# ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
500
- # ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
463
+ ("float8_e5m2" , "float8_e8m0fnu" , "float8_e5m2" , 32 ),
501
464
],
502
465
)
503
466
@pytest .mark .parametrize ("a_major" , ["k" ])
@@ -538,7 +501,6 @@ def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(
538
501
pytest .skip (
539
502
f"world_size { world_size } is greater than available_gpus { available_gpus } "
540
503
)
541
- #device = torch.device("cuda", rank)
542
504
major , minor = torch .cuda .get_device_capability (torch .device ("cuda:0" ))
543
505
if not (major == 10 and minor == 0 ):
544
506
pytest .skip ("Cute-dsl backend is only supported on SM100." )
0 commit comments