Skip to content

Commit ec92b1a

Browse files
fix: model.set_requires_gradient_sync(False) should be called to turn off gradient synchronization in FSDP2 (#3762)
* fix :`model.set_requires_gradient_sync(False)` should be called to turn off gradient synchronization in FSDP2. * fix: remove trailing whitespace
1 parent 62ede1e commit ec92b1a

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

src/accelerate/accelerator.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,13 +1169,20 @@ def no_sync(self, model):
11691169
>>> optimizer.zero_grad()
11701170
```
11711171
"""
1172-
context = contextlib.nullcontext
1173-
if self.use_distributed:
1174-
if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
1175-
context = getattr(model, "no_sync", context)
1172+
if self.is_fsdp2:
1173+
model.set_requires_gradient_sync(False)
1174+
try:
1175+
yield
1176+
finally:
1177+
model.set_requires_gradient_sync(True)
1178+
else:
1179+
context = contextlib.nullcontext
1180+
if self.use_distributed:
1181+
if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
1182+
context = getattr(model, "no_sync", context)
11761183

1177-
with context():
1178-
yield
1184+
with context():
1185+
yield
11791186

11801187
@staticmethod
11811188
@contextmanager

0 commit comments

Comments
 (0)