We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6808186 commit 0798ef7Copy full SHA for 0798ef7
src/lightning/fabric/wrappers.py
@@ -217,7 +217,7 @@ def _register_backward_hook(self, tensor: Tensor) -> Tensor:
217
hook = partial(_backward_hook, (strategy_requires or precision_requires))
218
tensor.register_hook(hook)
219
return tensor
220
-
+
221
def wrap_forward_method(self, method: Callable) -> Callable:
222
@wraps(method)
223
def wrapper(*args: Any, **kwargs: Any) -> Any:
@@ -231,6 +231,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
231
232
apply_to_collection(output, dtype=Tensor, function=self._register_backward_hook)
233
return output
234
235
return wrapper
236
237
@override
0 commit comments