Skip to content

Commit a033ee6

Browse files
committed
wip
1 parent 22746b6 commit a033ee6

File tree

2 files changed

+152
-77
lines changed

2 files changed

+152
-77
lines changed

flashinfer/cute_dsl/blockscaled_gemm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2755,8 +2755,9 @@ def can_implement(
27552755

27562756
# check for all reduce constraints
27572757
# TODO(asamani): expand the logic for mnnvl support
2758-
if torch.distributed.get_world_size() not in [2, 4, 8] and all_reduce != "none":
2759-
can_implement = False
2758+
if all_reduce != "none":
2759+
if torch.distributed.get_world_size() not in [2, 4, 8]:
2760+
can_implement = False
27602761
return can_implement
27612762

27622763

tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py

Lines changed: 149 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -29,50 +29,50 @@
2929

3030
def create_mc_tensor(torch_tensor_cpu, dtype, is_dynamic_layout=True):
3131
m, n, l = torch_tensor_cpu.shape
32-
32+
3333
# Create flat symm_mem buffer
3434
total_elements = m * n * l
3535
torch_symm_flat = symm_mem.empty(
3636
(total_elements,), device="cuda", dtype=torch_tensor_cpu.dtype
3737
)
38-
38+
3939
# Reshape to match input's stride pattern using as_strided
4040
torch_symm_tensor = torch_symm_flat.as_strided(
41-
size=torch_tensor_cpu.shape,
42-
stride=torch_tensor_cpu.stride()
41+
size=torch_tensor_cpu.shape, stride=torch_tensor_cpu.stride()
4342
)
4443
torch_symm_tensor.copy_(torch_tensor_cpu)
45-
44+
4645
symm = symm_mem.rendezvous(torch_symm_flat, group=dist.group.WORLD.group_name)
4746
mc_ptr = symm.multicast_ptr
48-
47+
4948
# Create MC tensor with same stride
50-
torch_tensor_mc_flat = cutlass_torch.as_tensor(mc_ptr, (total_elements,), torch_tensor_cpu.dtype)
49+
torch_tensor_mc_flat = cutlass_torch.as_tensor(
50+
mc_ptr, (total_elements,), torch_tensor_cpu.dtype
51+
)
5152
torch_tensor_mc = torch_tensor_mc_flat.as_strided(
52-
size=torch_tensor_cpu.shape,
53-
stride=torch_tensor_cpu.stride()
53+
size=torch_tensor_cpu.shape, stride=torch_tensor_cpu.stride()
5454
)
55-
55+
5656
cute_tensor_mc = from_dlpack(torch_tensor_mc, assumed_align=16)
57-
57+
5858
if is_dynamic_layout:
5959
for i, stride in enumerate(torch_tensor_mc.stride()):
6060
if stride == 1:
6161
leading_dim = i
6262
break
6363
cute_tensor_mc = cute_tensor_mc.mark_layout_dynamic(leading_dim=leading_dim)
64-
64+
6565
torch_tensor_gpu = torch_symm_tensor
6666
cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16)
6767
cute_tensor.element_type = dtype
68-
68+
6969
if is_dynamic_layout:
7070
for i, stride in enumerate(torch_tensor_gpu.stride()):
7171
if stride == 1:
7272
leading_dim = i
7373
break
7474
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim)
75-
75+
7676
cute_tensor = cutlass_torch.convert_cute_tensor(
7777
torch_tensor_gpu,
7878
cute_tensor,
@@ -81,44 +81,49 @@ def create_mc_tensor(torch_tensor_cpu, dtype, is_dynamic_layout=True):
8181
)
8282
return cute_tensor, cute_tensor_mc, torch_tensor_gpu, torch_tensor_mc
8383

