Skip to content

Commit 00958a9

Browse files
committed
wip
1 parent 97d8cb5 commit 00958a9

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,45 +28,51 @@
2828

2929

3030
def create_mc_tensor(torch_tensor_cpu, dtype, is_dynamic_layout=True):
31-
torch_tensor_cpu_lmn = torch_tensor_cpu.permute(2, 0, 1).contiguous()
32-
# torch_symm_tensor = symm_mem.empty(
33-
# torch_tensor_cpu_lmn, device="cuda", dtype=torch_tensor_cpu.dtype
34-
# )
35-
torch_symm_tensor_lmn = symm_mem.empty(
36-
torch_tensor_cpu_lmn.shape, # (l, m, n)
37-
device="cuda",
38-
dtype=torch_tensor_cpu.dtype
31+
m, n, l = torch_tensor_cpu.shape
32+
33+
# Create flat symm_mem buffer
34+
total_elements = m * n * l
35+
torch_symm_flat = symm_mem.empty(
36+
(total_elements,), device="cuda", dtype=torch_tensor_cpu.dtype
3937
)
40-
41-
torch_symm_tensor_lmn.copy_(torch_tensor_cpu_lmn)
42-
torch_symm_tensor = torch_symm_tensor_lmn.permute(1, 2, 0)
43-
symm = symm_mem.rendezvous(torch_symm_tensor, group=dist.group.WORLD.group_name)
38+
39+
# Reshape to match input's stride pattern using as_strided
40+
torch_symm_tensor = torch_symm_flat.as_strided(
41+
size=torch_tensor_cpu.shape,
42+
stride=torch_tensor_cpu.stride()
43+
)
44+
torch_symm_tensor.copy_(torch_tensor_cpu)
45+
46+
symm = symm_mem.rendezvous(torch_symm_flat, group=dist.group.WORLD.group_name)
4447
mc_ptr = symm.multicast_ptr
45-
# create MC tensor memref
46-
torch_tensor_mc = cutlass_torch.as_tensor(mc_ptr, torch_tensor_cpu.shape, torch_tensor_cpu.dtype)
47-
cute_tensor_mc = from_dlpack(
48-
torch_tensor_mc,
49-
assumed_align=16,
48+
49+
# Create MC tensor with same stride
50+
torch_tensor_mc_flat = cutlass_torch.as_tensor(mc_ptr, (total_elements,), torch_tensor_cpu.dtype)
51+
torch_tensor_mc = torch_tensor_mc_flat.as_strided(
52+
size=torch_tensor_cpu.shape,
53+
stride=torch_tensor_cpu.stride()
5054
)
51-
# if is_dynamic_layout:
52-
# cute_tensor_mc = cute_tensor_mc.mark_layout_dynamic(leading_dim=leading_dim)
55+
56+
cute_tensor_mc = from_dlpack(torch_tensor_mc, assumed_align=16)
57+
5358
if is_dynamic_layout:
5459
for i, stride in enumerate(torch_tensor_mc.stride()):
5560
if stride == 1:
5661
leading_dim = i
5762
break
5863
cute_tensor_mc = cute_tensor_mc.mark_layout_dynamic(leading_dim=leading_dim)
64+
5965
torch_tensor_gpu = torch_symm_tensor
6066
cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16)
6167
cute_tensor.element_type = dtype
62-
# if is_dynamic_layout:
63-
# cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim)
68+
6469
if is_dynamic_layout:
65-
for i, stride in enumerate(torch_tensor_mc.stride()):
70+
for i, stride in enumerate(torch_tensor_gpu.stride()):
6671
if stride == 1:
6772
leading_dim = i
6873
break
6974
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim)
75+
7076
cute_tensor = cutlass_torch.convert_cute_tensor(
7177
torch_tensor_gpu,
7278
cute_tensor,

0 commit comments

Comments
 (0)