Skip to content

Commit 22746b6

Browse files
committed
wip
1 parent 2e96016 commit 22746b6

File tree

1 file changed

+103
-21
lines changed

1 file changed

+103
-21
lines changed

tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py

Lines changed: 103 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,29 @@ def run_blockscaled_gemm_all_reduce_python_interface(
354354
rtol=1e-02,
355355
)
356356

357-
def _run_correctness_worker(world_size, rank, distributed_init_port):
357+
def _run_correctness_worker(
358+
world_size,
359+
rank,
360+
distributed_init_port,
361+
lm,
362+
kn,
363+
ab_dtype,
364+
sf_dtype,
365+
sf_vec_size,
366+
c_dtype,
367+
a_major,
368+
b_major,
369+
c_major,
370+
fuse_alpha,
371+
alpha_dtype,
372+
mma_tiler_mn,
373+
cluster_shape_mn,
374+
sm_count,
375+
tolerance,
376+
iterations,
377+
enable_dst_signals,
378+
all_reduce,
379+
):
358380
assert rank >= 0
359381
torch.cuda.set_device(rank)
360382
device = torch.device("cuda", rank)
@@ -371,24 +393,24 @@ def _run_correctness_worker(world_size, rank, distributed_init_port):
371393

372394
try:
373395
run_blockscaled_gemm_all_reduce_python_interface(
374-
lm=(2, 512), # (1, 1024), (2, 512), (4, 256)
375-
kn=(7168, 4096),
376-
ab_dtype="float8_e5m2",
377-
sf_dtype="float8_e8m0fnu",
378-
sf_vec_size=32,
379-
c_dtype="bfloat16",
380-
a_major="k",
381-
b_major="k",
382-
c_major="n",
383-
fuse_alpha=False,
384-
alpha_dtype="float32",
385-
mma_tiler_mn=(128, 128),
386-
cluster_shape_mn=(1, 1),
387-
tolerance=1e-01,
388-
iterations=1,
389-
sm_count=148,
390-
enable_dst_signals=True,
391-
all_reduce="two_shot",
396+
lm=lm,
397+
kn=kn,
398+
ab_dtype=ab_dtype,
399+
sf_dtype=sf_dtype,
400+
sf_vec_size=sf_vec_size,
401+
c_dtype=c_dtype,
402+
a_major=a_major,
403+
b_major=b_major,
404+
c_major=c_major,
405+
fuse_alpha=fuse_alpha,
406+
alpha_dtype=alpha_dtype,
407+
mma_tiler_mn=mma_tiler_mn,
408+
cluster_shape_mn=cluster_shape_mn,
409+
tolerance=tolerance,
410+
iterations=iterations,
411+
sm_count=sm_count,
412+
enable_dst_signals=enable_dst_signals,
413+
all_reduce=all_reduce,
392414
rank=rank,
393415
)
394416
except Exception as e:
@@ -433,7 +455,48 @@ def multi_process_parallel(
433455
not is_cute_dsl_available(), reason="Please `pip install nvidia-cutlass-dsl`"
434456
)
435457
@pytest.mark.parametrize("world_size", [8])
436-
def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(world_size):
458+
@pytest.mark.parametrize("lm", [(1, 1024), (2, 512), (4, 256)])
459+
@pytest.mark.parametrize("kn", [(7168, 4096)])
460+
@pytest.mark.parametrize(
461+
"ab_dtype,sf_dtype,c_dtype,sf_vec_size",
462+
[
463+
("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
464+
# Add more combinations as needed
465+
],
466+
)
467+
@pytest.mark.parametrize("a_major", ["k"])
468+
@pytest.mark.parametrize("b_major", ["k"])
469+
@pytest.mark.parametrize("c_major", ["n"])
470+
@pytest.mark.parametrize("fuse_alpha", [False])
471+
@pytest.mark.parametrize("alpha_dtype", ["float32"])
472+
@pytest.mark.parametrize("mma_tiler_mn", [(128, 128)])
473+
@pytest.mark.parametrize("cluster_shape_mn", [(1, 1)])
474+
@pytest.mark.parametrize("sm_count", [148])
475+
@pytest.mark.parametrize("tolerance", [1e-01])
476+
@pytest.mark.parametrize("iterations", [1])
477+
@pytest.mark.parametrize("enable_dst_signals", [True])
478+
@pytest.mark.parametrize("all_reduce", ["two_shot"])
479+
def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(
480+
world_size,
481+
lm,
482+
kn,
483+
ab_dtype,
484+
sf_dtype,
485+
sf_vec_size,
486+
c_dtype,
487+
a_major,
488+
b_major,
489+
c_major,
490+
fuse_alpha,
491+
alpha_dtype,
492+
mma_tiler_mn,
493+
cluster_shape_mn,
494+
sm_count,
495+
tolerance,
496+
iterations,
497+
enable_dst_signals,
498+
all_reduce,
499+
):
437500
available_gpus = torch.cuda.device_count()
438501
if world_size > available_gpus:
439502
pytest.skip(
@@ -443,6 +506,25 @@ def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(world_size):
443506
multi_process_parallel(
444507
world_size,
445508
_run_correctness_worker,
446-
target_args=(),
509+
target_args=(
510+
lm,
511+
kn,
512+
ab_dtype,
513+
sf_dtype,
514+
sf_vec_size,
515+
c_dtype,
516+
a_major,
517+
b_major,
518+
c_major,
519+
fuse_alpha,
520+
alpha_dtype,
521+
mma_tiler_mn,
522+
cluster_shape_mn,
523+
sm_count,
524+
tolerance,
525+
iterations,
526+
enable_dst_signals,
527+
all_reduce,
528+
),
447529
)
448530
print(f"cute_dsl_blockscaled_gemm_allreduce_two_shot on {world_size} GPUs: OK")

0 commit comments

Comments
 (0)