|
17 | 17 | import pickle |
18 | 18 | import sys |
19 | 19 | from argparse import Namespace |
20 | | -from dataclasses import dataclass |
| 20 | +from dataclasses import dataclass, field |
21 | 21 | from enum import Enum |
22 | 22 | from unittest import mock |
23 | 23 |
|
@@ -881,6 +881,32 @@ def test_dataclass_lightning_module(tmp_path): |
881 | 881 | assert model.hparams == {"mandatory": 33, "optional": "cocofruit"} |
882 | 882 |
|
883 | 883 |
|
| 884 | +def test_dataclass_with_init_false_fields(): |
| 885 | + """Test that save_hyperparameters() filters out fields with init=False and issues a warning.""" |
| 886 | + |
| 887 | + @dataclass |
| 888 | + class DataClassWithInitFalseFieldsModel(BoringModel): |
| 889 | + mandatory: int |
| 890 | + optional: str = "optional" |
| 891 | + non_init_field: int = field(default=999, init=False) |
| 892 | + another_non_init: str = field(default="not_in_init", init=False) |
| 893 | + |
| 894 | + def __post_init__(self): |
| 895 | + super().__init__() |
| 896 | + self.save_hyperparameters() |
| 897 | + |
| 898 | + with pytest.warns(UserWarning, match="Detected a dataclass with fields with `init=False`"): |
| 899 | + model = DataClassWithInitFalseFieldsModel(33, optional="cocofruit") |
| 900 | + |
| 901 | + expected_hparams = {"mandatory": 33, "optional": "cocofruit"} |
| 902 | + assert model.hparams == expected_hparams |
| 903 | + |
| 904 | + assert model.non_init_field == 999 |
| 905 | + assert model.another_non_init == "not_in_init" |
| 906 | + assert "non_init_field" not in model.hparams |
| 907 | + assert "another_non_init" not in model.hparams |
| 908 | + |
| 909 | + |
884 | 910 | class NoHparamsModel(BoringModel): |
885 | 911 | """Tests a model without hparams.""" |
886 | 912 |
|
|
0 commit comments