3232)
3333from lightning .pytorch .callbacks .batch_size_finder import BatchSizeFinder
3434from lightning .pytorch .demos .boring_classes import BoringModel
35- from lightning .pytorch .trainer .connectors .callback_connector import _CallbackConnector
35+ from lightning .pytorch .trainer .connectors .callback_connector import _CallbackConnector , _validate_callbacks_list
3636
3737
3838@patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False )
@@ -104,12 +104,12 @@ def test_checkpoint_callbacks_are_last(tmp_path):
104104 ]
105105
106106
107- class StatefulCallback0 (Callback ):
107+ class StatefulCallbackContent0 (Callback ):
108108 def state_dict (self ):
109109 return {"content0" : 0 }
110110
111111
112- class StatefulCallback1 (Callback ):
112+ class StatefulCallbackContent1 (Callback ):
113113 def __init__ (self , unique = None , other = None ):
114114 self ._unique = unique
115115 self ._other = other
@@ -126,9 +126,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
126126 """Test that all callback states get saved even if the ModelCheckpoint is not given as last and when there are
127127 multiple callbacks of the same type."""
128128
129- callback0 = StatefulCallback0 ()
130- callback1 = StatefulCallback1 (unique = "one" )
131- callback2 = StatefulCallback1 (unique = "two" , other = 2 )
129+ callback0 = StatefulCallbackContent0 ()
130+ callback1 = StatefulCallbackContent1 (unique = "one" )
131+ callback2 = StatefulCallbackContent1 (unique = "two" , other = 2 )
132132 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "all_states" )
133133 model = BoringModel ()
134134 trainer = Trainer (
@@ -147,9 +147,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
147147 trainer .fit (model )
148148
149149 ckpt = torch .load (str (tmp_path / "all_states.ckpt" ), weights_only = True )
150- state0 = ckpt ["callbacks" ]["StatefulCallback0 " ]
151- state1 = ckpt ["callbacks" ]["StatefulCallback1 {'unique': 'one'}" ]
152- state2 = ckpt ["callbacks" ]["StatefulCallback1 {'unique': 'two'}" ]
150+ state0 = ckpt ["callbacks" ]["StatefulCallbackContent0 " ]
151+ state1 = ckpt ["callbacks" ]["StatefulCallbackContent1 {'unique': 'one'}" ]
152+ state2 = ckpt ["callbacks" ]["StatefulCallbackContent1 {'unique': 'two'}" ]
153153 assert "content0" in state0
154154 assert state0 ["content0" ] == 0
155155 assert "content1" in state1
@@ -323,3 +323,80 @@ def state_dict(self):
323323
324324 with pytest .raises (RuntimeError , match = "Found more than one stateful callback of type `MockCallback`" ):
325325 Trainer (callbacks = [MockCallback (), MockCallback ()])
326+
327+
328+ # Test with single stateful callback
329+ class StatefulCallback (Callback ):
330+ def state_dict (self ):
331+ return {"state" : 1 }
332+
333+
334+ # Test with multiple stateful callbacks with unique state keys
335+ class StatefulCallback1 (Callback ):
336+ @property
337+ def state_key (self ):
338+ return "unique_key_1"
339+
340+ def state_dict (self ):
341+ return {"state" : 1 }
342+
343+
344+ class StatefulCallback2 (Callback ):
345+ @property
346+ def state_key (self ):
347+ return "unique_key_2"
348+
349+ def state_dict (self ):
350+ return {"state" : 2 }
351+
352+
353+ @pytest .mark .parametrize (
354+ ("callbacks" ),
355+ [
356+ [Callback (), Callback ()],
357+ [StatefulCallback ()],
358+ [StatefulCallback1 (), StatefulCallback2 ()],
359+ ],
360+ )
361+ def test_validate_callbacks_list_function (callbacks : list ):
362+ """Test the _validate_callbacks_list function directly with various scenarios."""
363+ _validate_callbacks_list (callbacks )
364+
365+
366+ # Test with multiple stateful callbacks with same state key
367+ class ConflictingCallback (Callback ):
368+ @property
369+ def state_key (self ):
370+ return "same_key"
371+
372+ def state_dict (self ):
373+ return {"state" : 1 }
374+
375+
376+ # Test with different types of stateful callbacks that happen to have same state key
377+ class AnotherConflictingCallback (Callback ):
378+ @property
379+ def state_key (self ):
380+ return "same_key" # Same key as ConflictingCallback
381+
382+ def state_dict (self ):
383+ return {"state" : 3 }
384+
385+
386+ @pytest .mark .parametrize (
387+ ("callbacks" , "match_msg" ),
388+ [
389+ (
390+ [ConflictingCallback (), ConflictingCallback ()],
391+ "Found more than one stateful callback of type `ConflictingCallback`" ,
392+ ),
393+ (
394+ [ConflictingCallback (), Callback (), ConflictingCallback ()],
395+ "Found more than one stateful callback of type `ConflictingCallback`" ,
396+ ),
397+ ([ConflictingCallback (), AnotherConflictingCallback ()], "Found more than one stateful callback" ),
398+ ],
399+ )
400+ def test_raising_error_validate_callbacks_list_function (callbacks : list , match_msg : str ):
401+ with pytest .raises (RuntimeError , match = match_msg ):
402+ _validate_callbacks_list (callbacks )
0 commit comments