Skip to content

Commit f8a2ce3

Browse files
RohitRathore1pytorchmergebot
authored andcommitted
Fix inplace ops on Partial DTensors to preserve aliasing semantics (pytorch#164729)
Fixes pytorch#163374. Here is the output from reproducible code: ``` W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] ***************************************** W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W1006 09:09:26.329000 2457 /home/fedora/github/pytorch/torch/distributed/run.py:811] ***************************************** aten::clamp_(dt: f32[][R], None, 2) redistribute_input(0, [P] -> [R]) redistribute_input(t: f32[], [P] -> [R]) _c10d_functional::all_reduce(t: f32[], sum, 0) _c10d_functional::wait_tensor(t: f32[]) aten::clamp_(t: f32[], None, 2) aten::view(t: f32[], []) (Replicate(),) tensor(2., device='cuda:0') ``` The behavior is now matching what you were expecting in issue pytorch#163374: Expected behavior (from the issue): 1. Placement should change from Partial(sum) to Replicate() 2. Value should be tensor(2.) instead of tensor(144.) Actual output from this build: 1. (Replicate(),) - placement is correct 2. tensor(2., device='cuda:0') - value is correct so the inplace operation now properly redistributes the partial DTensor to replicate before performing the clamp snd maintains the correct aliasing semantics. It also produces the expected clamped value. Pull Request resolved: pytorch#164729 Approved by: https://github.com/ezyang
1 parent e2c6834 commit f8a2ce3

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

test/distributed/tensor/test_pointwise_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,25 @@ def test_mul_partial(self):
331331
self.assertEqual(z.placements, (Replicate(),))
332332
self.assertEqual(z.to_local(), input)
333333

334+
def test_inplace_op_partial_to_replicate(self):
335+
# test that in-place operations that require redistribution raise an error
336+
# to preserve aliasing semantics (issue #163374)
337+
device_mesh = self.build_device_mesh()
338+
339+
input_tensor = torch.tensor(64.0, device=self.device_type)
340+
partial_dt = DTensor.from_local(
341+
input_tensor, device_mesh, placements=(Partial(),)
342+
)
343+
344+
self.assertTrue(partial_dt.placements[0].is_partial())
345+
346+
# Inplace ops that require placement changes (Partial -> Replicate) should error
347+
with self.assertRaisesRegex(
348+
RuntimeError,
349+
"in-place operations that require placement changes are not supported",
350+
):
351+
partial_dt.clamp_(max=10)
352+
334353

335354
if __name__ == "__main__":
336355
run_tests()

torch/distributed/tensor/_dispatch.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,19 +337,34 @@ def _dispatch_fast_path_python_tail(
337337
if is_inplace_op:
338338
# inplace op should return self instead of re-wrapping
339339
if output_sharding.output_spec is not None:
340+
output_spec = output_sharding.output_spec
341+
assert isinstance(output_spec, DTensorSpec)
342+
assert isinstance(args[0], dtensor.DTensor)
343+
340344
# NOTE: aten.squeeze_.dim is an inplace op but it also may change
341345
# the inplace argument's tensor meta. Here we choose to special case
342346
# this op because as far as I know this is the only inplace op that
343347
# has such as behavior. We can extend this special case if necessary.
344348
if op_call == aten.squeeze_.dim:
345-
output_spec = output_sharding.output_spec
346-
assert isinstance(output_spec, DTensorSpec)
347-
assert isinstance(args[0], dtensor.DTensor)
349+
# update the spec to handle tensor meta changes
348350
args[0]._spec = output_spec
349351
# use return_and_correct_aliasing to match the outer and the inner
350352
# aliasing. See https://github.com/pytorch/pytorch/pull/158954
351353
return return_and_correct_aliasing(op_call, args, kwargs, args[0])
352354
else:
355+
# For all other inplace ops, check if placement changes are required
356+
# Inplace operations that change placement are not supported because
357+
# they would require redistribution, which breaks aliasing semantics.
358+
# If there are views into the tensor, the views would not be updated.
359+
if args[0]._spec.placements != output_spec.placements:
360+
raise RuntimeError(
361+
f"{op_call}: in-place operations that require placement changes "
362+
f"are not supported. The operation would change placement from "
363+
f"{args[0]._spec.placements} to {output_spec.placements}, "
364+
f"which requires redistribution and breaks aliasing semantics. "
365+
f"Please use the out-of-place version of this operation instead."
366+
)
367+
# Most inplace ops don't change tensor meta, so no spec update needed
353368
return args[0]
354369
else:
355370
return None

0 commit comments

Comments
 (0)