Skip to content

Commit 415da39

Browse files
committed
clean up
1 parent 66ddb23 commit 415da39

File tree

1 file changed

+16
-54
lines changed

1 file changed

+16
-54
lines changed

tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py

Lines changed: 16 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,6 @@ def create_barrier_flags(m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count):
8686
barrier_size = Sm100BlockScaledPersistentDenseGemmKernel.compute_barrier_flag_size(
8787
m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count
8888
)
89-
#print("LOOK HERE", (barrier_size,))
90-
# NOTE: use_2cta_instrs from blockedscaled_gemm logic
91-
92-
# use_2cta_instrs = mma_tiler_mn[0] == 256
93-
# cta_tile_shape_mn = (
94-
# mma_tiler_mn[0] // (2 if use_2cta_instrs else 1),
95-
# mma_tiler_mn[1],
96-
# )
97-
# problem_shape_ntile_mn = (m // cta_tile_shape_mn[0], n // cta_tile_shape_mn[1])
98-
# num_tiles_per_batch = problem_shape_ntile_mn[0] * problem_shape_ntile_mn[1]
99-
# num_tiles = num_tiles_per_batch * l
100-
# num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
101-
# +num_sms for final barrier
102-
# num_tiles + num_sms
103-
10489
barrier_flag = symm_mem.empty((barrier_size,), device="cuda", dtype=torch.int32)
10590

10691
barrier_flag.fill_(0)
@@ -158,8 +143,6 @@ def run_blockscaled_gemm_all_reduce_python_interface(
158143
l, m = lm
159144
k, n = kn
160145

161-
#print(f"device: {device}")
162-
163146
if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
164147
get_cutlass_dtype(ab_dtype),
165148
get_cutlass_dtype(sf_dtype),
@@ -201,7 +184,6 @@ def run_blockscaled_gemm_all_reduce_python_interface(
201184
init_type=cutlass_torch.TensorInitType.SCALAR,
202185
init_config=cutlass_torch.ScalarInitConfig(value=0.0),
203186
)
204-
#print(f"Rank {rank}: c_ref INITIAL shape={c_ref.shape}, stride={c_ref.stride()}")
205187
a_tensor, a_torch = cutlass_torch.cute_tensor_like(
206188
a_ref,
207189
get_cutlass_dtype(ab_dtype),
@@ -214,21 +196,12 @@ def run_blockscaled_gemm_all_reduce_python_interface(
214196
is_dynamic_layout=True,
215197
assumed_align=16,
216198
)
217-
# c_tensor, c_torch = cutlass_torch.cute_tensor_like(
218-
# c_ref,
219-
# get_cutlass_dtype(c_dtype),
220-
# is_dynamic_layout=True,
221-
# assumed_align=16,
222-
# )
223199
c_tensor, c_tensor_mc, c_torch, c_torch_mc = create_mc_tensor(
224200
c_ref,
225201
get_cutlass_dtype(c_dtype),
226202
# (1 if c_major == "n" else 0),
227203
is_dynamic_layout=True,
228204
)
229-
# print(
230-
# f"Rank {rank}: c_torch INITIAL shape={c_torch.shape}, stride={c_torch.stride()}"
231-
# )
232205
alpha_tensor = (
233206
torch.randn(l, dtype=torch.float32, device=device) if fuse_alpha else None
234207
)
@@ -271,15 +244,11 @@ def run_blockscaled_gemm_all_reduce_python_interface(
271244
sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor(
272245
l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
273246
)
274-
# masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
275247
if rank == 0:
276248
masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
277249
else:
278250
masked_m_tensor = torch.empty((l,), dtype=torch.int32, device=device)
279251
torch.distributed.broadcast(masked_m_tensor, src=0)
280-
# to hack and test:
281-
# masked_m_tensor = torch.full((l,), m, dtype=torch.int32, device=device)
282-
# print(f"Rank {rank}: masked_m = {masked_m_tensor}")
283252
for _ in range(iterations):
284253
dst_signals = (
285254
torch.zeros((l,), dtype=torch.uint32, device="cuda")
@@ -328,18 +297,12 @@ def run_blockscaled_gemm_all_reduce_python_interface(
328297
)
329298
# Convert c back to f32 for comparison.
330299
ref = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0)
331-
# print(f"Rank {rank}: c_ref shape={c_ref.shape}, stride={c_ref.stride()}")
332-
# print(f"Rank {rank}: ref shape={ref.shape}, stride={ref.stride()}")
333-
# print(f"Rank {rank}: c_torch shape={c_torch.shape}, stride={c_torch.stride()}")
334300
cute.testing.convert(
335301
c_tensor,
336302
from_dlpack(c_ref, assumed_align=16).mark_layout_dynamic(
337303
leading_dim=(1 if c_major == "n" else 0)
338304
),
339305
)
340-
# print(f"Rank {rank}: c_ref shape={c_ref.shape}, stride={c_ref.stride()}")
341-
# print(f"Rank {rank}: ref shape={ref.shape}, stride={ref.stride()}")
342-
# print(f"Rank {rank}: c_torch shape={c_torch.shape}, stride={c_torch.stride()}")
343306
if c_dtype in ("float32", "float16", "bfloat16"):
344307
for i in range(l):
345308
# skip testing c_ref & ref
@@ -481,23 +444,23 @@ def multi_process_parallel(
481444
@pytest.mark.parametrize(
482445
"ab_dtype,sf_dtype,c_dtype,sf_vec_size",
483446
[
484-
# ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
485-
# ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
486-
# ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
487-
# ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
488-
# ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
489-
# ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
490-
# ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
491-
# ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
492-
# ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
493-
# ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
494-
("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
495-
# ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
496-
# ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
497-
# ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
498-
# ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
447+
("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
448+
("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
449+
("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
450+
("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
451+
("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
452+
("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
453+
("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
454+
("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
455+
("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
456+
("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
457+
# ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
458+
("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
459+
("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
460+
("float8_e5m2", "float8_e8m0fnu", "float16", 32),
461+
("float8_e5m2", "float8_e8m0fnu", "float32", 32),
499462
# ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
500-
# ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
463+
("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
501464
],
502465
)
503466
@pytest.mark.parametrize("a_major", ["k"])
@@ -538,7 +501,6 @@ def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(
538501
pytest.skip(
539502
f"world_size {world_size} is greater than available_gpus {available_gpus}"
540503
)
541-
#device = torch.device("cuda", rank)
542504
major, minor = torch.cuda.get_device_capability(torch.device("cuda:0"))
543505
if not (major == 10 and minor == 0):
544506
pytest.skip("Cute-dsl backend is only supported on SM100.")

0 commit comments

Comments
 (0)