Skip to content

Commit 0869293

Browse files
committed
add testing
1 parent 2af9944 commit 0869293

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

tests/tests_pytorch/models/test_hparams.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pickle
1818
import sys
1919
from argparse import Namespace
20-
from dataclasses import dataclass
20+
from dataclasses import dataclass, field
2121
from enum import Enum
2222
from unittest import mock
2323

@@ -881,6 +881,32 @@ def test_dataclass_lightning_module(tmp_path):
881881
assert model.hparams == {"mandatory": 33, "optional": "cocofruit"}
882882

883883

884+
def test_dataclass_with_init_false_fields():
885+
"""Test that save_hyperparameters() filters out fields with init=False and issues a warning."""
886+
887+
@dataclass
888+
class DataClassWithInitFalseFieldsModel(BoringModel):
889+
mandatory: int
890+
optional: str = "optional"
891+
non_init_field: int = field(default=999, init=False)
892+
another_non_init: str = field(default="not_in_init", init=False)
893+
894+
def __post_init__(self):
895+
super().__init__()
896+
self.save_hyperparameters()
897+
898+
with pytest.warns(UserWarning, match="Detected a dataclass with fields with `init=False`"):
899+
model = DataClassWithInitFalseFieldsModel(33, optional="cocofruit")
900+
901+
expected_hparams = {"mandatory": 33, "optional": "cocofruit"}
902+
assert model.hparams == expected_hparams
903+
904+
assert model.non_init_field == 999
905+
assert model.another_non_init == "not_in_init"
906+
assert "non_init_field" not in model.hparams
907+
assert "another_non_init" not in model.hparams
908+
909+
884910
class NoHparamsModel(BoringModel):
885911
"""Tests a model without hparams."""
886912

0 commit comments

Comments
 (0)