@@ -354,7 +354,29 @@ def run_blockscaled_gemm_all_reduce_python_interface(
354
354
rtol = 1e-02 ,
355
355
)
356
356
357
- def _run_correctness_worker (world_size , rank , distributed_init_port ):
357
+ def _run_correctness_worker (
358
+ world_size ,
359
+ rank ,
360
+ distributed_init_port ,
361
+ lm ,
362
+ kn ,
363
+ ab_dtype ,
364
+ sf_dtype ,
365
+ sf_vec_size ,
366
+ c_dtype ,
367
+ a_major ,
368
+ b_major ,
369
+ c_major ,
370
+ fuse_alpha ,
371
+ alpha_dtype ,
372
+ mma_tiler_mn ,
373
+ cluster_shape_mn ,
374
+ sm_count ,
375
+ tolerance ,
376
+ iterations ,
377
+ enable_dst_signals ,
378
+ all_reduce ,
379
+ ):
358
380
assert rank >= 0
359
381
torch .cuda .set_device (rank )
360
382
device = torch .device ("cuda" , rank )
@@ -371,24 +393,24 @@ def _run_correctness_worker(world_size, rank, distributed_init_port):
371
393
372
394
try :
373
395
run_blockscaled_gemm_all_reduce_python_interface (
374
- lm = ( 2 , 512 ), # (1, 1024), (2, 512), (4, 256)
375
- kn = ( 7168 , 4096 ) ,
376
- ab_dtype = "float8_e5m2" ,
377
- sf_dtype = "float8_e8m0fnu" ,
378
- sf_vec_size = 32 ,
379
- c_dtype = "bfloat16" ,
380
- a_major = "k" ,
381
- b_major = "k" ,
382
- c_major = "n" ,
383
- fuse_alpha = False ,
384
- alpha_dtype = "float32" ,
385
- mma_tiler_mn = ( 128 , 128 ) ,
386
- cluster_shape_mn = ( 1 , 1 ) ,
387
- tolerance = 1e-01 ,
388
- iterations = 1 ,
389
- sm_count = 148 ,
390
- enable_dst_signals = True ,
391
- all_reduce = "two_shot" ,
396
+ lm = lm ,
397
+ kn = kn ,
398
+ ab_dtype = ab_dtype ,
399
+ sf_dtype = sf_dtype ,
400
+ sf_vec_size = sf_vec_size ,
401
+ c_dtype = c_dtype ,
402
+ a_major = a_major ,
403
+ b_major = b_major ,
404
+ c_major = c_major ,
405
+ fuse_alpha = fuse_alpha ,
406
+ alpha_dtype = alpha_dtype ,
407
+ mma_tiler_mn = mma_tiler_mn ,
408
+ cluster_shape_mn = cluster_shape_mn ,
409
+ tolerance = tolerance ,
410
+ iterations = iterations ,
411
+ sm_count = sm_count ,
412
+ enable_dst_signals = enable_dst_signals ,
413
+ all_reduce = all_reduce ,
392
414
rank = rank ,
393
415
)
394
416
except Exception as e :
@@ -433,7 +455,48 @@ def multi_process_parallel(
433
455
not is_cute_dsl_available (), reason = "Please `pip install nvidia-cutlass-dsl`"
434
456
)
435
457
@pytest .mark .parametrize ("world_size" , [8 ])
436
- def test_cute_dsl_blockscaled_gemm_allreduce_two_shot (world_size ):
458
+ @pytest .mark .parametrize ("lm" , [(1 , 1024 ), (2 , 512 ), (4 , 256 )])
459
+ @pytest .mark .parametrize ("kn" , [(7168 , 4096 )])
460
+ @pytest .mark .parametrize (
461
+ "ab_dtype,sf_dtype,c_dtype,sf_vec_size" ,
462
+ [
463
+ ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 ),
464
+ # Add more combinations as needed
465
+ ],
466
+ )
467
+ @pytest .mark .parametrize ("a_major" , ["k" ])
468
+ @pytest .mark .parametrize ("b_major" , ["k" ])
469
+ @pytest .mark .parametrize ("c_major" , ["n" ])
470
+ @pytest .mark .parametrize ("fuse_alpha" , [False ])
471
+ @pytest .mark .parametrize ("alpha_dtype" , ["float32" ])
472
+ @pytest .mark .parametrize ("mma_tiler_mn" , [(128 , 128 )])
473
+ @pytest .mark .parametrize ("cluster_shape_mn" , [(1 , 1 )])
474
+ @pytest .mark .parametrize ("sm_count" , [148 ])
475
+ @pytest .mark .parametrize ("tolerance" , [1e-01 ])
476
+ @pytest .mark .parametrize ("iterations" , [1 ])
477
+ @pytest .mark .parametrize ("enable_dst_signals" , [True ])
478
+ @pytest .mark .parametrize ("all_reduce" , ["two_shot" ])
479
+ def test_cute_dsl_blockscaled_gemm_allreduce_two_shot (
480
+ world_size ,
481
+ lm ,
482
+ kn ,
483
+ ab_dtype ,
484
+ sf_dtype ,
485
+ sf_vec_size ,
486
+ c_dtype ,
487
+ a_major ,
488
+ b_major ,
489
+ c_major ,
490
+ fuse_alpha ,
491
+ alpha_dtype ,
492
+ mma_tiler_mn ,
493
+ cluster_shape_mn ,
494
+ sm_count ,
495
+ tolerance ,
496
+ iterations ,
497
+ enable_dst_signals ,
498
+ all_reduce ,
499
+ ):
437
500
available_gpus = torch .cuda .device_count ()
438
501
if world_size > available_gpus :
439
502
pytest .skip (
@@ -443,6 +506,25 @@ def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(world_size):
443
506
multi_process_parallel (
444
507
world_size ,
445
508
_run_correctness_worker ,
446
- target_args = (),
509
+ target_args = (
510
+ lm ,
511
+ kn ,
512
+ ab_dtype ,
513
+ sf_dtype ,
514
+ sf_vec_size ,
515
+ c_dtype ,
516
+ a_major ,
517
+ b_major ,
518
+ c_major ,
519
+ fuse_alpha ,
520
+ alpha_dtype ,
521
+ mma_tiler_mn ,
522
+ cluster_shape_mn ,
523
+ sm_count ,
524
+ tolerance ,
525
+ iterations ,
526
+ enable_dst_signals ,
527
+ all_reduce ,
528
+ ),
447
529
)
448
530
print (f"cute_dsl_blockscaled_gemm_allreduce_two_shot on { world_size } GPUs: OK" )
0 commit comments