|
63 | 63 | dev_type = torch.device(get_devtype()) |
64 | 64 |
|
65 | 65 |
|
| 66 | +class PytreeTuple: |
| 67 | + """ |
| 68 | + Tuple-like values that are treated as leaves of a PyTree. |
| 69 | + """ |
| 70 | + |
| 71 | + def __init__(self, *values): |
| 72 | + self._values = tuple(values) |
| 73 | + |
| 74 | + def __repr__(self): |
| 75 | + pr = repr(self._values)[1:-1] |
| 76 | + return f"{type(self).__name__}({pr})" |
| 77 | + |
| 78 | + def __getitem__(self, i): |
| 79 | + return self._values[i] |
| 80 | + |
| 81 | + def __iter__(self): |
| 82 | + return iter(self._values) |
| 83 | + |
| 84 | + def __len__(self): |
| 85 | + return len(self._values) |
| 86 | + |
| 87 | + def __eq__(self, other: object) -> bool: |
| 88 | + if isinstance(other, self.__class__): |
| 89 | + return self._values == other._values |
| 90 | + elif isinstance(other, tuple): |
| 91 | + return self._values == other |
| 92 | + return False |
| 93 | + |
| 94 | + def __hash__(self) -> int: |
| 95 | + return hash(self._values) |
| 96 | + |
| 97 | + def __add__(self, other): |
| 98 | + if isinstance(other, (self.__class__, tuple)): |
| 99 | + return self.__class__(*self, *other) |
| 100 | + raise NotImplementedError(type(other)) |
| 101 | + |
| 102 | + def __radd__(self, other): |
| 103 | + if isinstance(other, (self.__class__, tuple)): |
| 104 | + return self.__class__(*other, *self) |
| 105 | + raise NotImplementedError(type(other)) |
| 106 | + |
| 107 | + def index(self, value): |
| 108 | + return self._values.index(value) |
| 109 | + |
| 110 | + def count(self, value): |
| 111 | + return self._values.count(value) |
| 112 | + |
| 113 | + |
66 | 114 | class SimpleModel(nn.Module): |
67 | 115 | def __init__(self, device): |
68 | 116 | super().__init__() |
@@ -767,6 +815,37 @@ def fn(x): |
767 | 815 | # this fails with an inductor stride assert |
768 | 816 | out_dt.to_local().sum().backward() |
769 | 817 |
|
| 818 | + def test_dynamo_to_local_grad_placements_sequence(self): |
| 819 | + placements = PytreeTuple([Shard(0)]) |
| 820 | + |
| 821 | + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) |
| 822 | + |
| 823 | + def fn(x): |
| 824 | + return dt.to_local(grad_placements=placements) + 2 |
| 825 | + |
| 826 | + fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True) |
| 827 | + x = torch.ones(4) |
| 828 | + dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) |
| 829 | + |
| 830 | + out_ref = fn(dt) |
| 831 | + out_test = fn_opt(dt) |
| 832 | + self.assertEqual(out_ref, out_test) |
| 833 | + |
| 834 | + def test_dynamo_to_local_grad_placements_sequence_intermediate(self): |
| 835 | + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) |
| 836 | + |
| 837 | + def fn(x): |
| 838 | + placements = PytreeTuple([Shard(0)]) |
| 839 | + return dt.to_local(grad_placements=placements) + 2 |
| 840 | + |
| 841 | + fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True) |
| 842 | + x = torch.ones(4) |
| 843 | + dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) |
| 844 | + |
| 845 | + out_ref = fn(dt) |
| 846 | + out_test = fn_opt(dt) |
| 847 | + self.assertEqual(out_ref, out_test) |
| 848 | + |
770 | 849 | def test_dynamo_to_local_kwargs(self): |
771 | 850 | mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) |
772 | 851 |
|
|
0 commit comments