|
10 | 10 | from torch.distributed.fsdp import FSDPModule
|
11 | 11 | from torch.distributed.tensor import DTensor
|
12 | 12 | from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
13 |
| -from torch.testing._internal.common_fsdp import FSDPTestMultiThread, get_devtype, MLP |
14 |
| -from torch.testing._internal.common_utils import run_tests |
| 13 | +from torch.testing._internal.common_fsdp import ( |
| 14 | + check_sharded_parity, |
| 15 | + FSDPTestMultiThread, |
| 16 | + get_devtype, |
| 17 | + MLP, |
| 18 | +) |
| 19 | +from torch.testing._internal.common_utils import run_tests, wrapSwapTensorsTest |
15 | 20 |
|
16 | 21 |
|
17 | 22 | c10d_ops = torch.ops.c10d
|
@@ -169,5 +174,61 @@ def _assert_same_params(
|
169 | 174 | self.assertEqual(param, ref_param)
|
170 | 175 |
|
171 | 176 |
|
| 177 | +class TestReplicateCastAfterInit(FSDPTestMultiThread): |
| 178 | + @property |
| 179 | + def world_size(self) -> int: |
| 180 | + return 2 |
| 181 | + |
| 182 | + @skip_if_lt_x_gpu(1) |
| 183 | + @wrapSwapTensorsTest(True) |
| 184 | + def test_to_float64_after_init(self): |
| 185 | + """Tests that the user can cast the module to float64 after init.""" |
| 186 | + # NOTE: Test fp64 instead of a lower precision dtype like bf16 for |
| 187 | + # better numerics. The important part is changing the dtype. |
| 188 | + |
| 189 | + torch.manual_seed(42) |
| 190 | + mlp_dim, device, dtype = 4, device_type, torch.float64 |
| 191 | + model = MLP(mlp_dim, device=device) |
| 192 | + for param in model.parameters(): |
| 193 | + dist.broadcast(param, src=0) |
| 194 | + ref_model = copy.deepcopy(model).to(dtype) |
| 195 | + |
| 196 | + ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) |
| 197 | + for module in (model.in_proj, model.out_proj, model): |
| 198 | + replicate(module) |
| 199 | + model.to(dtype) |
| 200 | + for param in model.parameters(): |
| 201 | + self.assertEqual(param.dtype, dtype) |
| 202 | + self.assertEqual(param.to_local().dtype, dtype) |
| 203 | + self.assertEqual(param._spec.tensor_meta.dtype, dtype) |
| 204 | + optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) |
| 205 | + check_sharded_parity(self, ref_model, model) |
| 206 | + torch.manual_seed(42 + self.rank + 1) |
| 207 | + inp = torch.randn((2, mlp_dim), device=device_type.type, dtype=dtype) |
| 208 | + for iter_idx in range(10): |
| 209 | + losses: list[torch.Tensor] = [] |
| 210 | + for _model in (ref_model, model): |
| 211 | + losses.append(_model(inp).sum()) |
| 212 | + losses[-1].backward() |
| 213 | + |
| 214 | + for param in ref_model.parameters(): |
| 215 | + if param.grad is not None: |
| 216 | + dist.all_reduce(param.grad) |
| 217 | + param.grad.div_(self.world_size) |
| 218 | + |
| 219 | + self.assertEqual(losses[0], losses[1]) |
| 220 | + check_sharded_parity(self, ref_model, model) |
| 221 | + for param in model.parameters(): |
| 222 | + self.assertEqual(param.dtype, dtype) |
| 223 | + self.assertEqual(param.to_local().dtype, dtype) |
| 224 | + self.assertEqual(param._spec.tensor_meta.dtype, dtype) |
| 225 | + self.assertEqual(param.grad.dtype, dtype) |
| 226 | + self.assertEqual(param.grad.to_local().dtype, dtype) |
| 227 | + self.assertEqual(param.grad._spec.tensor_meta.dtype, dtype) |
| 228 | + for _optim in (ref_optim, optim): |
| 229 | + _optim.step() |
| 230 | + _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) |
| 231 | + |
| 232 | + |
172 | 233 | if __name__ == "__main__":
|
173 | 234 | run_tests()
|
0 commit comments