Skip to content

Commit 997538c

Browse files
authored
feat(dfn): add from_dict method (#241)
Allows using Dfn/Field spec objects explicitly by passing a dictionary to __init__ via double star syntax, or via from_dict which handles field structuring automatically, and with strict=False ignores unrecognized keys (like pydantic's extra="ignore"), if strict=True unrecognized keys cause an error (like pydantic's extra="forbid")
1 parent 37d2f9a commit 997538c

File tree

4 files changed

+315
-6
lines changed

4 files changed

+315
-6
lines changed

autotest/test_dfn.py

Lines changed: 235 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from dataclasses import asdict
12
from pathlib import Path
23

34
import pytest
5+
from packaging.version import Version
46

5-
from modflow_devtools.dfn import _load_common, load, load_flat
7+
from modflow_devtools.dfn import Dfn, _load_common, load, load_flat
68
from modflow_devtools.dfn.fetch import fetch_dfns
9+
from modflow_devtools.dfn.schema.v1 import FieldV1
10+
from modflow_devtools.dfn.schema.v2 import FieldV2
711
from modflow_devtools.dfn2toml import convert
812
from modflow_devtools.markers import requires_pkg
913

@@ -114,3 +118,233 @@ def test_convert(function_tmpdir):
114118
assert dis.parent == "gwf"
115119
assert "options" in (dis.blocks or {})
116120
assert "dimensions" in (dis.blocks or {})
121+
122+
123+
def test_dfn_from_dict_ignores_extra_keys():
124+
d = {
125+
"schema_version": Version("2"),
126+
"name": "test-dfn",
127+
"extra_key": "should be allowed",
128+
"another_extra": 123,
129+
}
130+
dfn = Dfn.from_dict(d)
131+
assert dfn.name == "test-dfn"
132+
assert dfn.schema_version == Version("2")
133+
134+
135+
def test_dfn_from_dict_strict_mode():
136+
d = {
137+
"schema_version": Version("2"),
138+
"name": "test-dfn",
139+
"extra_key": "should cause error",
140+
}
141+
with pytest.raises(ValueError, match="Unrecognized keys in DFN data"):
142+
Dfn.from_dict(d, strict=True)
143+
144+
145+
def test_dfn_from_dict_strict_mode_nested():
146+
d = {
147+
"schema_version": Version("2"),
148+
"name": "test-dfn",
149+
"blocks": {
150+
"options": {
151+
"test_field": {
152+
"name": "test_field",
153+
"type": "keyword",
154+
"extra_key": "should cause error",
155+
},
156+
},
157+
},
158+
}
159+
with pytest.raises(ValueError, match="Unrecognized keys in field data"):
160+
Dfn.from_dict(d, strict=True)
161+
162+
163+
def test_dfn_from_dict_roundtrip():
164+
original = Dfn(
165+
schema_version=Version("2"),
166+
name="gwf-nam",
167+
parent="sim-nam",
168+
advanced=False,
169+
multi=True,
170+
blocks={"options": {}},
171+
)
172+
d = asdict(original)
173+
reconstructed = Dfn.from_dict(d)
174+
assert reconstructed.name == original.name
175+
assert reconstructed.schema_version == original.schema_version
176+
assert reconstructed.parent == original.parent
177+
assert reconstructed.advanced == original.advanced
178+
assert reconstructed.multi == original.multi
179+
assert reconstructed.blocks == original.blocks
180+
181+
182+
def test_fieldv1_from_dict_ignores_extra_keys():
183+
d = {
184+
"name": "test_field",
185+
"type": "keyword",
186+
"extra_key": "should be allowed",
187+
"another_extra": 123,
188+
}
189+
field = FieldV1.from_dict(d)
190+
assert field.name == "test_field"
191+
assert field.type == "keyword"
192+
193+
194+
def test_fieldv1_from_dict_strict_mode():
195+
d = {
196+
"name": "test_field",
197+
"type": "keyword",
198+
"extra_key": "should cause error",
199+
}
200+
with pytest.raises(ValueError, match="Unrecognized keys in field data"):
201+
FieldV1.from_dict(d, strict=True)
202+
203+
204+
def test_fieldv1_from_dict_roundtrip():
205+
original = FieldV1(
206+
name="maxbound",
207+
type="integer",
208+
block="dimensions",
209+
description="maximum number of cells",
210+
tagged=True,
211+
)
212+
d = asdict(original)
213+
reconstructed = FieldV1.from_dict(d)
214+
assert reconstructed.name == original.name
215+
assert reconstructed.type == original.type
216+
assert reconstructed.block == original.block
217+
assert reconstructed.description == original.description
218+
assert reconstructed.tagged == original.tagged
219+
220+
221+
def test_fieldv2_from_dict_ignores_extra_keys():
222+
d = {
223+
"name": "test_field",
224+
"type": "keyword",
225+
"extra_key": "should be allowed",
226+
"another_extra": 123,
227+
}
228+
field = FieldV2.from_dict(d)
229+
assert field.name == "test_field"
230+
assert field.type == "keyword"
231+
232+
233+
def test_fieldv2_from_dict_strict_mode():
234+
d = {
235+
"name": "test_field",
236+
"type": "keyword",
237+
"extra_key": "should cause error",
238+
}
239+
with pytest.raises(ValueError, match="Unrecognized keys in field data"):
240+
FieldV2.from_dict(d, strict=True)
241+
242+
243+
def test_fieldv2_from_dict_roundtrip():
244+
original = FieldV2(
245+
name="nper",
246+
type="integer",
247+
block="dimensions",
248+
description="number of stress periods",
249+
optional=False,
250+
)
251+
d = asdict(original)
252+
reconstructed = FieldV2.from_dict(d)
253+
assert reconstructed.name == original.name
254+
assert reconstructed.type == original.type
255+
assert reconstructed.block == original.block
256+
assert reconstructed.description == original.description
257+
assert reconstructed.optional == original.optional
258+
259+
260+
def test_dfn_from_dict_with_v1_field_dicts():
261+
d = {
262+
"schema_version": Version("1"),
263+
"name": "test-dfn",
264+
"blocks": {
265+
"options": {
266+
"save_flows": {
267+
"name": "save_flows",
268+
"type": "keyword",
269+
"tagged": True,
270+
"in_record": False,
271+
},
272+
},
273+
},
274+
}
275+
dfn = Dfn.from_dict(d)
276+
assert dfn.schema_version == Version("1")
277+
assert dfn.name == "test-dfn"
278+
assert dfn.blocks is not None
279+
assert "options" in dfn.blocks
280+
assert "save_flows" in dfn.blocks["options"]
281+
282+
field = dfn.blocks["options"]["save_flows"]
283+
assert isinstance(field, FieldV1)
284+
assert field.name == "save_flows"
285+
assert field.type == "keyword"
286+
assert field.tagged is True
287+
assert field.in_record is False
288+
289+
290+
def test_dfn_from_dict_with_v2_field_dicts():
291+
d = {
292+
"schema_version": Version("2"),
293+
"name": "test-dfn",
294+
"blocks": {
295+
"dimensions": {
296+
"nper": {
297+
"name": "nper",
298+
"type": "integer",
299+
"optional": False,
300+
},
301+
},
302+
},
303+
}
304+
dfn = Dfn.from_dict(d)
305+
assert dfn.schema_version == Version("2")
306+
assert dfn.name == "test-dfn"
307+
assert dfn.blocks is not None
308+
assert "dimensions" in dfn.blocks
309+
assert "nper" in dfn.blocks["dimensions"]
310+
311+
field = dfn.blocks["dimensions"]["nper"]
312+
assert isinstance(field, FieldV2)
313+
assert field.name == "nper"
314+
assert field.type == "integer"
315+
assert field.optional is False
316+
317+
318+
def test_dfn_from_dict_defaults_to_v2_fields():
319+
d = {
320+
"name": "test-dfn",
321+
"blocks": {
322+
"options": {
323+
"some_field": {
324+
"name": "some_field",
325+
"type": "keyword",
326+
},
327+
},
328+
},
329+
}
330+
dfn = Dfn.from_dict(d)
331+
assert dfn.blocks is not None
332+
field = dfn.blocks["options"]["some_field"]
333+
assert isinstance(field, FieldV2)
334+
assert dfn.schema_version == Version("2")
335+
336+
337+
def test_dfn_from_dict_with_already_deserialized_fields():
338+
field = FieldV2(name="test", type="keyword")
339+
d = {
340+
"schema_version": Version("2"),
341+
"name": "test-dfn",
342+
"blocks": {
343+
"options": {
344+
"test": field,
345+
},
346+
},
347+
}
348+
dfn = Dfn.from_dict(d)
349+
assert dfn.blocks is not None
350+
assert dfn.blocks["options"]["test"] is field

modflow_devtools/dfn/__init__.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
try_parse_bool,
2525
try_parse_parent,
2626
)
27-
from modflow_devtools.dfn.schema.block import Block, Blocks
27+
from modflow_devtools.dfn.schema.block import Block, Blocks, block_sort_key
2828
from modflow_devtools.dfn.schema.field import SCALAR_TYPES, Field, Fields
2929
from modflow_devtools.dfn.schema.ref import Ref
3030
from modflow_devtools.dfn.schema.v1 import FieldV1
@@ -42,6 +42,7 @@
4242
"FieldV2",
4343
"Fields",
4444
"Ref",
45+
"block_sort_key",
4546
"load",
4647
"load_flat",
4748
"load_tree",
@@ -90,6 +91,52 @@ def fields(self) -> Fields:
9091
# TODO: change to normal dict after deprecating v1 schema
9192
return OMD(fields)
9293

