@@ -158,7 +158,7 @@ def run_blockscaled_gemm_all_reduce_python_interface(
158
158
l , m = lm
159
159
k , n = kn
160
160
161
- print (f"device: { device } " )
161
+ # print(f"device: {device}")
162
162
163
163
if not Sm100BlockScaledPersistentDenseGemmKernel .can_implement (
164
164
get_cutlass_dtype (ab_dtype ),
@@ -201,7 +201,7 @@ def run_blockscaled_gemm_all_reduce_python_interface(
201
201
init_type = cutlass_torch .TensorInitType .SCALAR ,
202
202
init_config = cutlass_torch .ScalarInitConfig (value = 0.0 ),
203
203
)
204
- print (f"Rank { rank } : c_ref INITIAL shape={ c_ref .shape } , stride={ c_ref .stride ()} " )
204
+ # print(f"Rank {rank}: c_ref INITIAL shape={c_ref.shape}, stride={c_ref.stride()}")
205
205
a_tensor , a_torch = cutlass_torch .cute_tensor_like (
206
206
a_ref ,
207
207
get_cutlass_dtype (ab_dtype ),
@@ -226,9 +226,9 @@ def run_blockscaled_gemm_all_reduce_python_interface(
226
226
# (1 if c_major == "n" else 0),
227
227
is_dynamic_layout = True ,
228
228
)
229
- print (
230
- f"Rank { rank } : c_torch INITIAL shape={ c_torch .shape } , stride={ c_torch .stride ()} "
231
- )
229
+ # print(
230
+ # f"Rank {rank}: c_torch INITIAL shape={c_torch.shape}, stride={c_torch.stride()}"
231
+ # )
232
232
alpha_tensor = (
233
233
torch .randn (l , dtype = torch .float32 , device = device ) if fuse_alpha else None
234
234
)
@@ -279,7 +279,7 @@ def run_blockscaled_gemm_all_reduce_python_interface(
279
279
torch .distributed .broadcast (masked_m_tensor , src = 0 )
280
280
# to hack and test:
281
281
# masked_m_tensor = torch.full((l,), m, dtype=torch.int32, device=device)
282
- print (f"Rank { rank } : masked_m = { masked_m_tensor } " )
282
+ # print(f"Rank {rank}: masked_m = {masked_m_tensor}")
283
283
for _ in range (iterations ):
284
284
dst_signals = (
285
285
torch .zeros ((l ,), dtype = torch .uint32 , device = "cuda" )
@@ -328,9 +328,9 @@ def run_blockscaled_gemm_all_reduce_python_interface(
328
328
)
329
329
# Convert c back to f32 for comparison.
330
330
ref = ref .permute (2 , 0 , 1 ).contiguous ().permute (1 , 2 , 0 )
331
- print (f"Rank { rank } : c_ref shape={ c_ref .shape } , stride={ c_ref .stride ()} " )
332
- print (f"Rank { rank } : ref shape={ ref .shape } , stride={ ref .stride ()} " )
333
- print (f"Rank { rank } : c_torch shape={ c_torch .shape } , stride={ c_torch .stride ()} " )
331
+ # print(f"Rank {rank}: c_ref shape={c_ref.shape}, stride={c_ref.stride()}")
332
+ # print(f"Rank {rank}: ref shape={ref.shape}, stride={ref.stride()}")
333
+ # print(f"Rank {rank}: c_torch shape={c_torch.shape}, stride={c_torch.stride()}")
334
334
cute .testing .convert (
335
335
c_tensor ,
336
336
from_dlpack (c_ref , assumed_align = 16 ).mark_layout_dynamic (
@@ -472,70 +472,32 @@ def multi_process_parallel(
472
472
), f"Process { i } failed with exit code { procs [i ].exitcode } "
473
473
474
474
475
- # @pytest.mark.parametrize("lm", [(1, 1024), (2, 512), (4, 256)])
476
- # @pytest.mark.parametrize("kn", [(7168, 4096), (2048, 7168)])
477
- # @pytest.mark.parametrize(
478
- # "ab_dtype,sf_dtype,c_dtype,sf_vec_size",
479
- # [
480
- # ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
481
- # ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
482
- # ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
483
- # ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
484
- # ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
485
- # ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
486
- # ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
487
- # ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
488
- # ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
489
- # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
490
- # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
491
- # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
492
- # ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
493
- # ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
494
- # ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
495
- # ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
496
- # ],
497
- # )
498
- # @pytest.mark.parametrize("a_major", ["k"])
499
- # @pytest.mark.parametrize("b_major", ["k"])
500
- # @pytest.mark.parametrize("c_major", ["n"])
501
- # @pytest.mark.parametrize("fuse_alpha", [False, True])
502
- # @pytest.mark.parametrize("alpha_dtype", ["float32"])
503
- # @pytest.mark.parametrize("mma_tiler_mn", [(128, 128)])
504
- # @pytest.mark.parametrize("cluster_shape_mn", [(1, 1)])
505
- # @pytest.mark.parametrize("sm_count", [132, None])
506
- # @pytest.mark.parametrize("tolerance", [1e-01])
507
- # @pytest.mark.parametrize("iterations", [3])
508
- # @pytest.mark.parametrize("enable_dst_signals", [False, True])
509
-
510
- # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
511
-
512
-
513
475
@pytest .mark .skipif (
514
476
not is_cute_dsl_available (), reason = "Please `pip install nvidia-cutlass-dsl`"
515
477
)
516
478
@pytest .mark .parametrize ("world_size" , [8 ])
517
479
@pytest .mark .parametrize ("lm" , [(1 , 1024 ), (2 , 512 ), (4 , 256 )])
518
- @pytest .mark .parametrize ("kn" , [(7168 , 4096 )])
480
+ @pytest .mark .parametrize ("kn" , [(7168 , 4096 ), ( 2048 , 7168 ) ])
519
481
@pytest .mark .parametrize (
520
482
"ab_dtype,sf_dtype,c_dtype,sf_vec_size" ,
521
483
[
522
484
("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 )
523
- # ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
524
- # ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
525
- # ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
526
- # ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
527
- # ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
528
- # ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
529
- # ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
530
- # ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
531
- # ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
532
- # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
533
- # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
534
- # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
535
- # ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
536
- # ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
537
- # ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
538
- # ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
485
+ ("float4_e2m1fn" , "float8_e8m0fnu" , "float16" , 16 ),
486
+ ("float4_e2m1fn" , "float8_e8m0fnu" , "bfloat16" , 16 ),
487
+ ("float4_e2m1fn" , "float8_e8m0fnu" , "float32" , 16 ),
488
+ ("float4_e2m1fn" , "float8_e4m3fn" , "float16" , 16 ),
489
+ ("float4_e2m1fn" , "float8_e4m3fn" , "bfloat16" , 16 ),
490
+ ("float4_e2m1fn" , "float8_e4m3fn" , "float32" , 16 ),
491
+ ("float8_e4m3fn" , "float8_e8m0fnu" , "bfloat16" , 32 ),
492
+ ("float8_e4m3fn" , "float8_e8m0fnu" , "float16" , 32 ),
493
+ ("float8_e4m3fn" , "float8_e8m0fnu" , "float32" , 32 ),
494
+ ("float8_e4m3fn" , "float8_e8m0fnu" , "float8_e4m3fn" , 32 ),
495
+ ("float8_e4m3fn" , "float8_e8m0fnu" , "float8_e5m2" , 32 ),
496
+ ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 ),
497
+ ("float8_e5m2" , "float8_e8m0fnu" , "float16" , 32 ),
498
+ ("float8_e5m2" , "float8_e8m0fnu" , "float32" , 32 ),
499
+ ("float8_e5m2" , "float8_e8m0fnu" , "float8_e4m3fn" , 32 ),
500
+ ("float8_e5m2" , "float8_e8m0fnu" , "float8_e5m2" , 32 ),
539
501
],
540
502
)
541
503
@pytest .mark .parametrize ("a_major" , ["k" ])
@@ -576,6 +538,39 @@ def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(
576
538
pytest .skip (
577
539
f"world_size { world_size } is greater than available_gpus { available_gpus } "
578
540
)
541
+ #device = torch.device("cuda", rank)
542
+ major , minor = torch .cuda .get_device_capability (torch .device ("cuda:0" ))
543
+ if not (major == 10 and minor == 0 ):
544
+ pytest .skip ("Cute-dsl backend is only supported on SM100." )
545
+ if enable_dst_signals and (sm_count is None ):
546
+ pytest .skip ("dst_signals require sm_count" )
547
+
548
+ l , m = lm
549
+ k , n = kn
550
+ if not Sm100BlockScaledPersistentDenseGemmKernel .can_implement (
551
+ get_cutlass_dtype (ab_dtype ),
552
+ get_cutlass_dtype (sf_dtype ),
553
+ sf_vec_size ,
554
+ get_cutlass_dtype (c_dtype ),
555
+ mma_tiler_mn ,
556
+ cluster_shape_mn ,
557
+ m ,
558
+ n ,
559
+ k ,
560
+ l ,
561
+ a_major ,
562
+ b_major ,
563
+ c_major ,
564
+ ):
565
+ pytest .skip (
566
+ f"Unsupported testcase { ab_dtype } , { sf_dtype } , { sf_vec_size } , { c_dtype } , { mma_tiler_mn } , { cluster_shape_mn } , { m } , { n } , { k } , { l } , { a_major } , { b_major } , { c_major } "
567
+ )
568
+
569
+ if not (a_major == "k" and b_major == "k" and c_major == "n" ):
570
+ # not supported since we try to align deepgemm for now
571
+ pytest .skip (
572
+ f"Skip non deepgemm-like cases { a_major } , { b_major } , { c_major } . Might be added later"
573
+ )
579
574
print (f"Running test for world_size={ world_size } " )
580
575
multi_process_parallel (
581
576
world_size ,
0 commit comments