2121from torch .utils .data .dataloader import DataLoader
2222
2323from lightning .fabric .fabric import Fabric
24+ from lightning .fabric .plugins import Precision
2425from lightning .fabric .utilities .device_dtype_mixin import _DeviceDtypeModuleMixin
2526from lightning .fabric .wrappers import _FabricDataLoader , _FabricModule , _FabricOptimizer , is_wrapped
2627from tests_fabric .helpers .runif import RunIf
@@ -417,9 +418,10 @@ def validation_step(self, arg, kwarg=None):
417418 def normal_method (self ):
418419 pass
419420
421+ precision = Mock (wraps = Precision ())
420422 original_module = LightningModule ()
421423 forward_module = DDP (original_module )
422- fabric_module = _FabricModule (forward_module = forward_module , precision = Mock () , original_module = original_module )
424+ fabric_module = _FabricModule (forward_module = forward_module , precision = precision , original_module = original_module )
423425
424426 # Regular methods on the original_module are visible and identical on the fabric_module ...
425427 assert fabric_module .normal_method == original_module .normal_method
@@ -441,6 +443,7 @@ def normal_method(self):
441443 assert fabric_module .training_step ("train_arg" , kwarg = "train_kwarg" ) == "training_step_return"
442444 assert fabric_module .training_step ("train_arg" , kwarg = "train_kwarg" ) == "training_step_return" # call 2nd time
443445 assert fabric_module .validation_step ("val_arg" , kwarg = "val_kwarg" ) == "validation_step_return"
446+ precision .forward_context .assert_called ()
444447
445448 # The forward method remains untouched/unpatched after the special methods have been called
446449 assert original_module .forward .__name__ == "forward"
0 commit comments