Skip to content

Commit b03774f

Browse files
committed
wip
1 parent 55ce566 commit b03774f

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,18 @@
2828

2929

3030
def create_mc_tensor(torch_tensor_cpu, dtype, is_dynamic_layout=True):
31-
torch_symm_tensor = symm_mem.empty(
32-
torch_tensor_cpu.shape, device="cuda", dtype=torch_tensor_cpu.dtype
31+
torch_tensor_cpu_lmn = torch_tensor_cpu.permute(2, 0, 1).contiguous()
32+
# torch_symm_tensor = symm_mem.empty(
33+
# torch_tensor_cpu_lmn, device="cuda", dtype=torch_tensor_cpu.dtype
34+
# )
35+
torch_symm_tensor_lmn = symm_mem.empty(
36+
torch_tensor_cpu_lmn.shape, # (l, m, n)
37+
device="cuda",
38+
dtype=torch_tensor_cpu.dtype
3339
)
34-
torch_symm_tensor.copy_(torch_tensor_cpu)
40+
41+
torch_symm_tensor_lmn.copy_(torch_tensor_cpu)
42+
torch_symm_tensor = torch_symm_tensor_lmn.permute(1, 2, 0)
3543
symm = symm_mem.rendezvous(torch_symm_tensor, group=dist.group.WORLD.group_name)
3644
mc_ptr = symm.multicast_ptr
3745
# create MC tensor memref

0 commit comments

Comments
 (0)