|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import inspect |
14 | 15 | import os
|
15 | 16 | import warnings
|
16 | 17 | from contextlib import nullcontext
|
|
20 | 21 |
|
21 | 22 | import pytest
|
22 | 23 | import torch
|
23 |
| -import torch.distributed |
24 | 24 | import torch.nn.functional
|
25 | 25 | from lightning_utilities.test.warning import no_warning_call
|
26 | 26 | from torch import nn
|
@@ -1294,3 +1294,56 @@ def test_verify_launch_called():
|
1294 | 1294 | fabric.launch()
|
1295 | 1295 | assert fabric._launched
|
1296 | 1296 | fabric._validate_launched()
|
| 1297 | + |
| 1298 | + |
| 1299 | +def test_callback_kwargs_filtering(): |
| 1300 | + """Test that callbacks receive only the kwargs they can handle based on their signature.""" |
| 1301 | + |
| 1302 | + class CallbackWithLimitedKwargs: |
| 1303 | + def on_train_epoch_end(self, epoch: int): |
| 1304 | + self.epoch = epoch |
| 1305 | + |
| 1306 | + class CallbackWithVarKeywords: |
| 1307 | + def on_train_epoch_end(self, epoch: int, **kwargs): |
| 1308 | + self.epoch = epoch |
| 1309 | + self.kwargs = kwargs |
| 1310 | + |
| 1311 | + class CallbackWithNoParams: |
| 1312 | + def on_train_epoch_end(self): |
| 1313 | + self.called = True |
| 1314 | + |
| 1315 | + callback1 = CallbackWithLimitedKwargs() |
| 1316 | + callback2 = CallbackWithVarKeywords() |
| 1317 | + callback3 = CallbackWithNoParams() |
| 1318 | + fabric = Fabric(callbacks=[callback1, callback2, callback3]) |
| 1319 | + fabric.call("on_train_epoch_end", epoch=5, loss=0.1, metrics={"acc": 0.9}) |
| 1320 | + |
| 1321 | + assert callback1.epoch == 5 |
| 1322 | + assert not hasattr(callback1, "loss") |
| 1323 | + assert callback2.epoch == 5 |
| 1324 | + assert callback2.kwargs == {"loss": 0.1, "metrics": {"acc": 0.9}} |
| 1325 | + assert callback3.called is True |
| 1326 | + |
| 1327 | + |
| 1328 | +def test_callback_kwargs_filtering_signature_inspection_failure(): |
| 1329 | + """Test behavior when signature inspection fails - should fallback to passing all kwargs.""" |
| 1330 | + callback = Mock() |
| 1331 | + fabric = Fabric(callbacks=[callback]) |
| 1332 | + original_signature = inspect.signature |
| 1333 | + |
| 1334 | + def mock_signature(obj): |
| 1335 | + if hasattr(obj, "_mock_name") or hasattr(obj, "_mock_new_name"): |
| 1336 | + raise ValueError("Cannot inspect mock signature") |
| 1337 | + return original_signature(obj) |
| 1338 | + |
| 1339 | + # Temporarily replace signature function in fabric module |
| 1340 | + import lightning.fabric.fabric |
| 1341 | + |
| 1342 | + lightning.fabric.fabric.inspect.signature = mock_signature |
| 1343 | + |
| 1344 | + try: |
| 1345 | + # Should still work by passing all kwargs when signature inspection fails |
| 1346 | + fabric.call("on_test_hook", arg1="value1", arg2="value2") |
| 1347 | + callback.on_test_hook.assert_called_with(arg1="value1", arg2="value2") |
| 1348 | + finally: |
| 1349 | + lightning.fabric.fabric.inspect.signature = original_signature |
0 commit comments