Skip to content

Commit 7a30b6c

Browse files
awaelchlilantiga
authored andcommitted
Avoid false-positive warnings about method calls on the Fabric-wrapped module (#18819)
(cherry picked from commit 97303b0)
1 parent 2285582 commit 7a30b6c

File tree

4 files changed

+62
-36
lines changed

4 files changed

+62
-36
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414

1515
### Changed
1616

17-
-
17+
- Calling a method other than `forward` that invokes submodules is now an error when the model is wrapped (e.g., with DDP) ([#18819](https://github.com/Lightning-AI/lightning/pull/18819))
18+
1819

1920

2021
### Deprecated
@@ -29,7 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2930

3031
### Fixed
3132

32-
-
33+
- Fixed false-positive warnings about method calls on the Fabric-wrapped module ([#18819](https://github.com/Lightning-AI/lightning/pull/18819))
3334

3435

3536
## [2.1.0] - 2023-10-11

src/lightning/fabric/wrappers.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import inspect
15+
from functools import wraps
1516
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, TypeVar, Union, overload
1617

1718
import torch
18-
from lightning_utilities import WarningCache
1919
from lightning_utilities.core.apply_func import apply_to_collection
2020
from torch import Tensor
2121
from torch import nn as nn
@@ -30,9 +30,7 @@
3030
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
3131
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
3232
from lightning.fabric.utilities.types import Optimizable
33-
from lightning.fabric.utilities.warnings import PossibleUserWarning
3433

35-
warning_cache = WarningCache()
3634
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
3735
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")
3836

@@ -161,25 +159,40 @@ def wrapped_forward(*args: Any, **kwargs: Any) -> Any:
161159
# We expect that the `forward_module` will eventually call `original_module.forward`, which we
162160
# have patched to redirect back to `original_module.method_name()`.
163161
def call_forward_module(*args: Any, **kwargs: Any) -> Any:
164-
# Patch the original_module's forward so we can redirect the arguments back to the real method
162+
# Patch the original_module's forward, so we can redirect the arguments back to the real method
165163
self._original_module.forward = wrapped_forward
166164
return self.forward(*args, **kwargs)
167165

168166
return call_forward_module
169167

170-
def _validate_method_access(self, name: str, attribute: Any) -> None:
171-
if (
172-
inspect.ismethod(attribute)
173-
and inspect.signature(attribute).parameters
174-
and self._forward_module != self._original_module
175-
):
176-
warning_cache.warn(
177-
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
178-
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
179-
" `.backward()`. You should pass your inputs through"
180-
f" `{type(self._original_module).__name__}.forward()`.",
181-
category=PossibleUserWarning,
182-
)
168+
def _wrap_method_with_module_call_tracker(self, method: Callable, name: str) -> Callable:
169+
"""Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by
170+
registering forward hooks on all submodules."""
171+
module_called = False
172+
173+
def hook(*_: Any, **__: Any) -> None:
174+
nonlocal module_called
175+
module_called = True
176+
177+
@wraps(method)
178+
def _wrapped_method(*args: Any, **kwargs: Any) -> Any:
179+
handles = []
180+
for module in self._original_module.modules():
181+
handles.append(module.register_forward_hook(hook))
182+
183+
output = method(*args, **kwargs)
184+
185+
if module_called:
186+
raise RuntimeError(
187+
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
188+
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
189+
" `.backward()`. You should pass your inputs through `forward()`.",
190+
)
191+
for handle in handles:
192+
handle.remove()
193+
return output
194+
195+
return _wrapped_method
183196

184197
def __getattr__(self, item: Any) -> Any:
185198
if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
@@ -194,7 +207,9 @@ def __getattr__(self, item: Any) -> Any:
194207
# If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
195208
original_module = super().__getattr__("_original_module")
196209
attr = getattr(original_module, item)
197-
self._validate_method_access(item, attr)
210+
211+
if inspect.ismethod(attr) and self._forward_module != self._original_module:
212+
attr = self._wrap_method_with_module_call_tracker(attr, item)
198213
return attr
199214

200215
def __setattr__(self, name: str, value: Any) -> None:

tests/tests_fabric/test_wrappers.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
_FabricOptimizer,
2727
_unwrap_objects,
2828
is_wrapped,
29-
warning_cache,
3029
)
31-
from lightning_utilities.test.warning import no_warning_call
3230
from torch.utils.data import BatchSampler, DistributedSampler
3331
from torch.utils.data.dataloader import DataLoader
3432

@@ -79,12 +77,24 @@ def test_fabric_module_method_lookup():
7977
"""Test that access to methods warns about improper use when a wrapper from a strategy is involved."""
8078

8179
class OriginalModule(torch.nn.Module):
82-
def method_no_args(self):
80+
def __init__(self):
81+
super().__init__()
82+
self.submodule = torch.nn.Linear(2, 3)
83+
84+
def forward(self, x):
85+
return x
86+
87+
def method_without_module_invocation(self):
8388
return 100
8489

85-
def method_with_args(self, arg, kwarg=1):
90+
def method_with_submodule_invocation(self):
91+
self.submodule(torch.rand(2, 2))
8692
return 101
8793

94+
def method_with_self_invocation(self):
95+
self(None)
96+
return 102
97+
8898
class ModuleWrapper(torch.nn.Module):
8999
def __init__(self, module):
90100
super().__init__()
@@ -93,21 +103,21 @@ def __init__(self, module):
93103
# Regular case: forward_module == original_module -> no warnings
94104
original_module = OriginalModule()
95105
fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module)
96-
warning_cache.clear()
97-
with no_warning_call(UserWarning):
98-
assert fabric_module.method_with_args(0) == 101
99-
assert not warning_cache
106+
assert fabric_module.method_without_module_invocation() == 100
100107

101108
# Special case: original module wrapped by forward module: -> warn if method accepts args
102109
original_module = OriginalModule()
103110
wrapped_module = ModuleWrapper(original_module)
104111
fabric_module = _FabricModule(forward_module=wrapped_module, precision=Mock(), original_module=original_module)
105-
warning_cache.clear()
106-
with no_warning_call(UserWarning):
107-
assert fabric_module.method_no_args() == 100
108-
with pytest.warns(UserWarning, match=r"You are calling the method `OriginalModule.method_with_args\(\)` from"):
109-
assert fabric_module.method_with_args(0) == 101
110-
warning_cache.clear()
112+
assert fabric_module.method_without_module_invocation() == 100
113+
with pytest.raises(
114+
RuntimeError, match=r"You are calling the method `OriginalModule.method_with_submodule_invocation\(\)` from"
115+
):
116+
assert fabric_module.method_with_submodule_invocation() == 101
117+
with pytest.raises(
118+
RuntimeError, match=r"You are calling the method `OriginalModule.method_with_self_invocation\(\)` from"
119+
):
120+
assert fabric_module.method_with_self_invocation() == 102
111121

112122

113123
def test_fabric_module_setattr():
@@ -555,7 +565,7 @@ def normal_method(self):
555565
fabric_module = _FabricModule(forward_module=forward_module, precision=precision, original_module=original_module)
556566

557567
# Regular methods on the original_module are visible and identical on the fabric_module ...
558-
assert fabric_module.normal_method == original_module.normal_method
568+
assert fabric_module.normal_method.__wrapped__ == original_module.normal_method
559569

560570
# ... but special methods like training_step get redirected to the forward_module
561571
assert fabric_module.training_step.__name__ == "call_forward_module"

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1488,7 +1488,7 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
14881488
callback = ModelCheckpoint(dirpath=first, monitor="step", mode="max", save_top_k=2, every_n_train_steps=2)
14891489
trainer = Trainer(callbacks=callback, max_steps=5, **trainer_kwargs)
14901490
trainer.fit(model)
1491-
assert os.listdir(first) == ["epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"]
1491+
assert set(os.listdir(first)) == {"epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"}
14921492

14931493
# Continue training from checkpoint
14941494
callback = ModelCheckpoint(dirpath=new_dirpath, monitor="step", mode="max", save_top_k=2, every_n_train_steps=2)

0 commit comments

Comments
 (0)