Skip to content

Commit 2061375

Browse files
committed
split tests
1 parent 675068b commit 2061375

File tree

1 file changed

+61
-68
lines changed

1 file changed

+61
-68
lines changed

tests/tests_pytorch/trainer/connectors/test_callback_connector.py

Lines changed: 61 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,12 @@ def test_checkpoint_callbacks_are_last(tmp_path):
104104
]
105105

106106

107-
class StatefulCallback0(Callback):
107+
class StatefulCallbackContent0(Callback):
108108
def state_dict(self):
109109
return {"content0": 0}
110110

111111

112-
class StatefulCallback1(Callback):
112+
class StatefulCallbackContent1(Callback):
113113
def __init__(self, unique=None, other=None):
114114
self._unique = unique
115115
self._other = other
@@ -126,9 +126,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
126126
"""Test that all callback states get saved even if the ModelCheckpoint is not given as last and when there are
127127
multiple callbacks of the same type."""
128128

129-
callback0 = StatefulCallback0()
130-
callback1 = StatefulCallback1(unique="one")
131-
callback2 = StatefulCallback1(unique="two", other=2)
129+
callback0 = StatefulCallbackContent0()
130+
callback1 = StatefulCallbackContent1(unique="one")
131+
callback2 = StatefulCallbackContent1(unique="two", other=2)
132132
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="all_states")
133133
model = BoringModel()
134134
trainer = Trainer(
@@ -147,9 +147,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
147147
trainer.fit(model)
148148

149149
ckpt = torch.load(str(tmp_path / "all_states.ckpt"), weights_only=True)
150-
state0 = ckpt["callbacks"]["StatefulCallback0"]
151-
state1 = ckpt["callbacks"]["StatefulCallback1{'unique': 'one'}"]
152-
state2 = ckpt["callbacks"]["StatefulCallback1{'unique': 'two'}"]
150+
state0 = ckpt["callbacks"]["StatefulCallbackContent0"]
151+
state1 = ckpt["callbacks"]["StatefulCallbackContent1{'unique': 'one'}"]
152+
state2 = ckpt["callbacks"]["StatefulCallbackContent1{'unique': 'two'}"]
153153
assert "content0" in state0
154154
assert state0["content0"] == 0
155155
assert "content1" in state1
@@ -325,72 +325,65 @@ def state_dict(self):
325325
Trainer(callbacks=[MockCallback(), MockCallback()])
326326

327327

328-
def test_validate_callbacks_list_function():
329-
"""Test the _validate_callbacks_list function directly with various scenarios."""
330-
331-
# Test with non-stateful callbacks
332-
callback1 = Callback()
333-
callback2 = Callback()
334-
_validate_callbacks_list([callback1, callback2])
335-
336-
# Test with single stateful callback
337-
class StatefulCallback(Callback):
338-
def state_dict(self):
339-
return {"state": 1}
340-
341-
stateful_cb = StatefulCallback()
342-
_validate_callbacks_list([stateful_cb])
343-
344-
# Test with multiple stateful callbacks with unique state keys
345-
class StatefulCallback1(Callback):
346-
@property
347-
def state_key(self):
348-
return "unique_key_1"
349-
350-
def state_dict(self):
351-
return {"state": 1}
352-
353-
class StatefulCallback2(Callback):
354-
@property
355-
def state_key(self):
356-
return "unique_key_2"
357-
358-
def state_dict(self):
359-
return {"state": 2}
360-
361-
stateful_cb1 = StatefulCallback1()
362-
stateful_cb2 = StatefulCallback2()
363-
_validate_callbacks_list([stateful_cb1, stateful_cb2])
328+
# Test with single stateful callback
329+
class StatefulCallback(Callback):
330+
def state_dict(self):
331+
return {"state": 1}
364332

365-
# Test with multiple stateful callbacks with same state key
366-
class ConflictingCallback(Callback):
367-
@property
368-
def state_key(self):
369-
return "same_key"
333+
# Test with multiple stateful callbacks with unique state keys
334+
class StatefulCallback1(Callback):
335+
@property
336+
def state_key(self):
337+
return "unique_key_1"
370338

371-
def state_dict(self):
372-
return {"state": 1}
339+
def state_dict(self):
340+
return {"state": 1}
373341

374-
conflicting_cb1 = ConflictingCallback()
375-
conflicting_cb2 = ConflictingCallback()
342+
class StatefulCallback2(Callback):
343+
@property
344+
def state_key(self):
345+
return "unique_key_2"
376346

377-
with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `ConflictingCallback`"):
378-
_validate_callbacks_list([conflicting_cb1, conflicting_cb2])
347+
def state_dict(self):
348+
return {"state": 2}
349+
350+
@pytest.mark.parametrize(
351+
("callbacks"),
352+
[[Callback(), Callback()],
353+
[StatefulCallback()],
354+
[StatefulCallback1(), StatefulCallback2()],
355+
],
356+
)
357+
def test_validate_callbacks_list_function(callbacks: list):
358+
"""Test the _validate_callbacks_list function directly with various scenarios."""
359+
_validate_callbacks_list(callbacks)
379360

380-
# Test with mix of stateful and non-stateful callbacks where stateful ones conflict
381-
with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `ConflictingCallback`"):
382-
_validate_callbacks_list([callback1, conflicting_cb1, callback2, conflicting_cb2])
383361

384-
# Test with different types of stateful callbacks that happen to have same state key
385-
class AnotherConflictingCallback(Callback):
386-
@property
387-
def state_key(self):
388-
return "same_key" # Same key as ConflictingCallback
362+
# Test with multiple stateful callbacks with same state key
363+
class ConflictingCallback(Callback):
364+
@property
365+
def state_key(self):
366+
return "same_key"
389367

390-
def state_dict(self):
391-
return {"state": 3}
368+
def state_dict(self):
369+
return {"state": 1}
392370

393-
another_conflicting_cb = AnotherConflictingCallback()
371+
# Test with different types of stateful callbacks that happen to have same state key
372+
class AnotherConflictingCallback(Callback):
373+
@property
374+
def state_key(self):
375+
return "same_key" # Same key as ConflictingCallback
394376

395-
with pytest.raises(RuntimeError, match="Found more than one stateful callback"):
396-
_validate_callbacks_list([conflicting_cb1, another_conflicting_cb])
377+
def state_dict(self):
378+
return {"state": 3}
379+
@pytest.mark.parametrize(
380+
("callbacks", "match_msg"),
381+
[
382+
([ConflictingCallback(), ConflictingCallback()], "Found more than one stateful callback of type `ConflictingCallback`"),
383+
([Callback(), ConflictingCallback(), Callback(),ConflictingCallback(), ], "Found more than one stateful callback of type `ConflictingCallback`"),
384+
([ConflictingCallback(), AnotherConflictingCallback()], "Found more than one stateful callback"),
385+
],
386+
)
387+
def test_raising_error_validate_callbacks_list_function(callbacks: list, match_msg:str):
388+
with pytest.raises(RuntimeError, match=match_msg):
389+
_validate_callbacks_list(callbacks)

0 commit comments

Comments
 (0)