Skip to content

Commit 95871f0

Browse files
committed
Add a small repro
1 parent 4d1b431 commit 95871f0

File tree

1 file changed

+58
-1
lines changed

1 file changed

+58
-1
lines changed

tests/python/multidevice/test_multidevice.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ def test_sizes_and_ranks(multidevice_test):
2929
@pytest.mark.mpi
3030
def test_pointwise(multidevice_test):
3131
num_devices = multidevice_test.size
32-
mesh = nvfuser.multidevice.DeviceMesh(torch.arange(num_devices))
3332

3433
with FusionDefinition() as fd:
3534
inp_tv = fd.define_tensor((-1, -1), contiguity=False, dtype=DataType.Float)
3635
tv1 = fd.ops.relu(inp_tv)
3736
tv2 = fd.ops.add(tv1, tv1)
3837
fd.add_output(tv2)
3938

39+
mesh = nvfuser.multidevice.DeviceMesh(torch.arange(num_devices))
4040
for tv in [inp_tv, tv1, tv2]:
4141
tv.set_device_mesh(mesh)
4242

@@ -50,6 +50,63 @@ def test_pointwise(multidevice_test):
5050
torch.testing.assert_close(out.cpu(), out_ref)
5151

5252

53+
@pytest.mark.mpi
54+
def test_transpose(multidevice_test):
55+
d = multidevice_test.size
56+
cp_size = 2
57+
if d % (cp_size * cp_size) != 0:
58+
pytest.skip(
59+
f"We only support even split, so {d} has to be divisible by {cp_size * cp_size} for {cp_size=}."
60+
)
61+
dp_size = d // (cp_size * cp_size)
62+
63+
c = 128
64+
with FusionDefinition() as fd:
65+
inp_tv = fd.define_tensor(
66+
(-1, c, -1, -1, cp_size), contiguity=True, dtype=DataType.BFloat16
67+
)
68+
out_tv = fd.ops.set(inp_tv)
69+
fd.add_output(out_tv)
70+
71+
mesh = nvfuser.multidevice.DeviceMesh(
72+
torch.arange(d).reshape(dp_size, cp_size, cp_size)
73+
)
74+
for tv in [inp_tv, out_tv]:
75+
tv.set_device_mesh(mesh)
76+
77+
inp_tv.axis(4).parallelize(nvfuser.ParallelType.mesh_y)
78+
inp_tv.outer_split(3, cp_size)
79+
inp_tv.axis(3).parallelize(nvfuser.ParallelType.mesh_x)
80+
inp_tv.outer_split(0, dp_size)
81+
inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_z)
82+
83+
out_tv.axis(4).parallelize(nvfuser.ParallelType.mesh_y)
84+
out_tv.outer_split(3, cp_size)
85+
out_tv.axis(3).parallelize(nvfuser.ParallelType.mesh_x)
86+
out_tv.outer_split(0, dp_size)
87+
out_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_z)
88+
out_tv.set_allocation_domain(
89+
(
90+
out_tv.axis(3),
91+
out_tv.axis(0),
92+
out_tv.axis(1),
93+
out_tv.axis(2),
94+
out_tv.axis(4),
95+
out_tv.axis(5),
96+
out_tv.axis(6),
97+
),
98+
True,
99+
)
100+
101+
b = dp_size * 3
102+
s = cp_size * 5
103+
inp_ref = torch.randn(b, c, s, s, cp_size, dtype=torch.bfloat16)
104+
out_ref = inp_ref
105+
106+
inp = multidevice_test.shard_tensor(inp_ref, inp_tv)
107+
fd.execute([inp])
108+
109+
53110
class QkvFormat(Enum):
54111
BHSE = auto()
55112
BSHE = auto()

0 commit comments

Comments
 (0)