21
21
from torch .utils .data .dataloader import DataLoader
22
22
23
23
from lightning .fabric .fabric import Fabric
24
+ from lightning .fabric .plugins import Precision
24
25
from lightning .fabric .utilities .device_dtype_mixin import _DeviceDtypeModuleMixin
25
26
from lightning .fabric .wrappers import _FabricDataLoader , _FabricModule , _FabricOptimizer , is_wrapped
26
27
from tests_fabric .helpers .runif import RunIf
@@ -417,9 +418,10 @@ def validation_step(self, arg, kwarg=None):
417
418
def normal_method (self ):
418
419
pass
419
420
421
+ precision = Mock (wraps = Precision ())
420
422
original_module = LightningModule ()
421
423
forward_module = DDP (original_module )
422
- fabric_module = _FabricModule (forward_module = forward_module , precision = Mock () , original_module = original_module )
424
+ fabric_module = _FabricModule (forward_module = forward_module , precision = precision , original_module = original_module )
423
425
424
426
# Regular methods on the original_module are visible and identical on the fabric_module ...
425
427
assert fabric_module .normal_method == original_module .normal_method
@@ -441,6 +443,7 @@ def normal_method(self):
441
443
assert fabric_module .training_step ("train_arg" , kwarg = "train_kwarg" ) == "training_step_return"
442
444
assert fabric_module .training_step ("train_arg" , kwarg = "train_kwarg" ) == "training_step_return" # call 2nd time
443
445
assert fabric_module .validation_step ("val_arg" , kwarg = "val_kwarg" ) == "validation_step_return"
446
+ precision .forward_context .assert_called ()
444
447
445
448
# The forward method remains untouched/unpatched after the special methods have been called
446
449
assert original_module .forward .__name__ == "forward"
0 commit comments