Skip to content

Commit 6808186

Browse files
committed
wrap forward methods instead of monkeypatch based redirect
1 parent 4281b58 commit 6808186

File tree

1 file changed

+16
-23
lines changed

1 file changed

+16
-23
lines changed

src/lightning/fabric/wrappers.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -176,28 +176,6 @@ def mark_forward_method(self, method: Union[MethodType, str]) -> None:
176176
)
177177
self._forward_methods.add(name)
178178

179-
def _redirection_through_forward(self, method_name: str) -> Callable:
180-
assert method_name != "forward"
181-
original_forward = self._original_module.forward
182-
183-
def wrapped_forward(*args: Any, **kwargs: Any) -> Any:
184-
# Unpatch ourselves immediately before calling the method `method_name`
185-
# because itself may want to call the real `forward`
186-
self._original_module.forward = original_forward
187-
# Call the actual method e.g. `.training_step(...)`
188-
method = getattr(self._original_module, method_name)
189-
return method(*args, **kwargs)
190-
191-
# We make the caller "unknowingly" send their arguments through the forward_module's `__call__`.
192-
# We expect that the `forward_module` will eventually call `original_module.forward`, which we
193-
# have patched to redirect back to `original_module.method_name()`.
194-
def call_forward_module(*args: Any, **kwargs: Any) -> Any:
195-
# Patch the original_module's forward, so we can redirect the arguments back to the real method
196-
self._original_module.forward = wrapped_forward
197-
return self.forward(*args, **kwargs)
198-
199-
return call_forward_module
200-
201179
def _wrap_method_with_module_call_tracker(self, method: Callable, name: str) -> Callable:
202180
"""Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by
203181
registering forward hooks on all submodules."""
@@ -239,6 +217,21 @@ def _register_backward_hook(self, tensor: Tensor) -> Tensor:
239217
hook = partial(_backward_hook, (strategy_requires or precision_requires))
240218
tensor.register_hook(hook)
241219
return tensor
220+
221+
def wrap_forward_method(self, method: Callable) -> Callable:
222+
@wraps(method)
223+
def wrapper(*args: Any, **kwargs: Any) -> Any:
224+
precision = self._strategy.precision
225+
args, kwargs = precision.convert_input((args, kwargs))
226+
227+
with precision.forward_context():
228+
output = method(*args, **kwargs)
229+
230+
output = precision.convert_output(output)
231+
232+
apply_to_collection(output, dtype=Tensor, function=self._register_backward_hook)
233+
return output
234+
return wrapper
242235

243236
@override
244237
def __getattr__(self, item: Any) -> Any:
@@ -248,7 +241,7 @@ def __getattr__(self, item: Any) -> Any:
248241
and self._forward_module != self._original_module
249242
):
250243
# Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward
251-
return self._redirection_through_forward(item)
244+
return self.wrap_forward_method(getattr(self._original_module, item))
252245

253246
try:
254247
# __getattr__ gets called as a last resort if the attribute does not exist

0 commit comments

Comments
 (0)