Skip to content

Commit 7a0f933

Browse files
anshul-sipytorchmergebot
authored andcommitted
[FSDP][Replicate] tests replicate casting module after init (pytorch#162636)
**Summary:** In order to ensure that replicate acts as intended (a specialized version of hsdp) we need to make sure that it can pass the same tests that fully_shard can for training. This test is important as it verifies we can cast a replicated module to a different type after initialization, and import feature for enabling mixed precision, **Test Cases** 1. pytest test/distributed/_composable/test_replicate_training.py -k test_to_float64_after_init Pull Request resolved: pytorch#162636 Approved by: https://github.com/mori360 ghstack dependencies: pytorch#162631
1 parent 63276ed commit 7a0f933

File tree

1 file changed

+63
-2
lines changed

1 file changed

+63
-2
lines changed

test/distributed/_composable/test_replicate_training.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@
1010
from torch.distributed.fsdp import FSDPModule
1111
from torch.distributed.tensor import DTensor
1212
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
13-
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, get_devtype, MLP
14-
from torch.testing._internal.common_utils import run_tests
13+
from torch.testing._internal.common_fsdp import (
14+
check_sharded_parity,
15+
FSDPTestMultiThread,
16+
get_devtype,
17+
MLP,
18+
)
19+
from torch.testing._internal.common_utils import run_tests, wrapSwapTensorsTest
1520

1621

1722
c10d_ops = torch.ops.c10d
@@ -169,5 +174,61 @@ def _assert_same_params(
169174
self.assertEqual(param, ref_param)
170175

171176

177+
class TestReplicateCastAfterInit(FSDPTestMultiThread):
178+
@property
179+
def world_size(self) -> int:
180+
return 2
181+
182+
@skip_if_lt_x_gpu(1)
183+
@wrapSwapTensorsTest(True)
184+
def test_to_float64_after_init(self):
185+
"""Tests that the user can cast the module to float64 after init."""
186+
# NOTE: Test fp64 instead of a lower precision dtype like bf16 for
187+
# better numerics. The important part is changing the dtype.
188+
189+
torch.manual_seed(42)
190+
mlp_dim, device, dtype = 4, device_type, torch.float64
191+
model = MLP(mlp_dim, device=device)
192+
for param in model.parameters():
193+
dist.broadcast(param, src=0)
194+
ref_model = copy.deepcopy(model).to(dtype)
195+
196+
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
197+
for module in (model.in_proj, model.out_proj, model):
198+
replicate(module)
199+
model.to(dtype)
200+
for param in model.parameters():
201+
self.assertEqual(param.dtype, dtype)
202+
self.assertEqual(param.to_local().dtype, dtype)
203+
self.assertEqual(param._spec.tensor_meta.dtype, dtype)
204+
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
205+
check_sharded_parity(self, ref_model, model)
206+
torch.manual_seed(42 + self.rank + 1)
207+
inp = torch.randn((2, mlp_dim), device=device_type.type, dtype=dtype)
208+
for iter_idx in range(10):
209+
losses: list[torch.Tensor] = []
210+
for _model in (ref_model, model):
211+
losses.append(_model(inp).sum())
212+
losses[-1].backward()
213+
214+
for param in ref_model.parameters():
215+
if param.grad is not None:
216+
dist.all_reduce(param.grad)
217+
param.grad.div_(self.world_size)
218+
219+
self.assertEqual(losses[0], losses[1])
220+
check_sharded_parity(self, ref_model, model)
221+
for param in model.parameters():
222+
self.assertEqual(param.dtype, dtype)
223+
self.assertEqual(param.to_local().dtype, dtype)
224+
self.assertEqual(param._spec.tensor_meta.dtype, dtype)
225+
self.assertEqual(param.grad.dtype, dtype)
226+
self.assertEqual(param.grad.to_local().dtype, dtype)
227+
self.assertEqual(param.grad._spec.tensor_meta.dtype, dtype)
228+
for _optim in (ref_optim, optim):
229+
_optim.step()
230+
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
231+
232+
172233
if __name__ == "__main__":
173234
run_tests()

0 commit comments

Comments
 (0)