Skip to content

Commit 73a65b6

Browse files
committed
add testing
1 parent dbc6e96 commit 73a65b6

File tree

1 file changed

+54
-1
lines changed

1 file changed

+54
-1
lines changed

tests/tests_fabric/test_fabric.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
import os
1516
import warnings
1617
from contextlib import nullcontext
@@ -20,7 +21,6 @@
2021

2122
import pytest
2223
import torch
23-
import torch.distributed
2424
import torch.nn.functional
2525
from lightning_utilities.test.warning import no_warning_call
2626
from torch import nn
@@ -1294,3 +1294,56 @@ def test_verify_launch_called():
12941294
fabric.launch()
12951295
assert fabric._launched
12961296
fabric._validate_launched()
1297+
1298+
1299+
def test_callback_kwargs_filtering():
1300+
"""Test that callbacks receive only the kwargs they can handle based on their signature."""
1301+
1302+
class CallbackWithLimitedKwargs:
1303+
def on_train_epoch_end(self, epoch: int):
1304+
self.epoch = epoch
1305+
1306+
class CallbackWithVarKeywords:
1307+
def on_train_epoch_end(self, epoch: int, **kwargs):
1308+
self.epoch = epoch
1309+
self.kwargs = kwargs
1310+
1311+
class CallbackWithNoParams:
1312+
def on_train_epoch_end(self):
1313+
self.called = True
1314+
1315+
callback1 = CallbackWithLimitedKwargs()
1316+
callback2 = CallbackWithVarKeywords()
1317+
callback3 = CallbackWithNoParams()
1318+
fabric = Fabric(callbacks=[callback1, callback2, callback3])
1319+
fabric.call("on_train_epoch_end", epoch=5, loss=0.1, metrics={"acc": 0.9})
1320+
1321+
assert callback1.epoch == 5
1322+
assert not hasattr(callback1, "loss")
1323+
assert callback2.epoch == 5
1324+
assert callback2.kwargs == {"loss": 0.1, "metrics": {"acc": 0.9}}
1325+
assert callback3.called is True
1326+
1327+
1328+
def test_callback_kwargs_filtering_signature_inspection_failure():
1329+
"""Test behavior when signature inspection fails - should fallback to passing all kwargs."""
1330+
callback = Mock()
1331+
fabric = Fabric(callbacks=[callback])
1332+
original_signature = inspect.signature
1333+
1334+
def mock_signature(obj):
1335+
if hasattr(obj, "_mock_name") or hasattr(obj, "_mock_new_name"):
1336+
raise ValueError("Cannot inspect mock signature")
1337+
return original_signature(obj)
1338+
1339+
# Temporarily replace signature function in fabric module
1340+
import lightning.fabric.fabric
1341+
1342+
lightning.fabric.fabric.inspect.signature = mock_signature
1343+
1344+
try:
1345+
# Should still work by passing all kwargs when signature inspection fails
1346+
fabric.call("on_test_hook", arg1="value1", arg2="value2")
1347+
callback.on_test_hook.assert_called_with(arg1="value1", arg2="value2")
1348+
finally:
1349+
lightning.fabric.fabric.inspect.signature = original_signature

0 commit comments

Comments
 (0)