94+
@classmethod
95+
def from_dict(cls, d: dict, strict: bool = False) -> "Dfn":
96+
"""
97+
Create a Dfn instance from a dictionary.
98+
99+
Parameters
100+
----------
101+
d : dict
102+
Dictionary containing DFN data
103+
strict : bool, optional
104+
If True, raise ValueError if dict contains unrecognized keys at the
105+
top level or in nested field dicts. If False (default), ignore
106+
unrecognized keys.
107+
"""
108+
keys = list(cls.__annotations__.keys())
109+
if strict:
110+
extra_keys = set(d.keys()) - set(keys)
111+
if extra_keys:
112+
raise ValueError(f"Unrecognized keys in DFN data: {extra_keys}")
113+
data = {k: v for k, v in d.items() if k in keys}
114+
schema_version = data.get("schema_version", Version("2"))
115+
field_cls = FieldV1 if schema_version == Version("1") else FieldV2
116+
117+
def _fields(block_name, block_data):
118+
fields = {}
119+
for field_name, field_data in block_data.items():
120+
if isinstance(field_data, dict):
121+
fields[field_name] = field_cls.from_dict(field_data, strict=strict)
122+
elif isinstance(field_data, field_cls):
123+
fields[field_name] = field_data
124+
else:
125+
raise TypeError(
126+
f"Invalid field data for {field_name} in block {block_name}: "
127+
f"expected dict or Field, got {type(field_data)}"
128+
)
129+
return fields
130+
131+
if blocks := data.get("blocks"):
132+
data["schema_version"] = schema_version
133+
data["blocks"] = {
134+
block_name: _fields(block_name, block_data)
135+
for block_name, block_data in blocks.items()
136+
}
137+
138+
return cls(**data)
139+
93140