84+
8485
def create_barrier_flags(m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count):
85-
barrier_size = Sm100BlockScaledPersistentDenseGemmKernel.compute_barrier_flag_size(
86-
m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count
87-
)
88-
print("LOOK HERE",(barrier_size,))
89-
# NOTE: use_2cta_instrs from blockedscaled_gemm logic
90-
91-
# use_2cta_instrs = mma_tiler_mn[0] == 256
92-
# cta_tile_shape_mn = (
93-
# mma_tiler_mn[0] // (2 if use_2cta_instrs else 1),
94-
# mma_tiler_mn[1],
95-
# )
96-
# problem_shape_ntile_mn = (m // cta_tile_shape_mn[0], n // cta_tile_shape_mn[1])
97-
# num_tiles_per_batch = problem_shape_ntile_mn[0] * problem_shape_ntile_mn[1]
98-
# num_tiles = num_tiles_per_batch * l
99-
# num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
100-
# +num_sms for final barrier
101-
#num_tiles + num_sms
102-
103-
barrier_flag = symm_mem.empty(
104-
(barrier_size,), device="cuda", dtype=torch.int32
105-
)
106-
107-
barrier_flag.fill_(0)
108-
symm = symm_mem.rendezvous(barrier_flag, group=dist.group.WORLD.group_name)
109-
barrier_flag_mc_ptr = symm.multicast_ptr
110-
111-
barrier_flag_memref = from_dlpack(barrier_flag)
112-
barrier_flag_memref = barrier_flag_memref.mark_layout_dynamic()
113-
barrier_flag_mc_torch = cutlass_torch.as_tensor(
114-
barrier_flag_mc_ptr, barrier_flag.shape, barrier_flag.dtype
115-
)
116-
barrier_flag_mc_memref = from_dlpack(
117-
barrier_flag_mc_torch,
118-
)
119-
barrier_flag_mc_memref = barrier_flag_mc_memref.mark_layout_dynamic()
120-
barrier_flag_torch = barrier_flag
121-
return barrier_flag_memref, barrier_flag_mc_memref, barrier_flag_torch, barrier_flag_mc_torch
86+
barrier_size = Sm100BlockScaledPersistentDenseGemmKernel.compute_barrier_flag_size(
87+
m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count
88+
)
89+
print("LOOK HERE", (barrier_size,))
90+
# NOTE: use_2cta_instrs from blockedscaled_gemm logic
91+
92+
# use_2cta_instrs = mma_tiler_mn[0] == 256
93+
# cta_tile_shape_mn = (
94+
# mma_tiler_mn[0] // (2 if use_2cta_instrs else 1),
95+
# mma_tiler_mn[1],
96+
# )
97+
# problem_shape_ntile_mn = (m // cta_tile_shape_mn[0], n // cta_tile_shape_mn[1])
98+
# num_tiles_per_batch = problem_shape_ntile_mn[0] * problem_shape_ntile_mn[1]
99+
# num_tiles = num_tiles_per_batch * l
100+
# num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
101+
# +num_sms for final barrier
102+
# num_tiles + num_sms
103+
104+
barrier_flag = symm_mem.empty((barrier_size,), device="cuda", dtype=torch.int32)
105+
106+
barrier_flag.fill_(0)
107+
symm = symm_mem.rendezvous(barrier_flag, group=dist.group.WORLD.group_name)
108+
barrier_flag_mc_ptr = symm.multicast_ptr
109+
110+
barrier_flag_memref = from_dlpack(barrier_flag)
111+
barrier_flag_memref = barrier_flag_memref.mark_layout_dynamic()
112+
barrier_flag_mc_torch = cutlass_torch.as_tensor(
113+
barrier_flag_mc_ptr, barrier_flag.shape, barrier_flag.dtype
114+
)
115+
barrier_flag_mc_memref = from_dlpack(
116+
barrier_flag_mc_torch,
117+
)
118+
barrier_flag_mc_memref = barrier_flag_mc_memref.mark_layout_dynamic()
119+
barrier_flag_torch = barrier_flag
120+
return (
121+
barrier_flag_memref,
122+
barrier_flag_mc_memref,
123+
barrier_flag_torch,
124+
barrier_flag_mc_torch,
125+
)
126+
122127

