Skip to content

Commit c614128

Browse files
tianyu-lpytorchmergebot
authored andcommitted
[DTensor] support Replicate -> Partial("avg") + support distribute_tensor with Partial placements (pytorch#168133)
Pull Request resolved: pytorch#168133 Approved by: https://github.com/ezyang
1 parent 9bca3c1 commit c614128

File tree

7 files changed

+89
-35
lines changed

7 files changed

+89
-35
lines changed

test/distributed/tensor/debug/test_debug_mode.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_debug_mode_mm(self):
7676
_c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32]
7777
_c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32]
7878
aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32]
79-
<method 'sum' of 'torch._C.TensorBase' objects>(dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P
79+
<method 'sum' of 'torch._C.TensorBase' objects>(dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P(sum)
8080
aten::sum(dt$6: f32[8, 32]| S(0))
8181
aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""",
8282
)
@@ -179,8 +179,8 @@ def test_debug_mode_backward(self):
179179
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 8]| S(0))
180180
aten::sum(dt: f32[8, 8]| S(0))
181181
aten::sum(t: f32[1, 8])
182-
torch._tensor.backward(dt: f32[]| P, gradient=None, retain_graph=None, create_graph=False, inputs=None)
183-
aten::ones_like(dt: f32[]| P, pin_memory=False, memory_format=torch.preserve_format)
182+
torch._tensor.backward(dt: f32[]| P(sum), gradient=None, retain_graph=None, create_graph=False, inputs=None)
183+
aten::ones_like(dt: f32[]| P(sum), pin_memory=False, memory_format=torch.preserve_format)
184184
aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format)
185185
aten::expand(dt: f32[]| R, [8, 8])
186186
aten::expand(t: f32[], [8, 8])
@@ -189,9 +189,9 @@ def test_debug_mode_backward(self):
189189
aten::clone(t: f32[8, 1])
190190
aten::_to_copy(t: f32[8, 1], dtype=torch.float32, layout=torch.strided, device=cpu)
191191
redistribute_input(t: f32[8, 8], trace: R->S(0))
192-
aten::detach(t: f32[8, 1])
193192
aten::split.Tensor(t: f32[8, 8], 1)
194193
aten::clone(t: f32[1, 8])
194+
aten::detach(t: f32[8, 1])
195195
aten::_to_copy(t: f32[1, 8], dtype=torch.float32, layout=torch.strided, device=cpu)
196196
aten::detach(t: f32[1, 8])""",
197197
)
@@ -253,50 +253,50 @@ def test_debug_mode_einsum(self):
253253
self.assertExpectedInline(
254254
debug_mode.debug_string(),
255255
"""\
256-
torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8]| PR, dt: f32[8, 4, 4]| RP)
257-
aten::unsqueeze(dt: f32[16, 6, 8]| PR, 3)
256+
torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8]| P(sum)R, dt: f32[8, 4, 4]| RP(sum))
257+
aten::unsqueeze(dt: f32[16, 6, 8]| P(sum)R, 3)
258258
aten::unsqueeze(t: f32[16, 6, 8], 3)
259-
aten::unsqueeze(dt: f32[16, 6, 8, 1]| PR, 4)
259+
aten::unsqueeze(dt: f32[16, 6, 8, 1]| P(sum)R, 4)
260260
aten::unsqueeze(t: f32[16, 6, 8, 1], 4)
261-
aten::permute(dt: f32[16, 6, 8, 1, 1]| PR, [0, 1, 3, 4, 2])
261+
aten::permute(dt: f32[16, 6, 8, 1, 1]| P(sum)R, [0, 1, 3, 4, 2])
262262
aten::permute(t: f32[16, 6, 8, 1, 1], [0, 1, 3, 4, 2])
263-
aten::unsqueeze(dt: f32[8, 4, 4]| RP, 3)
263+
aten::unsqueeze(dt: f32[8, 4, 4]| RP(sum), 3)
264264
aten::unsqueeze(t: f32[8, 4, 4], 3)
265-
aten::unsqueeze(dt: f32[8, 4, 4, 1]| RP, 4)
265+
aten::unsqueeze(dt: f32[8, 4, 4, 1]| RP(sum), 4)
266266
aten::unsqueeze(t: f32[8, 4, 4, 1], 4)
267-
aten::permute(dt: f32[8, 4, 4, 1, 1]| RP, [3, 4, 1, 2, 0])
267+
aten::permute(dt: f32[8, 4, 4, 1, 1]| RP(sum), [3, 4, 1, 2, 0])
268268
aten::permute(t: f32[8, 4, 4, 1, 1], [3, 4, 1, 2, 0])
269-
aten::permute(dt: f32[16, 6, 1, 1, 8]| PR, [0, 1, 4, 2, 3])
269+
aten::permute(dt: f32[16, 6, 1, 1, 8]| P(sum)R, [0, 1, 4, 2, 3])
270270
aten::permute(t: f32[16, 6, 1, 1, 8], [0, 1, 4, 2, 3])
271-
aten::view(dt: f32[16, 6, 8, 1, 1]| PR, [1, 96, 8])
271+
aten::view(dt: f32[16, 6, 8, 1, 1]| P(sum)R, [1, 96, 8])
272272
aten::view(t: f32[16, 6, 8, 1, 1], [1, 96, 8])
273-
aten::permute(dt: f32[1, 1, 4, 4, 8]| RP, [4, 2, 3, 0, 1])
273+
aten::permute(dt: f32[1, 1, 4, 4, 8]| RP(sum), [4, 2, 3, 0, 1])
274274
aten::permute(t: f32[1, 1, 4, 4, 8], [4, 2, 3, 0, 1])
275-
aten::view(dt: f32[8, 4, 4, 1, 1]| RP, [1, 8, 16])
275+
aten::view(dt: f32[8, 4, 4, 1, 1]| RP(sum), [1, 8, 16])
276276
aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
277-
aten::bmm(dt: f32[1, 96, 8]| PR, dt: f32[1, 8, 16]| RP)
278-
redistribute_input(0, PR -> S(2)[0]S(2)[1])
279-
redistribute_input(t: f32[1, 96, 8], trace: PR->S(2)R->S(2)[0]S(2)[1])
277+
aten::bmm(dt: f32[1, 96, 8]| P(sum)R, dt: f32[1, 8, 16]| RP(sum))
278+
redistribute_input(0, P(sum)R -> S(2)[0]S(2)[1])
279+
redistribute_input(t: f32[1, 96, 8], trace: P(sum)R->S(2)R->S(2)[0]S(2)[1])
280280
aten::chunk(t: f32[1, 96, 8], 4, 2)
281281
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
282282
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
283283
_c10d_functional::wait_tensor(t: f32[1, 96, 2])
284284
aten::chunk(t: f32[1, 96, 2], 2, 2)
285285
aten::clone(t: f32[1, 96, 1])
286-
redistribute_input(1, RP -> S(1)[0]S(1)[1])
287-
redistribute_input(t: f32[1, 8, 16], trace: RP->S(1)P->S(1)[0]S(1)[1])
286+
redistribute_input(1, RP(sum) -> S(1)[0]S(1)[1])
287+
redistribute_input(t: f32[1, 8, 16], trace: RP(sum)->S(1)P(sum)->S(1)[0]S(1)[1])
288288
aten::chunk(t: f32[1, 8, 16], 4, 1)
289289
aten::clone(t: f32[1, 2, 16])
290290
aten::chunk(t: f32[1, 2, 16], 2, 1)
291291
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
292292
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
293293
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
294294
aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
295-
aten::view(dt: f32[1, 96, 16]| PP, [16, 6, 1, 4, 4])
295+
aten::view(dt: f32[1, 96, 16]| P(sum)P(sum), [16, 6, 1, 4, 4])
296296
aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])
297-
aten::permute(dt: f32[16, 6, 1, 4, 4]| PP, [0, 1, 3, 4, 2])
297+
aten::permute(dt: f32[16, 6, 1, 4, 4]| P(sum)P(sum), [0, 1, 3, 4, 2])
298298
aten::permute(t: f32[16, 6, 1, 4, 4], [0, 1, 3, 4, 2])
299-
aten::view(dt: f32[16, 6, 4, 4, 1]| PP, [16, 6, 4, 4])
299+
aten::view(dt: f32[16, 6, 4, 4, 1]| P(sum)P(sum), [16, 6, 4, 4])
300300
aten::view(t: f32[16, 6, 4, 4, 1], [16, 6, 4, 4])""",
301301
)
302302

test/distributed/tensor/test_api.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,13 @@ def test_distribute_tensor_rank(self):
7979
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_minus_spec)
8080
self.assertEqual(dist_tensor.placements[0].dim, 1)
8181

