Skip to content

Commit 7073b1f

Browse files
awaelchlilantiga
authored andcommitted
Enable precision autocast for LightningModule step methods in Fabric (#17439)
(cherry picked from commit d9b4ebd)
1 parent 9dce034 commit 7073b1f

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Changed
1111

12-
-
12+
13+
- Enable precision autocast for LightningModule step methods in Fabric ([#17439](https://github.com/Lightning-AI/lightning/pull/17439))
1314

1415

1516
### Fixed

src/lightning/fabric/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def wrapped_forward(*args: Any, **kwargs: Any) -> Any:
154154
def call_forward_module(*args: Any, **kwargs: Any) -> Any:
155155
# Patch the original_module's forward so we can redirect the arguments back to the real method
156156
self._original_module.forward = wrapped_forward
157-
return self._forward_module(*args, **kwargs)
157+
return self.forward(*args, **kwargs)
158158

159159
return call_forward_module
160160

tests/tests_fabric/test_wrappers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torch.utils.data.dataloader import DataLoader
2222

2323
from lightning.fabric.fabric import Fabric
24+
from lightning.fabric.plugins import Precision
2425
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
2526
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer, is_wrapped
2627
from tests_fabric.helpers.runif import RunIf
@@ -417,9 +418,10 @@ def validation_step(self, arg, kwarg=None):
417418
def normal_method(self):
418419
pass
419420

421+
precision = Mock(wraps=Precision())
420422
original_module = LightningModule()
421423
forward_module = DDP(original_module)
422-
fabric_module = _FabricModule(forward_module=forward_module, precision=Mock(), original_module=original_module)
424+
fabric_module = _FabricModule(forward_module=forward_module, precision=precision, original_module=original_module)
423425

424426
# Regular methods on the original_module are visible and identical on the fabric_module ...
425427
assert fabric_module.normal_method == original_module.normal_method
@@ -441,6 +443,7 @@ def normal_method(self):
441443
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"
442444
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" # call 2nd time
443445
assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return"
446+
precision.forward_context.assert_called()
444447

445448
# The forward method remains untouched/unpatched after the special methods have been called
446449
assert original_module.forward.__name__ == "forward"

0 commit comments

Comments
 (0)