Skip to content

Commit aeb1815

Browse files
committed
wip
1 parent a033ee6 commit aeb1815

File tree

2 files changed

+60
-65
lines changed

2 files changed

+60
-65
lines changed

flashinfer/cute_dsl/blockscaled_gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2734,7 +2734,7 @@ def can_implement(
27342734
can_implement = True
27352735
# Skip unsupported types
27362736
if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(
2737-
ab_dtype, sf_dtype, sf_vec_size, c_dtype,all_reduce
2737+
ab_dtype, sf_dtype, sf_vec_size, c_dtype, all_reduce
27382738
):
27392739
can_implement = False
27402740
# Skip unsupported layouts

tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py

Lines changed: 59 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def run_blockscaled_gemm_all_reduce_python_interface(
158158
l, m = lm
159159
k, n = kn
160160

161-
print(f"device: {device}")
161+
#print(f"device: {device}")
162162

163163
if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
164164
get_cutlass_dtype(ab_dtype),
@@ -201,7 +201,7 @@ def run_blockscaled_gemm_all_reduce_python_interface(
201201
init_type=cutlass_torch.TensorInitType.SCALAR,
202202
init_config=cutlass_torch.ScalarInitConfig(value=0.0),
203203
)
204-
print(f"Rank {rank}: c_ref INITIAL shape={c_ref.shape}, stride={c_ref.stride()}")
204+
#print(f"Rank {rank}: c_ref INITIAL shape={c_ref.shape}, stride={c_ref.stride()}")
205205
a_tensor, a_torch = cutlass_torch.cute_tensor_like(
206206
a_ref,
207207
get_cutlass_dtype(ab_dtype),
@@ -226,9 +226,9 @@ def run_blockscaled_gemm_all_reduce_python_interface(
226226
# (1 if c_major == "n" else 0),
227227
is_dynamic_layout=True,
228228
)
229-
print(
230-
f"Rank {rank}: c_torch INITIAL shape={c_torch.shape}, stride={c_torch.stride()}"
231-
)
229+
# print(
230+
# f"Rank {rank}: c_torch INITIAL shape={c_torch.shape}, stride={c_torch.stride()}"
231+
# )
232232
alpha_tensor = (
233233
torch.randn(l, dtype=torch.float32, device=device) if fuse_alpha else None
234234
)
@@ -279,7 +279,7 @@ def run_blockscaled_gemm_all_reduce_python_interface(
279279
torch.distributed.broadcast(masked_m_tensor, src=0)
280280
# to hack and test:
281281
# masked_m_tensor = torch.full((l,), m, dtype=torch.int32, device=device)
282-
print(f"Rank {rank}: masked_m = {masked_m_tensor}")
282+
# print(f"Rank {rank}: masked_m = {masked_m_tensor}")
283283
for _ in range(iterations):
284284
dst_signals = (
285285
torch.zeros((l,), dtype=torch.uint32, device="cuda")
@@ -328,9 +328,9 @@ def run_blockscaled_gemm_all_reduce_python_interface(
328328
)
329329
# Convert c back to f32 for comparison.
330330
ref = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0)
331-
print(f"Rank {rank}: c_ref shape={c_ref.shape}, stride={c_ref.stride()}")
332-
print(f"Rank {rank}: ref shape={ref.shape}, stride={ref.stride()}")
333-
print(f"Rank {rank}: c_torch shape={c_torch.shape}, stride={c_torch.stride()}")
331+
# print(f"Rank {rank}: c_ref shape={c_ref.shape}, stride={c_ref.stride()}")
332+
# print(f"Rank {rank}: ref shape={ref.shape}, stride={ref.stride()}")
333+
# print(f"Rank {rank}: c_torch shape={c_torch.shape}, stride={c_torch.stride()}")
334334
cute.testing.convert(
335335
c_tensor,
336336
from_dlpack(c_ref, assumed_align=16).mark_layout_dynamic(
@@ -472,70 +472,32 @@ def multi_process_parallel(
472472
), f"Process {i} failed with exit code {procs[i].exitcode}"
473473

474474

475-
# @pytest.mark.parametrize("lm", [(1, 1024), (2, 512), (4, 256)])
476-
# @pytest.mark.parametrize("kn", [(7168, 4096), (2048, 7168)])
477-
# @pytest.mark.parametrize(
478-
# "ab_dtype,sf_dtype,c_dtype,sf_vec_size",
479-
# [
480-
# ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
481-
# ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
482-
# ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
483-
# ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
484-
# ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
485-
# ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
486-
# ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
487-
# ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
488-
# ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
489-
# ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
490-
# ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
491-
# ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
492-
# ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
493-
# ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
494-
# ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
495-
# ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
496-
# ],
497-
# )
498-
# @pytest.mark.parametrize("a_major", ["k"])
499-
# @pytest.mark.parametrize("b_major", ["k"])
500-
# @pytest.mark.parametrize("c_major", ["n"])
501-
# @pytest.mark.parametrize("fuse_alpha", [False, True])
502-
# @pytest.mark.parametrize("alpha_dtype", ["float32"])
503-
# @pytest.mark.parametrize("mma_tiler_mn", [(128, 128)])
504-
# @pytest.mark.parametrize("cluster_shape_mn", [(1, 1)])
505-
# @pytest.mark.parametrize("sm_count", [132, None])
506-
# @pytest.mark.parametrize("tolerance", [1e-01])
507-
# @pytest.mark.parametrize("iterations", [3])
508-
# @pytest.mark.parametrize("enable_dst_signals", [False, True])
509-
510-
# ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
511-
512-
513475
@pytest.mark.skipif(
514476
not is_cute_dsl_available(), reason="Please `pip install nvidia-cutlass-dsl`"
515477
)
516478
@pytest.mark.parametrize("world_size", [8])
517479
@pytest.mark.parametrize("lm", [(1, 1024), (2, 512), (4, 256)])
518-
@pytest.mark.parametrize("kn", [(7168, 4096)])
480+
@pytest.mark.parametrize("kn", [(7168, 4096), (2048, 7168)])
519481
@pytest.mark.parametrize(
520482
"ab_dtype,sf_dtype,c_dtype,sf_vec_size",
521483
[
522484
("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
523-
# ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
524-
# ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
525-
# ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
526-
# ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
527-
# ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
528-
# ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
529-
# ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
530-
# ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
531-
# ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
532-
# ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
533-
# ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
534-
# ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
535-
# ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
536-
# ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
537-
# ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
538-
# ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
485+
("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
486+
("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
487+
("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
488+
("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
489+
("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
490+
("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
491+
("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
492+
("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
493+
("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
494+
("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
495+
("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
496+
("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
497+
("float8_e5m2", "float8_e8m0fnu", "float16", 32),
498+
("float8_e5m2", "float8_e8m0fnu", "float32", 32),
499+
("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
500+
("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
539501
],
540502
)
541503
@pytest.mark.parametrize("a_major", ["k"])
@@ -576,6 +538,39 @@ def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(
576538
pytest.skip(
577539
f"world_size {world_size} is greater than available_gpus {available_gpus}"
578540
)
541+
#device = torch.device("cuda", rank)
542+
major, minor = torch.cuda.get_device_capability(torch.device("cuda:0"))
543+
if not (major == 10 and minor == 0):
544+
pytest.skip("Cute-dsl backend is only supported on SM100.")
545+
if enable_dst_signals and (sm_count is None):
546+
pytest.skip("dst_signals require sm_count")
547+
548+
l, m = lm
549+
k, n = kn
550+
if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
551+
get_cutlass_dtype(ab_dtype),
552+
get_cutlass_dtype(sf_dtype),
553+
sf_vec_size,
554+
get_cutlass_dtype(c_dtype),
555+
mma_tiler_mn,
556+
cluster_shape_mn,
557+
m,
558+
n,
559+
k,
560+
l,
561+
a_major,
562+
b_major,
563+
c_major,
564+
):
565+
pytest.skip(
566+
f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}"
567+
)
568+
569+
if not (a_major == "k" and b_major == "k" and c_major == "n"):
570+
# not supported since we try to align deepgemm for now
571+
pytest.skip(
572+
f"Skip non deepgemm-like cases {a_major}, {b_major}, {c_major}. Might be added later"
573+
)
579574
print(f"Running test for world_size={world_size}")
580575
multi_process_parallel(
581576
world_size,

0 commit comments

Comments
 (0)