|
28 | 28 |
|
29 | 29 |
|
30 | 30 | def create_mc_tensor(torch_tensor_cpu, dtype, is_dynamic_layout=True):
|
31 |
| - torch_tensor_cpu_lmn = torch_tensor_cpu.permute(2, 0, 1).contiguous() |
32 |
| - # torch_symm_tensor = symm_mem.empty( |
33 |
| - # torch_tensor_cpu_lmn, device="cuda", dtype=torch_tensor_cpu.dtype |
34 |
| - # ) |
35 |
| - torch_symm_tensor_lmn = symm_mem.empty( |
36 |
| - torch_tensor_cpu_lmn.shape, # (l, m, n) |
37 |
| - device="cuda", |
38 |
| - dtype=torch_tensor_cpu.dtype |
| 31 | + m, n, l = torch_tensor_cpu.shape |
| 32 | + |
| 33 | + # Create flat symm_mem buffer |
| 34 | + total_elements = m * n * l |
| 35 | + torch_symm_flat = symm_mem.empty( |
| 36 | + (total_elements,), device="cuda", dtype=torch_tensor_cpu.dtype |
39 | 37 | )
|
40 |
| - |
41 |
| - torch_symm_tensor_lmn.copy_(torch_tensor_cpu_lmn) |
42 |
| - torch_symm_tensor = torch_symm_tensor_lmn.permute(1, 2, 0) |
43 |
| - symm = symm_mem.rendezvous(torch_symm_tensor, group=dist.group.WORLD.group_name) |
| 38 | + |
| 39 | + # Reshape to match input's stride pattern using as_strided |
| 40 | + torch_symm_tensor = torch_symm_flat.as_strided( |
| 41 | + size=torch_tensor_cpu.shape, |
| 42 | + stride=torch_tensor_cpu.stride() |
| 43 | + ) |
| 44 | + torch_symm_tensor.copy_(torch_tensor_cpu) |
| 45 | + |
| 46 | + symm = symm_mem.rendezvous(torch_symm_flat, group=dist.group.WORLD.group_name) |
44 | 47 | mc_ptr = symm.multicast_ptr
|
45 |
| - # create MC tensor memref |
46 |
| - torch_tensor_mc = cutlass_torch.as_tensor(mc_ptr, torch_tensor_cpu.shape, torch_tensor_cpu.dtype) |
47 |
| - cute_tensor_mc = from_dlpack( |
48 |
| - torch_tensor_mc, |
49 |
| - assumed_align=16, |
| 48 | + |
| 49 | + # Create MC tensor with same stride |
| 50 | + torch_tensor_mc_flat = cutlass_torch.as_tensor(mc_ptr, (total_elements,), torch_tensor_cpu.dtype) |
| 51 | + torch_tensor_mc = torch_tensor_mc_flat.as_strided( |
| 52 | + size=torch_tensor_cpu.shape, |
| 53 | + stride=torch_tensor_cpu.stride() |
50 | 54 | )
|
51 |
| - # if is_dynamic_layout: |
52 |
| - # cute_tensor_mc = cute_tensor_mc.mark_layout_dynamic(leading_dim=leading_dim) |
| 55 | + |
| 56 | + cute_tensor_mc = from_dlpack(torch_tensor_mc, assumed_align=16) |
| 57 | + |
53 | 58 | if is_dynamic_layout:
|
54 | 59 | for i, stride in enumerate(torch_tensor_mc.stride()):
|
55 | 60 | if stride == 1:
|
56 | 61 | leading_dim = i
|
57 | 62 | break
|
58 | 63 | cute_tensor_mc = cute_tensor_mc.mark_layout_dynamic(leading_dim=leading_dim)
|
| 64 | + |
59 | 65 | torch_tensor_gpu = torch_symm_tensor
|
60 | 66 | cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16)
|
61 | 67 | cute_tensor.element_type = dtype
|
62 |
| - # if is_dynamic_layout: |
63 |
| - # cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) |
| 68 | + |
64 | 69 | if is_dynamic_layout:
|
65 |
| - for i, stride in enumerate(torch_tensor_mc.stride()): |
| 70 | + for i, stride in enumerate(torch_tensor_gpu.stride()): |
66 | 71 | if stride == 1:
|
67 | 72 | leading_dim = i
|
68 | 73 | break
|
69 | 74 | cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim)
|
| 75 | + |
70 | 76 | cute_tensor = cutlass_torch.convert_cute_tensor(
|
71 | 77 | torch_tensor_gpu,
|
72 | 78 | cute_tensor,
|
|
0 commit comments