diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index b593c9f22ed23..2b0669690caa9 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -176,28 +176,6 @@ def mark_forward_method(self, method: Union[MethodType, str]) -> None: ) self._forward_methods.add(name) - def _redirection_through_forward(self, method_name: str) -> Callable: - assert method_name != "forward" - original_forward = self._original_module.forward - - def wrapped_forward(*args: Any, **kwargs: Any) -> Any: - # Unpatch ourselves immediately before calling the method `method_name` - # because itself may want to call the real `forward` - self._original_module.forward = original_forward - # Call the actual method e.g. `.training_step(...)` - method = getattr(self._original_module, method_name) - return method(*args, **kwargs) - - # We make the caller "unknowingly" send their arguments through the forward_module's `__call__`. - # We expect that the `forward_module` will eventually call `original_module.forward`, which we - # have patched to redirect back to `original_module.method_name()`. - def call_forward_module(*args: Any, **kwargs: Any) -> Any: - # Patch the original_module's forward, so we can redirect the arguments back to the real method - self._original_module.forward = wrapped_forward - return self.forward(*args, **kwargs) - - return call_forward_module - def _wrap_method_with_module_call_tracker(self, method: Callable, name: str) -> Callable: """Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by registering forward hooks on all submodules.""" @@ -240,6 +218,22 @@ def _register_backward_hook(self, tensor: Tensor) -> Tensor: tensor.register_hook(hook) return tensor + def wrap_forward_method(self, method: Callable) -> Callable: + @wraps(method) + def wrapper(*args: Any, **kwargs: Any) -> Any: + precision = self._strategy.precision + args, kwargs = precision.convert_input((args, kwargs)) + + with precision.forward_context(): + output = method(*args, **kwargs) + + output = precision.convert_output(output) + + apply_to_collection(output, dtype=Tensor, function=self._register_backward_hook) + return output + + return wrapper + @override def __getattr__(self, item: Any) -> Any: if ( @@ -248,7 +242,7 @@ def __getattr__(self, item: Any) -> Any: and self._forward_module != self._original_module ): # Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward - return self._redirection_through_forward(item) + return self.wrap_forward_method(getattr(self._original_module, item)) try: # __getattr__ gets called as a last resort if the attribute does not exist