@@ -176,28 +176,6 @@ def mark_forward_method(self, method: Union[MethodType, str]) -> None:
176
176
)
177
177
self ._forward_methods .add (name )
178
178
179
- def _redirection_through_forward (self , method_name : str ) -> Callable :
180
- assert method_name != "forward"
181
- original_forward = self ._original_module .forward
182
-
183
- def wrapped_forward (* args : Any , ** kwargs : Any ) -> Any :
184
- # Unpatch ourselves immediately before calling the method `method_name`
185
- # because itself may want to call the real `forward`
186
- self ._original_module .forward = original_forward
187
- # Call the actual method e.g. `.training_step(...)`
188
- method = getattr (self ._original_module , method_name )
189
- return method (* args , ** kwargs )
190
-
191
- # We make the caller "unknowingly" send their arguments through the forward_module's `__call__`.
192
- # We expect that the `forward_module` will eventually call `original_module.forward`, which we
193
- # have patched to redirect back to `original_module.method_name()`.
194
- def call_forward_module (* args : Any , ** kwargs : Any ) -> Any :
195
- # Patch the original_module's forward, so we can redirect the arguments back to the real method
196
- self ._original_module .forward = wrapped_forward
197
- return self .forward (* args , ** kwargs )
198
-
199
- return call_forward_module
200
-
201
179
def _wrap_method_with_module_call_tracker (self , method : Callable , name : str ) -> Callable :
202
180
"""Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by
203
181
registering forward hooks on all submodules."""
@@ -239,6 +217,21 @@ def _register_backward_hook(self, tensor: Tensor) -> Tensor:
239
217
hook = partial (_backward_hook , (strategy_requires or precision_requires ))
240
218
tensor .register_hook (hook )
241
219
return tensor
220
+
221
+ def wrap_forward_method (self , method : Callable ) -> Callable :
222
+ @wraps (method )
223
+ def wrapper (* args : Any , ** kwargs : Any ) -> Any :
224
+ precision = self ._strategy .precision
225
+ args , kwargs = precision .convert_input ((args , kwargs ))
226
+
227
+ with precision .forward_context ():
228
+ output = method (* args , ** kwargs )
229
+
230
+ output = precision .convert_output (output )
231
+
232
+ apply_to_collection (output , dtype = Tensor , function = self ._register_backward_hook )
233
+ return output
234
+ return wrapper
242
235
243
236
@override
244
237
def __getattr__ (self , item : Any ) -> Any :
@@ -248,7 +241,7 @@ def __getattr__(self, item: Any) -> Any:
248
241
and self ._forward_module != self ._original_module
249
242
):
250
243
# Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward
251
- return self ._redirection_through_forward ( item )
244
+ return self .wrap_forward_method ( getattr ( self . _original_module , item ) )
252
245
253
246
try :
254
247
# __getattr__ gets called as a last resort if the attribute does not exist
0 commit comments