@@ -330,6 +330,7 @@ class StatefulCallback(Callback):
330
330
def state_dict (self ):
331
331
return {"state" : 1 }
332
332
333
+
333
334
# Test with multiple stateful callbacks with unique state keys
334
335
class StatefulCallback1 (Callback ):
335
336
@property
@@ -339,6 +340,7 @@ def state_key(self):
339
340
def state_dict (self ):
340
341
return {"state" : 1 }
341
342
343
+
342
344
class StatefulCallback2 (Callback ):
343
345
@property
344
346
def state_key (self ):
@@ -347,12 +349,14 @@ def state_key(self):
347
349
def state_dict (self ):
348
350
return {"state" : 2 }
349
351
352
+
350
353
@pytest .mark .parametrize (
351
354
("callbacks" ),
352
- [[Callback (), Callback ()],
353
- [StatefulCallback ()],
354
- [StatefulCallback1 (), StatefulCallback2 ()],
355
- ],
355
+ [
356
+ [Callback (), Callback ()],
357
+ [StatefulCallback ()],
358
+ [StatefulCallback1 (), StatefulCallback2 ()],
359
+ ],
356
360
)
357
361
def test_validate_callbacks_list_function (callbacks : list ):
358
362
"""Test the _validate_callbacks_list function directly with various scenarios."""
@@ -368,6 +372,7 @@ def state_key(self):
368
372
def state_dict (self ):
369
373
return {"state" : 1 }
370
374
375
+
371
376
# Test with different types of stateful callbacks that happen to have same state key
372
377
class AnotherConflictingCallback (Callback ):
373
378
@property
@@ -376,14 +381,27 @@ def state_key(self):
376
381
377
382
def state_dict (self ):
378
383
return {"state" : 3 }
384
+
385
+
379
386
@pytest .mark .parametrize (
380
387
("callbacks" , "match_msg" ),
381
388
[
382
- ([ConflictingCallback (), ConflictingCallback ()], "Found more than one stateful callback of type `ConflictingCallback`" ),
383
- ([Callback (), ConflictingCallback (), Callback (),ConflictingCallback (), ], "Found more than one stateful callback of type `ConflictingCallback`" ),
389
+ (
390
+ [ConflictingCallback (), ConflictingCallback ()],
391
+ "Found more than one stateful callback of type `ConflictingCallback`" ,
392
+ ),
393
+ (
394
+ [
395
+ Callback (),
396
+ ConflictingCallback (),
397
+ Callback (),
398
+ ConflictingCallback (),
399
+ ],
400
+ "Found more than one stateful callback of type `ConflictingCallback`" ,
401
+ ),
384
402
([ConflictingCallback (), AnotherConflictingCallback ()], "Found more than one stateful callback" ),
385
403
],
386
404
)
387
- def test_raising_error_validate_callbacks_list_function (callbacks : list , match_msg :str ):
405
+ def test_raising_error_validate_callbacks_list_function (callbacks : list , match_msg : str ):
388
406
with pytest .raises (RuntimeError , match = match_msg ):
389
407
_validate_callbacks_list (callbacks )
0 commit comments