Skip to content

Commit 962f13f

Browse files
anijain2305pytorchmergebot
authored andcommitted
[compile][to_local] Support Sequence-like placement user defined objects (pytorch#168149)
grad_placements is a sequence like data structure and therefore can be a UserDefinedObject. In that case, we can extract the tuple and pass along. Pull Request resolved: pytorch#168149 Approved by: https://github.com/bdhirsh
1 parent eefc0f8 commit 962f13f

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

test/distributed/tensor/test_dtensor_compile.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,54 @@
6363
dev_type = torch.device(get_devtype())
6464

6565

66+
class PytreeTuple:
67+
"""
68+
Tuple-like values that are treated as leaves of a PyTree.
69+
"""
70+
71+
def __init__(self, *values):
72+
self._values = tuple(values)
73+
74+
def __repr__(self):
75+
pr = repr(self._values)[1:-1]
76+
return f"{type(self).__name__}({pr})"
77+
78+
def __getitem__(self, i):
79+
return self._values[i]
80+
81+
def __iter__(self):
82+
return iter(self._values)
83+
84+
def __len__(self):
85+
return len(self._values)
86+
87+
def __eq__(self, other: object) -> bool:
88+
if isinstance(other, self.__class__):
89+
return self._values == other._values
90+
elif isinstance(other, tuple):
91+
return self._values == other
92+
return False
93+
94+
def __hash__(self) -> int:
95+
return hash(self._values)
96+
97+
def __add__(self, other):
98+
if isinstance(other, (self.__class__, tuple)):
99+
return self.__class__(*self, *other)
100+
raise NotImplementedError(type(other))
101+
102+
def __radd__(self, other):
103+
if isinstance(other, (self.__class__, tuple)):
104+
return self.__class__(*other, *self)
105+
raise NotImplementedError(type(other))
106+
107+
def index(self, value):
108+
return self._values.index(value)
109+
110+
def count(self, value):
111+
return self._values.count(value)
112+
113+
66114
class SimpleModel(nn.Module):
67115
def __init__(self, device):
68116
super().__init__()
@@ -767,6 +815,37 @@ def fn(x):
767815
# this fails with an inductor stride assert
768816
out_dt.to_local().sum().backward()
769817

818+
def test_dynamo_to_local_grad_placements_sequence(self):
819+
placements = PytreeTuple([Shard(0)])
820+
821+
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
822+
823+
def fn(x):
824+
return dt.to_local(grad_placements=placements) + 2
825+
826+
fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True)
827+
x = torch.ones(4)
828+
dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False)
829+
830+
out_ref = fn(dt)
831+
out_test = fn_opt(dt)
832+
self.assertEqual(out_ref, out_test)
833+
834+
def test_dynamo_to_local_grad_placements_sequence_intermediate(self):
835+
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
836+
837+
def fn(x):
838+
placements = PytreeTuple([Shard(0)])
839+
return dt.to_local(grad_placements=placements) + 2
840+
841+
fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True)
842+
x = torch.ones(4)
843+
dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False)
844+
845+
out_ref = fn(dt)
846+
out_test = fn_opt(dt)
847+
self.assertEqual(out_ref, out_test)
848+
770849
def test_dynamo_to_local_kwargs(self):
771850
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
772851

torch/_dynamo/variables/tensor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,19 @@ def method_to_local(self, *args, **kwargs):
12661266
tx = InstructionTranslator.current_tx()
12671267
# rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
12681268
# and rewrite args to have only proxyable args, then insert call_function
1269+
1270+
grad_placements_vt = kwargs.get(
1271+
"grad_placements", ConstantVariable.create(None)
1272+
)
1273+
if isinstance(grad_placements_vt, variables.UserDefinedObjectVariable):
1274+
# grad_placement is a sequence-like structure, iterate over the value
1275+
grad_placements_vt = variables.BuiltinVariable(tuple).call_function(
1276+
tx, [grad_placements_vt], {}
1277+
)
1278+
1279+
if kwargs.get("grad_placements") is not None:
1280+
kwargs["grad_placements"] = grad_placements_vt
1281+
12691282
args_as_value = [x.as_python_constant() for x in args]
12701283
kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}
12711284

0 commit comments

Comments
 (0)