Skip to content

Commit 06eb3cc

Browse files
authored
Pass enabled down to _BackwardSyncControl (#19577)
1 parent 3740546 commit 06eb3cc

File tree

10 files changed

+35
-19
lines changed

10 files changed

+35
-19
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626
- Fabric now raises an error if you forget to call `fabric.backward()` when it is needed by the strategy or precision selection ([#19447](https://github.com/Lightning-AI/lightning/pull/19447), [#19493](https://github.com/Lightning-AI/lightning/pull/19493))
2727

2828

29-
-
29+
- `_BackwardSyncControl` can now control what to do when gradient accumulation is disabled ([#19577](https://github.com/Lightning-AI/lightning/pull/19577))
30+
3031

3132
### Deprecated
3233

src/lightning/fabric/fabric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
672672
"You need to set up the model first before you can call `fabric.no_backward_sync()`:"
673673
" `model = fabric.setup(model, ...)`"
674674
)
675-
if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)):
675+
if isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)):
676676
return nullcontext()
677677
if self._strategy._backward_sync_control is None:
678678
rank_zero_warn(
@@ -683,7 +683,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
683683
return nullcontext()
684684

685685
forward_module, _ = _unwrap_compiled(module._forward_module)
686-
return self._strategy._backward_sync_control.no_backward_sync(forward_module)
686+
return self._strategy._backward_sync_control.no_backward_sync(forward_module, enabled)
687687

688688
def sharded_model(self) -> ContextManager:
689689
r"""Instantiate a model under this context manager to prepare it for model-parallel sharding.

src/lightning/fabric/strategies/ddp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,12 @@ def _determine_ddp_device_ids(self) -> Optional[List[int]]:
224224

225225
class _DDPBackwardSyncControl(_BackwardSyncControl):
226226
@override
227-
def no_backward_sync(self, module: Module) -> ContextManager:
227+
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
228228
"""Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel`
229229
wrapper."""
230+
if not enabled:
231+
return nullcontext()
232+
230233
if not isinstance(module, DistributedDataParallel):
231234
raise TypeError(
232235
"Blocking backward sync is only possible if the module passed to"

src/lightning/fabric/strategies/fsdp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import shutil
15-
from contextlib import ExitStack
15+
from contextlib import ExitStack, nullcontext
1616
from datetime import timedelta
1717
from functools import partial
1818
from pathlib import Path
@@ -768,9 +768,11 @@ def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwa
768768

769769
class _FSDPBackwardSyncControl(_BackwardSyncControl):
770770
@override
771-
def no_backward_sync(self, module: Module) -> ContextManager:
771+
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
772772
"""Blocks gradient synchronization inside the :class:`~torch.distributed.fsdp.FullyShardedDataParallel`
773773
wrapper."""
774+
if not enabled:
775+
return nullcontext()
774776
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
775777

776778
if not isinstance(module, FullyShardedDataParallel):

src/lightning/fabric/strategies/strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ class _BackwardSyncControl(ABC):
424424
"""
425425

426426
@abstractmethod
427-
def no_backward_sync(self, module: Module) -> ContextManager:
427+
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
428428
"""Blocks the synchronization of gradients during the backward pass.
429429
430430
This is a context manager. It is only effective if it wraps a call to `.backward()`.

src/lightning/fabric/strategies/xla_fsdp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,9 +679,11 @@ def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: Dict
679679

680680
class _XLAFSDPBackwardSyncControl(_BackwardSyncControl):
681681
@override
682-
def no_backward_sync(self, module: Module) -> ContextManager:
682+
def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
683683
"""Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel`
684684
wrapper."""
685+
if not enabled:
686+
return nullcontext()
685687
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP
686688

687689
if not isinstance(module, XLAFSDP):

tests/tests_fabric/strategies/test_ddp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,15 @@ def test_ddp_no_backward_sync():
6161

6262
with pytest.raises(
6363
TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`"
64-
), strategy._backward_sync_control.no_backward_sync(Mock()):
64+
), strategy._backward_sync_control.no_backward_sync(Mock(), True):
6565
pass
6666

6767
module = MagicMock(spec=DistributedDataParallel)
68-
with strategy._backward_sync_control.no_backward_sync(module):
68+
with strategy._backward_sync_control.no_backward_sync(module, False):
69+
pass
70+
module.no_sync.assert_not_called()
71+
with strategy._backward_sync_control.no_backward_sync(module, True):
6972
pass
70-
7173
module.no_sync.assert_called_once()
7274

7375

tests/tests_fabric/strategies/test_fsdp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,15 @@ def test_fsdp_no_backward_sync():
150150

151151
with pytest.raises(
152152
TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`"
153-
), strategy._backward_sync_control.no_backward_sync(Mock()):
153+
), strategy._backward_sync_control.no_backward_sync(Mock(), True):
154154
pass
155155

156156
module = MagicMock(spec=FullyShardedDataParallel)
157-
with strategy._backward_sync_control.no_backward_sync(module):
157+
with strategy._backward_sync_control.no_backward_sync(module, False):
158+
pass
159+
module.no_sync.assert_not_called()
160+
with strategy._backward_sync_control.no_backward_sync(module, True):
158161
pass
159-
160162
module.no_sync.assert_called_once()
161163

162164

tests/tests_fabric/strategies/test_xla_fsdp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,17 @@ def test_xla_fsdp_no_backward_sync():
5050

5151
with pytest.raises(
5252
TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`"
53-
), strategy._backward_sync_control.no_backward_sync(object()):
53+
), strategy._backward_sync_control.no_backward_sync(object(), True):
5454
pass
5555

5656
module = MagicMock(spec=XlaFullyShardedDataParallel)
57-
with strategy._backward_sync_control.no_backward_sync(module):
57+
58+
with strategy._backward_sync_control.no_backward_sync(module, False):
5859
pass
60+
module.no_sync.assert_not_called()
5961

62+
with strategy._backward_sync_control.no_backward_sync(module, True):
63+
pass
6064
module.no_sync.assert_called_once()
6165

6266

tests/tests_fabric/test_fabric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -767,11 +767,11 @@ def test_no_backward_sync():
767767
# disabling the context manager makes it a no-op
768768
with fabric.no_backward_sync(model, enabled=False):
769769
pass
770-
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
771-
# when enabled, the wrapped module gets passed down
770+
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module, False)
771+
fabric._strategy._backward_sync_control.reset_mock()
772772
with fabric.no_backward_sync(model):
773773
pass
774-
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module)
774+
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module, True)
775775

776776

777777
def test_launch_without_function():

0 commit comments

Comments
 (0)