@@ -104,12 +104,12 @@ def test_checkpoint_callbacks_are_last(tmp_path):
104
104
]
105
105
106
106
107
- class StatefulCallback0 (Callback ):
107
+ class StatefulCallbackContent0 (Callback ):
108
108
def state_dict (self ):
109
109
return {"content0" : 0 }
110
110
111
111
112
- class StatefulCallback1 (Callback ):
112
+ class StatefulCallbackContent1 (Callback ):
113
113
def __init__ (self , unique = None , other = None ):
114
114
self ._unique = unique
115
115
self ._other = other
@@ -126,9 +126,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
126
126
"""Test that all callback states get saved even if the ModelCheckpoint is not given as last and when there are
127
127
multiple callbacks of the same type."""
128
128
129
- callback0 = StatefulCallback0 ()
130
- callback1 = StatefulCallback1 (unique = "one" )
131
- callback2 = StatefulCallback1 (unique = "two" , other = 2 )
129
+ callback0 = StatefulCallbackContent0 ()
130
+ callback1 = StatefulCallbackContent1 (unique = "one" )
131
+ callback2 = StatefulCallbackContent1 (unique = "two" , other = 2 )
132
132
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "all_states" )
133
133
model = BoringModel ()
134
134
trainer = Trainer (
@@ -147,9 +147,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
147
147
trainer .fit (model )
148
148
149
149
ckpt = torch .load (str (tmp_path / "all_states.ckpt" ), weights_only = True )
150
- state0 = ckpt ["callbacks" ]["StatefulCallback0 " ]
151
- state1 = ckpt ["callbacks" ]["StatefulCallback1 {'unique': 'one'}" ]
152
- state2 = ckpt ["callbacks" ]["StatefulCallback1 {'unique': 'two'}" ]
150
+ state0 = ckpt ["callbacks" ]["StatefulCallbackContent0 " ]
151
+ state1 = ckpt ["callbacks" ]["StatefulCallbackContent1 {'unique': 'one'}" ]
152
+ state2 = ckpt ["callbacks" ]["StatefulCallbackContent1 {'unique': 'two'}" ]
153
153
assert "content0" in state0
154
154
assert state0 ["content0" ] == 0
155
155
assert "content1" in state1
@@ -325,72 +325,65 @@ def state_dict(self):
325
325
Trainer (callbacks = [MockCallback (), MockCallback ()])
326
326
327
327
328
- def test_validate_callbacks_list_function ():
329
- """Test the _validate_callbacks_list function directly with various scenarios."""
330
-
331
- # Test with non-stateful callbacks
332
- callback1 = Callback ()
333
- callback2 = Callback ()
334
- _validate_callbacks_list ([callback1 , callback2 ])
335
-
336
- # Test with single stateful callback
337
- class StatefulCallback (Callback ):
338
- def state_dict (self ):
339
- return {"state" : 1 }
340
-
341
- stateful_cb = StatefulCallback ()
342
- _validate_callbacks_list ([stateful_cb ])
343
-
344
- # Test with multiple stateful callbacks with unique state keys
345
- class StatefulCallback1 (Callback ):
346
- @property
347
- def state_key (self ):
348
- return "unique_key_1"
349
-
350
- def state_dict (self ):
351
- return {"state" : 1 }
352
-
353
- class StatefulCallback2 (Callback ):
354
- @property
355
- def state_key (self ):
356
- return "unique_key_2"
357
-
358
- def state_dict (self ):
359
- return {"state" : 2 }
360
-
361
- stateful_cb1 = StatefulCallback1 ()
362
- stateful_cb2 = StatefulCallback2 ()
363
- _validate_callbacks_list ([stateful_cb1 , stateful_cb2 ])
328
+ # Test with single stateful callback
329
+ class StatefulCallback (Callback ):
330
+ def state_dict (self ):
331
+ return {"state" : 1 }
364
332
365
- # Test with multiple stateful callbacks with same state key
366
- class ConflictingCallback (Callback ):
367
- @property
368
- def state_key (self ):
369
- return "same_key "
333
+ # Test with multiple stateful callbacks with unique state keys
334
+ class StatefulCallback1 (Callback ):
335
+ @property
336
+ def state_key (self ):
337
+ return "unique_key_1 "
370
338
371
- def state_dict (self ):
372
- return {"state" : 1 }
339
+ def state_dict (self ):
340
+ return {"state" : 1 }
373
341
374
- conflicting_cb1 = ConflictingCallback ()
375
- conflicting_cb2 = ConflictingCallback ()
342
+ class StatefulCallback2 (Callback ):
343
+ @property
344
+ def state_key (self ):
345
+ return "unique_key_2"
376
346
377
- with pytest .raises (RuntimeError , match = "Found more than one stateful callback of type `ConflictingCallback`" ):
378
- _validate_callbacks_list ([conflicting_cb1 , conflicting_cb2 ])
347
+ def state_dict (self ):
348
+ return {"state" : 2 }
349
+
350
+ @pytest .mark .parametrize (
351
+ ("callbacks" ),
352
+ [[Callback (), Callback ()],
353
+ [StatefulCallback ()],
354
+ [StatefulCallback1 (), StatefulCallback2 ()],
355
+ ],
356
+ )
357
+ def test_validate_callbacks_list_function (callbacks : list ):
358
+ """Test the _validate_callbacks_list function directly with various scenarios."""
359
+ _validate_callbacks_list (callbacks )
379
360
380
- # Test with mix of stateful and non-stateful callbacks where stateful ones conflict
381
- with pytest .raises (RuntimeError , match = "Found more than one stateful callback of type `ConflictingCallback`" ):
382
- _validate_callbacks_list ([callback1 , conflicting_cb1 , callback2 , conflicting_cb2 ])
383
361
384
- # Test with different types of stateful callbacks that happen to have same state key
385
- class AnotherConflictingCallback (Callback ):
386
- @property
387
- def state_key (self ):
388
- return "same_key" # Same key as ConflictingCallback
362
+ # Test with multiple stateful callbacks with same state key
363
+ class ConflictingCallback (Callback ):
364
+ @property
365
+ def state_key (self ):
366
+ return "same_key"
389
367
390
- def state_dict (self ):
391
- return {"state" : 3 }
368
+ def state_dict (self ):
369
+ return {"state" : 1 }
392
370
393
- another_conflicting_cb = AnotherConflictingCallback ()
371
+ # Test with different types of stateful callbacks that happen to have same state key
372
+ class AnotherConflictingCallback (Callback ):
373
+ @property
374
+ def state_key (self ):
375
+ return "same_key" # Same key as ConflictingCallback
394
376
395
- with pytest .raises (RuntimeError , match = "Found more than one stateful callback" ):
396
- _validate_callbacks_list ([conflicting_cb1 , another_conflicting_cb ])
377
+ def state_dict (self ):
378
+ return {"state" : 3 }
379
+ @pytest .mark .parametrize (
380
+ ("callbacks" , "match_msg" ),
381
+ [
382
+ ([ConflictingCallback (), ConflictingCallback ()], "Found more than one stateful callback of type `ConflictingCallback`" ),
383
+ ([Callback (), ConflictingCallback (), Callback (),ConflictingCallback (), ], "Found more than one stateful callback of type `ConflictingCallback`" ),
384
+ ([ConflictingCallback (), AnotherConflictingCallback ()], "Found more than one stateful callback" ),
385
+ ],
386
+ )
387
+ def test_raising_error_validate_callbacks_list_function (callbacks : list , match_msg :str ):
388
+ with pytest .raises (RuntimeError , match = match_msg ):
389
+ _validate_callbacks_list (callbacks )
0 commit comments