|
32 | 32 | )
|
33 | 33 | from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
|
34 | 34 | from lightning.pytorch.demos.boring_classes import BoringModel
|
35 |
| -from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector |
| 35 | +from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector, _validate_callbacks_list |
36 | 36 |
|
37 | 37 |
|
38 | 38 | @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
|
@@ -323,3 +323,74 @@ def state_dict(self):
|
323 | 323 |
|
324 | 324 | with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `MockCallback`"):
|
325 | 325 | Trainer(callbacks=[MockCallback(), MockCallback()])
|
| 326 | + |
| 327 | + |
| 328 | +def test_validate_callbacks_list_function(): |
| 329 | + """Test the _validate_callbacks_list function directly with various scenarios.""" |
| 330 | + |
| 331 | + # Test with non-stateful callbacks |
| 332 | + callback1 = Callback() |
| 333 | + callback2 = Callback() |
| 334 | + _validate_callbacks_list([callback1, callback2]) |
| 335 | + |
| 336 | + # Test with single stateful callback |
| 337 | + class StatefulCallback(Callback): |
| 338 | + def state_dict(self): |
| 339 | + return {"state": 1} |
| 340 | + |
| 341 | + stateful_cb = StatefulCallback() |
| 342 | + _validate_callbacks_list([stateful_cb]) |
| 343 | + |
| 344 | + # Test with multiple stateful callbacks with unique state keys |
| 345 | + class StatefulCallback1(Callback): |
| 346 | + @property |
| 347 | + def state_key(self): |
| 348 | + return "unique_key_1" |
| 349 | + |
| 350 | + def state_dict(self): |
| 351 | + return {"state": 1} |
| 352 | + |
| 353 | + class StatefulCallback2(Callback): |
| 354 | + @property |
| 355 | + def state_key(self): |
| 356 | + return "unique_key_2" |
| 357 | + |
| 358 | + def state_dict(self): |
| 359 | + return {"state": 2} |
| 360 | + |
| 361 | + stateful_cb1 = StatefulCallback1() |
| 362 | + stateful_cb2 = StatefulCallback2() |
| 363 | + _validate_callbacks_list([stateful_cb1, stateful_cb2]) |
| 364 | + |
| 365 | + # Test with multiple stateful callbacks with same state key |
| 366 | + class ConflictingCallback(Callback): |
| 367 | + @property |
| 368 | + def state_key(self): |
| 369 | + return "same_key" |
| 370 | + |
| 371 | + def state_dict(self): |
| 372 | + return {"state": 1} |
| 373 | + |
| 374 | + conflicting_cb1 = ConflictingCallback() |
| 375 | + conflicting_cb2 = ConflictingCallback() |
| 376 | + |
| 377 | + with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `ConflictingCallback`"): |
| 378 | + _validate_callbacks_list([conflicting_cb1, conflicting_cb2]) |
| 379 | + |
| 380 | + # Test with mix of stateful and non-stateful callbacks where stateful ones conflict |
| 381 | + with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `ConflictingCallback`"): |
| 382 | + _validate_callbacks_list([callback1, conflicting_cb1, callback2, conflicting_cb2]) |
| 383 | + |
| 384 | + # Test with different types of stateful callbacks that happen to have same state key |
| 385 | + class AnotherConflictingCallback(Callback): |
| 386 | + @property |
| 387 | + def state_key(self): |
| 388 | + return "same_key" # Same key as ConflictingCallback |
| 389 | + |
| 390 | + def state_dict(self): |
| 391 | + return {"state": 3} |
| 392 | + |
| 393 | + another_conflicting_cb = AnotherConflictingCallback() |
| 394 | + |
| 395 | + with pytest.raises(RuntimeError, match="Found more than one stateful callback"): |
| 396 | + _validate_callbacks_list([conflicting_cb1, another_conflicting_cb]) |
0 commit comments