Skip to content

Commit 196e7e9

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2061375 commit 196e7e9

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

tests/tests_pytorch/trainer/connectors/test_callback_connector.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ class StatefulCallback(Callback):
330330
def state_dict(self):
331331
return {"state": 1}
332332

333+
333334
# Test with multiple stateful callbacks with unique state keys
334335
class StatefulCallback1(Callback):
335336
@property
@@ -339,6 +340,7 @@ def state_key(self):
339340
def state_dict(self):
340341
return {"state": 1}
341342

343+
342344
class StatefulCallback2(Callback):
343345
@property
344346
def state_key(self):
@@ -347,12 +349,14 @@ def state_key(self):
347349
def state_dict(self):
348350
return {"state": 2}
349351

352+
350353
@pytest.mark.parametrize(
351354
("callbacks"),
352-
[[Callback(), Callback()],
353-
[StatefulCallback()],
354-
[StatefulCallback1(), StatefulCallback2()],
355-
],
355+
[
356+
[Callback(), Callback()],
357+
[StatefulCallback()],
358+
[StatefulCallback1(), StatefulCallback2()],
359+
],
356360
)
357361
def test_validate_callbacks_list_function(callbacks: list):
358362
"""Test the _validate_callbacks_list function directly with various scenarios."""
@@ -368,6 +372,7 @@ def state_key(self):
368372
def state_dict(self):
369373
return {"state": 1}
370374

375+
371376
# Test with different types of stateful callbacks that happen to have same state key
372377
class AnotherConflictingCallback(Callback):
373378
@property
@@ -376,14 +381,27 @@ def state_key(self):
376381

377382
def state_dict(self):
378383
return {"state": 3}
384+
385+
379386
@pytest.mark.parametrize(
380387
("callbacks", "match_msg"),
381388
[
382-
([ConflictingCallback(), ConflictingCallback()], "Found more than one stateful callback of type `ConflictingCallback`"),
383-
([Callback(), ConflictingCallback(), Callback(),ConflictingCallback(), ], "Found more than one stateful callback of type `ConflictingCallback`"),
389+
(
390+
[ConflictingCallback(), ConflictingCallback()],
391+
"Found more than one stateful callback of type `ConflictingCallback`",
392+
),
393+
(
394+
[
395+
Callback(),
396+
ConflictingCallback(),
397+
Callback(),
398+
ConflictingCallback(),
399+
],
400+
"Found more than one stateful callback of type `ConflictingCallback`",
401+
),
384402
([ConflictingCallback(), AnotherConflictingCallback()], "Found more than one stateful callback"),
385403
],
386404
)
387-
def test_raising_error_validate_callbacks_list_function(callbacks: list, match_msg:str):
405+
def test_raising_error_validate_callbacks_list_function(callbacks: list, match_msg: str):
388406
with pytest.raises(RuntimeError, match=match_msg):
389407
_validate_callbacks_list(callbacks)

0 commit comments

Comments
 (0)