123128
def run_blockscaled_gemm_all_reduce_python_interface(
124129
lm: Tuple[int, int],
@@ -139,7 +144,7 @@ def run_blockscaled_gemm_all_reduce_python_interface(
139144
iterations: int,
140145
enable_dst_signals: int,
141146
all_reduce: str,
142-
rank:int,
147+
rank: int,
143148
):
144149
torch.manual_seed(42)
145150
device = torch.device("cuda", rank)
@@ -187,7 +192,12 @@ def run_blockscaled_gemm_all_reduce_python_interface(
187192
l, n, k, b_major == "n", cutlass.Float32, device=device
188193
)
189194
c_ref = cutlass_torch.matrix(
190-
l, m, n, c_major == "m", cutlass.Float32, device=device,
195+
l,
196+
m,
197+
n,
198+
c_major == "m",
199+
cutlass.Float32,
200+
device=device,
191201
init_type=cutlass_torch.TensorInitType.SCALAR,
192202
init_config=cutlass_torch.ScalarInitConfig(value=0.0),
193203
)
@@ -213,14 +223,21 @@ def run_blockscaled_gemm_all_reduce_python_interface(
213223
c_tensor, c_tensor_mc, c_torch, c_torch_mc = create_mc_tensor(
214224
c_ref,
215225
get_cutlass_dtype(c_dtype),
216-
#(1 if c_major == "n" else 0),
226+
# (1 if c_major == "n" else 0),
217227
is_dynamic_layout=True,
218228
)
219-
print(f"Rank {rank}: c_torch INITIAL shape={c_torch.shape}, stride={c_torch.stride()}")
229+
print(
230+
f"Rank {rank}: c_torch INITIAL shape={c_torch.shape}, stride={c_torch.stride()}"
231+
)
220232
alpha_tensor = (
221233
torch.randn(l, dtype=torch.float32, device=device) if fuse_alpha else None
222234
)
223-
barrier_flag_memref, barrier_flag_mc_memref, barrier_flag_torch, barrier_flag_mc_torch = create_barrier_flags(
235+
(
236+
barrier_flag_memref,
237+
barrier_flag_mc_memref,
238+
barrier_flag_torch,
239+
barrier_flag_mc_torch,
240+
) = create_barrier_flags(
224241
m,
225242
n,
226243
l,
@@ -254,15 +271,15 @@ def run_blockscaled_gemm_all_reduce_python_interface(
254271
sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor(
255272
l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
256273
)
257-
#masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
258-
# if rank == 0:
259-
# masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
260-
# else:
261-
# masked_m_tensor = torch.empty((l,), dtype=torch.int32, device=device)
262-
# torch.distributed.broadcast(masked_m_tensor, src=0)
274+
# masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
275+
if rank == 0:
276+
masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
277+
else:
278+
masked_m_tensor = torch.empty((l,), dtype=torch.int32, device=device)
279+
torch.distributed.broadcast(masked_m_tensor, src=0)
263280
# to hack and test:
264-
masked_m_tensor = torch.full((l,), m, dtype=torch.int32, device=device)
265-
print(f"Rank {rank}: masked_m = {masked_m_tensor}")
281+
# masked_m_tensor = torch.full((l,), m, dtype=torch.int32, device=device)
282+
print(f"Rank {rank}: masked_m = {masked_m_tensor}")
266283
for _ in range(iterations):
267284
dst_signals = (
268285
torch.zeros((l,), dtype=torch.uint32, device="cuda")
@@ -306,7 +323,9 @@ def run_blockscaled_gemm_all_reduce_python_interface(
306323
ref = torch.einsum("mkl,nkl->mnl", res_a, res_b)
307324
ref = torch.einsum("mnl,l->mnl", ref, alpha_tensor)
308325
ref = ref.contiguous()
309-
torch.distributed.all_reduce(ref, op=torch.distributed.ReduceOp.SUM, group=dist.group.WORLD)
326+
torch.distributed.all_reduce(
327+
ref, op=torch.distributed.ReduceOp.SUM, group=dist.group.WORLD
328+
)
310329
# Convert c back to f32 for comparison.
311330
ref = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0)
312331
print(f"Rank {rank}: c_ref shape={c_ref.shape}, stride={c_ref.stride()}")
@@ -354,9 +373,10 @@ def run_blockscaled_gemm_all_reduce_python_interface(
354373
rtol=1e-02,
355374
)
356375

376+
357377
def _run_correctness_worker(
358-
world_size,
359-
rank,
378+
world_size,
379+
rank,
360380
distributed_init_port,
361381
lm,
362382
kn,
@@ -447,9 +467,48 @@ def multi_process_parallel(
447467

448468
for i in range(world_size):
449469
procs[i].join()
450-
assert procs[i].exitcode == 0, (
451-
f"Process {i} failed with exit code {procs[i].exitcode}"
452-
)
470+
assert (
471+
procs[i].exitcode == 0
472+
), f"Process {i} failed with exit code {procs[i].exitcode}"
473+
474+
475+
# @pytest.mark.parametrize("lm", [(1, 1024), (2, 512), (4, 256)])
476+
# @pytest.mark.parametrize("kn", [(7168, 4096), (2048, 7168)])
477+
# @pytest.mark.parametrize(
478+
# "ab_dtype,sf_dtype,c_dtype,sf_vec_size",
479+
# [
480+
# ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
481+
# ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
482+
# ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
483+
# ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
484+
# ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
485+
# ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
486+
# ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
487+
# ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
488+
# ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
489+
# ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
490+
# ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
491+
# ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
492+
# ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
493+
# ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
494+
# ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
495+
# ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
496+
# ],
497+
# )
498+
# @pytest.mark.parametrize("a_major", ["k"])
499+
# @pytest.mark.parametrize("b_major", ["k"])
500+
# @pytest.mark.parametrize("c_major", ["n"])
501+
# @pytest.mark.parametrize("fuse_alpha", [False, True])
502+
# @pytest.mark.parametrize("alpha_dtype", ["float32"])
503+
# @pytest.mark.parametrize("mma_tiler_mn", [(128, 128)])
504+
# @pytest.mark.parametrize("cluster_shape_mn", [(1, 1)])
505+
# @pytest.mark.parametrize("sm_count", [132, None])
506+
# @pytest.mark.parametrize("tolerance", [1e-01])
507+
# @pytest.mark.parametrize("iterations", [3])
508+
# @pytest.mark.parametrize("enable_dst_signals", [False, True])
509+
510+
# ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
511+
453512

454513
@pytest.mark.skipif(
455514
not is_cute_dsl_available(), reason="Please `pip install nvidia-cutlass-dsl`"
@@ -460,21 +519,36 @@ def multi_process_parallel(
460519
@pytest.mark.parametrize(
461520
"ab_dtype,sf_dtype,c_dtype,sf_vec_size",
462521
[
463-
("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
464-
# Add more combinations as needed
522+
("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
523+
# ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
524+
# ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
525+
# ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
526+
# ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
527+
# ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
528+
# ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
529+
# ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
530+
# ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
531+
# ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
532+
# ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
533+
# ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
534+
# ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
535+
# ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
536+
# ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
537+
# ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
538+
# ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
465539
],
466540
)
467541
@pytest.mark.parametrize("a_major", ["k"])
468542
@pytest.mark.parametrize("b_major", ["k"])
469543
@pytest.mark.parametrize("c_major", ["n"])
470-
@pytest.mark.parametrize("fuse_alpha", [False])
544+
@pytest.mark.parametrize("fuse_alpha", [False, True])
471545
@pytest.mark.parametrize("alpha_dtype", ["float32"])
472546
@pytest.mark.parametrize("mma_tiler_mn", [(128, 128)])
473547
@pytest.mark.parametrize("cluster_shape_mn", [(1, 1)])
474548
@pytest.mark.parametrize("sm_count", [148])
475549
@pytest.mark.parametrize("tolerance", [1e-01])
476550
@pytest.mark.parametrize("iterations", [1])
477-
@pytest.mark.parametrize("enable_dst_signals", [True])
551+
@pytest.mark.parametrize("enable_dst_signals", [False, True])
478552
@pytest.mark.parametrize("all_reduce", ["two_shot"])
479553
def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(
480554
world_size,
@@ -527,4 +601,4 @@ def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(
527601
all_reduce,
528602
),
529603
)
530-
print(f"cute_dsl_blockscaled_gemm_allreduce_two_shot on {world_size} GPUs: OK")
604+
print(f"cute_dsl_blockscaled_gemm_allreduce_two_shot on {world_size} GPUs: OK")

0 commit comments

Comments
 (0)