Skip to content

Commit 5a0df4d

Browse files
committed
add unittesting of function
1 parent 4e35c22 commit 5a0df4d

File tree

1 file changed

+72
-1
lines changed

1 file changed

+72
-1
lines changed

tests/tests_pytorch/trainer/connectors/test_callback_connector.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
3434
from lightning.pytorch.demos.boring_classes import BoringModel
35-
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
35+
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector, _validate_callbacks_list
3636

3737

3838
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
@@ -323,3 +323,74 @@ def state_dict(self):
323323

324324
with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `MockCallback`"):
325325
Trainer(callbacks=[MockCallback(), MockCallback()])
326+
327+
328+
def test_validate_callbacks_list_function():
329+
"""Test the _validate_callbacks_list function directly with various scenarios."""
330+
331+
# Test with non-stateful callbacks
332+
callback1 = Callback()
333+
callback2 = Callback()
334+
_validate_callbacks_list([callback1, callback2])
335+
336+
# Test with single stateful callback
337+
class StatefulCallback(Callback):
338+
def state_dict(self):
339+
return {"state": 1}
340+
341+
stateful_cb = StatefulCallback()
342+
_validate_callbacks_list([stateful_cb])
343+
344+
# Test with multiple stateful callbacks with unique state keys
345+
class StatefulCallback1(Callback):
346+
@property
347+
def state_key(self):
348+
return "unique_key_1"
349+
350+
def state_dict(self):
351+
return {"state": 1}
352+
353+
class StatefulCallback2(Callback):
354+
@property
355+
def state_key(self):
356+
return "unique_key_2"
357+
358+
def state_dict(self):
359+
return {"state": 2}
360+
361+
stateful_cb1 = StatefulCallback1()
362+
stateful_cb2 = StatefulCallback2()
363+
_validate_callbacks_list([stateful_cb1, stateful_cb2])
364+
365+
# Test with multiple stateful callbacks with same state key
366+
class ConflictingCallback(Callback):
367+
@property
368+
def state_key(self):
369+
return "same_key"
370+
371+
def state_dict(self):
372+
return {"state": 1}
373+
374+
conflicting_cb1 = ConflictingCallback()
375+
conflicting_cb2 = ConflictingCallback()
376+
377+
with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `ConflictingCallback`"):
378+
_validate_callbacks_list([conflicting_cb1, conflicting_cb2])
379+
380+
# Test with mix of stateful and non-stateful callbacks where stateful ones conflict
381+
with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `ConflictingCallback`"):
382+
_validate_callbacks_list([callback1, conflicting_cb1, callback2, conflicting_cb2])
383+
384+
# Test with different types of stateful callbacks that happen to have same state key
385+
class AnotherConflictingCallback(Callback):
386+
@property
387+
def state_key(self):
388+
return "same_key" # Same key as ConflictingCallback
389+
390+
def state_dict(self):
391+
return {"state": 3}
392+
393+
another_conflicting_cb = AnotherConflictingCallback()
394+
395+
with pytest.raises(RuntimeError, match="Found more than one stateful callback"):
396+
_validate_callbacks_list([conflicting_cb1, another_conflicting_cb])

0 commit comments

Comments
 (0)