Skip to content

Commit 8c72438

Browse files
carmoccalantiga
authored andcommitted
Remove warning on no_backward_sync with XLA strategy (#17761)
(cherry picked from commit f3c49b8)
1 parent ddfd5fe commit 8c72438

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111
- Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn"` and `accelerator="cpu"`; this has a necessary memory impact, as parameters are replicated for each process now ([#18238](https://github.com/Lightning-AI/lightning/pull/18238))
1212

1313

14+
- Removed false positive warning when using `fabric.no_backward_sync` with XLA strategies ([#17761](https://github.com/Lightning-AI/lightning/pull/17761))
15+
16+
1417
## [2.0.7] - 2023-08-14
1518

1619
### Changed

src/lightning/fabric/fabric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Gener
581581
"You need to set up the model first before you can call `self.no_backward_sync()`:"
582582
" `model = self.setup(model, ...)`"
583583
)
584-
if not enabled or isinstance(self._strategy, SingleDeviceStrategy):
584+
if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)):
585585
context = nullcontext()
586586
elif self._strategy._backward_sync_control is None:
587587
rank_zero_warn(

tests/tests_fabric/test_fabric.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,14 +625,19 @@ def test_no_backward_sync():
625625
with fabric.no_backward_sync(model):
626626
pass
627627
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
628+
# same for XLA
629+
fabric._strategy = Mock(spec=XLAStrategy, _backward_sync_control=MagicMock())
630+
with fabric.no_backward_sync(model):
631+
pass
632+
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
628633

629634
# pretend that the strategy supports skipping backward sync
630635
fabric._strategy = Mock(_backward_sync_control=MagicMock())
631636
# disabling the context manager makes it a no-op
632637
with fabric.no_backward_sync(model, enabled=False):
633638
pass
634639
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
635-
# when enabld, the wrapped module gets passed down
640+
# when enabled, the wrapped module gets passed down
636641
with fabric.no_backward_sync(model):
637642
pass
638643
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module)

0 commit comments

Comments
 (0)