82-
placement_combs = [[Shard(0)], [Shard(1)], [Replicate()]]
82+
placement_combs = [
83+
[Shard(0)],
84+
[Shard(1)],
85+
[Replicate()],
86+
[Partial(reduce_op="sum")],
87+
[Partial(reduce_op="avg")],
88+
]
8389

8490
if not self.is_local_tensor_enabled:
8591
# test src_data_rank == 1
@@ -125,6 +131,10 @@ def test_distribute_tensor_errors(self):
125131
shard_spec = [Shard(0)]
126132
distribute_tensor(tensor_to_distribute, device_mesh, shard_spec)
127133

134+
with self.assertRaisesRegex(ValueError, "conversion is not supported"):
135+
new_spec = [Replicate(), Partial(reduce_op="prod")]
136+
distribute_tensor(tensor_to_distribute, device_mesh, new_spec)
137+
128138
with self.assertRaisesRegex(RuntimeError, "distribute leaf tensor"):
129139
shard_spec = [Shard(0)]
130140
global_tensor = torch.randn(*tensor_shape, requires_grad=True)

test/distributed/tensor/test_dtensor_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ def run_mean(self):
725725
self.assertEqual(full_tensor, tensor.mean(dim=reduce_dim))
726726

727727
if is_evenly_shardable:
728-
self.assertTrue("P->R" in debug_mode.debug_string())
728+
self.assertTrue("P(avg)->R" in debug_mode.debug_string())
729729
else:
730730
self.assertTrue("S(0)->R" in debug_mode.debug_string())
731731

