32
32
)
33
33
from lightning .pytorch .callbacks .batch_size_finder import BatchSizeFinder
34
34
from lightning .pytorch .demos .boring_classes import BoringModel
35
- from lightning .pytorch .trainer .connectors .callback_connector import _CallbackConnector
35
+ from lightning .pytorch .trainer .connectors .callback_connector import _CallbackConnector , _validate_callbacks_list
36
36
37
37
38
38
@patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False )
@@ -104,12 +104,12 @@ def test_checkpoint_callbacks_are_last(tmp_path):
104
104
]
105
105
106
106
107
- class StatefulCallback0 (Callback ):
107
+ class StatefulCallbackContent0 (Callback ):
108
108
def state_dict (self ):
109
109
return {"content0" : 0 }
110
110
111
111
112
- class StatefulCallback1 (Callback ):
112
+ class StatefulCallbackContent1 (Callback ):
113
113
def __init__ (self , unique = None , other = None ):
114
114
self ._unique = unique
115
115
self ._other = other
@@ -126,9 +126,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
126
126
"""Test that all callback states get saved even if the ModelCheckpoint is not given as last and when there are
127
127
multiple callbacks of the same type."""
128
128
129
- callback0 = StatefulCallback0 ()
130
- callback1 = StatefulCallback1 (unique = "one" )
131
- callback2 = StatefulCallback1 (unique = "two" , other = 2 )
129
+ callback0 = StatefulCallbackContent0 ()
130
+ callback1 = StatefulCallbackContent1 (unique = "one" )
131
+ callback2 = StatefulCallbackContent1 (unique = "two" , other = 2 )
132
132
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "all_states" )
133
133
model = BoringModel ()
134
134
trainer = Trainer (
@@ -147,9 +147,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
147
147
trainer .fit (model )
148
148
149
149
ckpt = torch .load (str (tmp_path / "all_states.ckpt" ), weights_only = True )
150
- state0 = ckpt ["callbacks" ]["StatefulCallback0 " ]
151
- state1 = ckpt ["callbacks" ]["StatefulCallback1 {'unique': 'one'}" ]
152
- state2 = ckpt ["callbacks" ]["StatefulCallback1 {'unique': 'two'}" ]
150
+ state0 = ckpt ["callbacks" ]["StatefulCallbackContent0 " ]
151
+ state1 = ckpt ["callbacks" ]["StatefulCallbackContent1 {'unique': 'one'}" ]
152
+ state2 = ckpt ["callbacks" ]["StatefulCallbackContent1 {'unique': 'two'}" ]
153
153
assert "content0" in state0
154
154
assert state0 ["content0" ] == 0
155
155
assert "content1" in state1
@@ -323,3 +323,80 @@ def state_dict(self):
323
323
324
324
with pytest .raises (RuntimeError , match = "Found more than one stateful callback of type `MockCallback`" ):
325
325
Trainer (callbacks = [MockCallback (), MockCallback ()])
326
+
327
+
328
+ # Test with single stateful callback
329
+ class StatefulCallback (Callback ):
330
+ def state_dict (self ):
331
+ return {"state" : 1 }
332
+
333
+
334
+ # Test with multiple stateful callbacks with unique state keys
335
+ class StatefulCallback1 (Callback ):
336
+ @property
337
+ def state_key (self ):
338
+ return "unique_key_1"
339
+
340
+ def state_dict (self ):
341
+ return {"state" : 1 }
342
+
343
+
344
+ class StatefulCallback2 (Callback ):
345
+ @property
346
+ def state_key (self ):
347
+ return "unique_key_2"
348
+
349
+ def state_dict (self ):
350
+ return {"state" : 2 }
351
+
352
+
353
+ @pytest .mark .parametrize (
354
+ ("callbacks" ),
355
+ [
356
+ [Callback (), Callback ()],
357
+ [StatefulCallback ()],
358
+ [StatefulCallback1 (), StatefulCallback2 ()],
359
+ ],
360
+ )
361
+ def test_validate_callbacks_list_function (callbacks : list ):
362
+ """Test the _validate_callbacks_list function directly with various scenarios."""
363
+ _validate_callbacks_list (callbacks )
364
+
365
+
366
+ # Test with multiple stateful callbacks with same state key
367
+ class ConflictingCallback (Callback ):
368
+ @property
369
+ def state_key (self ):
370
+ return "same_key"
371
+
372
+ def state_dict (self ):
373
+ return {"state" : 1 }
374
+
375
+
376
+ # Test with different types of stateful callbacks that happen to have same state key
377
+ class AnotherConflictingCallback (Callback ):
378
+ @property
379
+ def state_key (self ):
380
+ return "same_key" # Same key as ConflictingCallback
381
+
382
+ def state_dict (self ):
383
+ return {"state" : 3 }
384
+
385
+
386
+ @pytest .mark .parametrize (
387
+ ("callbacks" , "match_msg" ),
388
+ [
389
+ (
390
+ [ConflictingCallback (), ConflictingCallback ()],
391
+ "Found more than one stateful callback of type `ConflictingCallback`" ,
392
+ ),
393
+ (
394
+ [ConflictingCallback (), Callback (), ConflictingCallback ()],
395
+ "Found more than one stateful callback of type `ConflictingCallback`" ,
396
+ ),
397
+ ([ConflictingCallback (), AnotherConflictingCallback ()], "Found more than one stateful callback" ),
398
+ ],
399
+ )
400
+ def test_raising_error_validate_callbacks_list_function (callbacks : list , match_msg : str ):
401
+ with pytest .raises (RuntimeError , match = match_msg ):
402
+ _validate_callbacks_list (callbacks )
0 commit comments