94141
class SchemaMap(ABC):
95142
@abstractmethod

modflow_devtools/dfn/schema/v1.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,21 @@ class FieldV1(Field):
1717
mf6internal: str | None = None
1818

1919
@classmethod
20-
def from_dict(cls, d: dict) -> "FieldV1":
21-
"""Create a FieldV1 instance from a dictionary."""
20+
def from_dict(cls, d: dict, strict: bool = False) -> "FieldV1":
21+
"""
22+
Create a FieldV1 instance from a dictionary.
23+
24+
Parameters
25+
----------
26+
d : dict
27+
Dictionary containing field data
28+
strict : bool, optional
29+
If True, raise ValueError if dict contains unrecognized keys.
30+
If False (default), ignore unrecognized keys.
31+
"""
2232
keys = list(cls.__annotations__.keys()) + list(Field.__annotations__.keys())
33+
if strict:
34+
extra_keys = set(d.keys()) - set(keys)
35+
if extra_keys:
36+
raise ValueError(f"Unrecognized keys in field data: {extra_keys}")
2337
return cls(**{k: v for k, v in d.items() if k in keys})

modflow_devtools/dfn/schema/v2.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,21 @@ class FieldV2(Field):
1313
pass
1414

1515
@classmethod
16-
def from_dict(cls, d: dict) -> "FieldV2":
17-
"""Create a FieldV2 instance from a dictionary."""
16+
def from_dict(cls, d: dict, strict: bool = False) -> "FieldV2":
17+
"""
18+
Create a FieldV2 instance from a dictionary.
19+
20+
Parameters
21+
----------
22+
d : dict
23+
Dictionary containing field data
24+
strict : bool, optional
25+
If True, raise ValueError if dict contains unrecognized keys.
26+
If False (default), ignore unrecognized keys.
27+
"""
1828
keys = list(cls.__annotations__.keys()) + list(Field.__annotations__.keys())
29+
if strict:
30+
extra_keys = set(d.keys()) - set(keys)
31+
if extra_keys:
32+
raise ValueError(f"Unrecognized keys in field data: {extra_keys}")
1933
return cls(**{k: v for k, v in d.items() if k in keys})

0 commit comments

Comments
 (0)