test/distributed/tensor/test_pointwise_ops.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,30 @@ def test_partial_add(self):
148148
d_3 = d_1 + d_2
149149
self.assertTrue(d_3._spec.placements[0].is_partial())
150150

151+
def test_partial_replicate_add(self):
152+
device_mesh = self.build_device_mesh()
153+
comm_mode = CommDebugMode()
154+
155+
for reduce_op in ("sum", "avg"):
156+
d_1 = DTensor.from_local(
157+
torch.rand(2, 2),
158+
device_mesh,
159+
[Partial(reduce_op=reduce_op)],
160+
)
161+
d_2 = DTensor.from_local(
162+
torch.rand(2, 1),
163+
device_mesh,
164+
[Replicate()],
165+
run_check=True,
166+
)
167+
168+
with comm_mode:
169+
d_3 = d_1 + d_2
170+
171+
self.assertEqual(comm_mode.get_total_counts(), 0)
172+
self.assertEqual(d_3.placements, (Partial(reduce_op=reduce_op),))
173+
self.assertEqual(d_3.full_tensor(), d_1.full_tensor() + d_2.full_tensor())
174+
151175
def test_activations(self):
152176
device_mesh = self.build_device_mesh()
153177
self._run_sharded_elementwise_ops(
@@ -247,6 +271,7 @@ def test_dropout_backward(self):
247271
),
248272
)
249273

274+
@skip_unless_torch_gpu
250275
def test_dropout_errors(self):
251276
device_mesh = self.build_device_mesh()
252277
with self.assertRaisesRegex(RuntimeError, "supported"):

torch/distributed/tensor/_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,11 @@ def distribute_tensor(
818818
local_tensor = Replicate._make_replicate_tensor(
819819
local_tensor, device_mesh, idx, src_data_rank
820820
)
821+
elif isinstance(placement, Partial):
822+
local_tensor = Replicate._make_replicate_tensor(
823+
local_tensor, device_mesh, idx, src_data_rank
824+
)
825+
local_tensor = placement._partition_value(local_tensor, device_mesh, idx)
821826
else:
822827
raise RuntimeError(
823828
f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!"

torch/distributed/tensor/_ops/_math_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,16 @@ def __eq__(self, other: object) -> bool:
163163
def __hash__(self) -> int:
164164
return 1 + hash(self.norm_type)
165165

166+
def __repr__(self) -> str:
167+
"""
168+
machine readable representation of the _NormPartial placement
169+
"""
170+
return f"_NormPartial(reduce_op={self.reduce_op}, norm_type={self.norm_type})"
171+
172+
def __str__(self) -> str:
173+
"""human readable representation of the _NormPartial placement"""
174+
return f"_NormP({self.reduce_op}, {self.norm_type})"
175+
166176

167177
def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[list[int]]:
168178
if dims_arg is None:

torch/distributed/tensor/placement_types.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -816,14 +816,18 @@ def _partition_value(
816816
# Partial placement contract #3:
817817
# _partition_value: partition the value of a replicated tensor on the mesh dimension
818818

819-
# _partition_value is the conjugate operation of _reduce_value
820-
# - i.e. _partition_value on a sum reduce op is just a division operation
821-
# - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation
822-
# TODO: if the reduce_op is min/max, etc. the _partition_value should be a
823-
# different operation
824-
assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!"
819+
# _partition_value is the conjugate operation of _reduce_value, e.g.
820+
# - _partition_value on a sum reduce op is just a division operation
821+
# - _reduce_value on a sum reduce op would just be a sum(allreduce) operation
825822
num_chunks = mesh.size(mesh_dim=mesh_dim)
826-
return tensor / num_chunks
823+
if self.reduce_op == "sum":
824+
return tensor / num_chunks
825+
elif self.reduce_op in ("avg", "min", "max"):
826+
return tensor
827+
else:
828+
raise ValueError(
829+
f"Replicate to Partial({self.reduce_op}) conversion is not supported."
830+
)
827831

828832
def __hash__(self) -> int:
829833
return 1 + hash(self.reduce_op)
@@ -838,7 +842,7 @@ def __str__(self) -> str:
838842
"""
839843
human readable representation of the Partial placement
840844
"""
841-
return "P"
845+
return f"P({self.reduce_op})"
842846

843847

844848
# We keep the old _Partial name for a while for BC reason
@@ -982,10 +986,10 @@ def __repr__(self) -> str:
982986
"""
983987
machine readable representation of the MaskPartial placement
984988
"""
985-
return f"MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
989+
return f"MaskPartial(reduce_op={self.reduce_op}, offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
986990

987991
def __str__(self) -> str:
988992
"""
989993
human readable representation of the MaskPartial placement
990994
"""
991-
return "MaskP"
995+
return f"MaskP({self.reduce_op}, {self.offset_shape}, {self.offset_dim})"

0 commit comments

Comments
 (0)