Skip to content

Commit e33f35f

Browse files
authored
generalize typed transformer (#180)
instead of hard-coding block/field rules, inject the component specification into the transformer and handle them dynamically. coming next, instead of hard-coding the grammar, generate it from dfn.
1 parent faf27ea commit e33f35f

File tree

5 files changed

+183
-188
lines changed

5 files changed

+183
-188
lines changed

flopy4/mf6/codec/reader/grammar/typed.lark

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ layered: "layered"i
1010
netcdf: "netcdf"i
1111
readarray: control [data]
1212
control: constant | internal | external
13-
constant: "constant"i _number
13+
constant: "constant"i double
1414
internal: "internal"i [factor] [iprn]
1515
external: "open/close"i filename [factor] [binary] [iprn]
16-
factor: "factor"i _number
17-
iprn: "iprn"i _integer
16+
factor: "factor"i double
17+
iprn: "iprn"i integer
1818
binary: "(binary)"i
1919
filename: ESCAPED_STRING | _word
20-
data: _number+
20+
data: double+
2121

2222
_word: /[a-zA-Z0-9._'~,-\\(\\)]+/
2323
_number: SIGNED_NUMBER | NUMBER

flopy4/mf6/codec/reader/transformer.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from collections import ChainMap
2+
from collections.abc import Mapping
13
from pathlib import Path
24
from typing import Any
35

46
import numpy as np
57
import xarray as xr
68
from lark import Token, Transformer
9+
from modflow_devtools.dfn import _SCALAR_TYPES, Dfn, get_blocks, get_fields
710

811

912
class BasicTransformer(Transformer):
@@ -59,38 +62,47 @@ def INT(self, token: Token) -> int:
5962
class TypedTransformer(Transformer):
6063
"""Type-aware transformer for MF6 input files."""
6164

62-
def start(self, items: list[Any]) -> dict:
65+
def __init__(self, visit_tokens=False, dfn: Dfn = None):
66+
super().__init__(visit_tokens)
67+
self.dfn = dfn
68+
self.blocks = get_blocks(dfn) if dfn else None
69+
self.fields = get_fields(dfn) if dfn else None
70+
71+
def start(self, items: list[Any]) -> Mapping:
72+
return ChainMap(*items)
73+
74+
def block(self, items: list[Any]) -> dict:
6375
return items[0]
6476

6577
def array(self, items: list[Any]) -> dict:
66-
infos = items[0]
67-
if isinstance(infos, list):
68-
data = xr.concat([info["data"] for info in infos if "data" in info], dim="layer")
78+
arrs = items[0]
79+
if isinstance(arrs, list):
80+
data = xr.concat([arr["data"] for arr in arrs if "data" in arr], dim="layer")
6981
return {
70-
"control": [info["control"] for info in infos if "control" in info],
82+
"control": [arr["control"] for arr in arrs if "control" in arr],
7183
"data": data,
72-
"attrs": {k: v for k, v in infos[0].items() if k not in ["data"]},
73-
"dims": {"layer": len(infos)},
84+
"attrs": {k: v for k, v in arrs[0].items() if k not in ["data"]},
85+
"dims": {"layer": len(arrs)},
7486
}
75-
return infos
87+
return arrs
7688

7789
def single_array(self, items: list[Any]) -> dict:
7890
netcdf = items[0]
79-
info = items[-1]
91+
arr = items[-1]
8092
if netcdf:
81-
info["netcdf"] = netcdf
82-
return TypedTransformer.try_create_dataarray(info)
93+
arr["netcdf"] = netcdf
94+
return TypedTransformer.try_create_dataarray(arr)
8395

8496
def layered_array(self, items: list[Any]) -> list[dict]:
8597
netcdf = items[0]
86-
infos = []
87-
for info in items[2:]:
88-
if info is None:
98+
layers = []
99+
for arr in items[2:]:
100+
if arr is None:
89101
continue
90102
if netcdf:
91-
info["netcdf"] = netcdf
92-
infos.append(TypedTransformer.try_create_dataarray(info))
93-
return infos
103+
arr["netcdf"] = netcdf
104+
layers.append(TypedTransformer.try_create_dataarray(arr))
105+
return layers
94106

95107
def readarray(self, items: list[Any]) -> dict[str, Any]:
96108
control = items[0]
@@ -129,7 +141,7 @@ def binary(self, items: list[Any]) -> dict[str, bool]:
129141
return {"binary": True}
130142

131143
def filename(self, items: list[Any]) -> Path:
132-
return Path(items[0])
144+
return Path(items[0].strip("\"'"))
133145

134146
def string(self, items: list[Any]) -> str:
135147
return items[0].strip("\"'")
@@ -146,25 +158,6 @@ def data(self, items: list[Any]) -> np.ndarray:
146158
def netcdf(self, items: list[Any]) -> dict[str, bool]:
147159
return {"netcdf": True}
148160

149-
def NUMBER(self, token: Token) -> int | float:
150-
return float(token)
151-
152-
def SIGNED_NUMBER(self, token: Token) -> int | float:
153-
return self.NUMBER(token)
154-
155-
def INT(self, token: Token) -> int:
156-
return int(token)
157-
158-
def SIGNED_INT(self, token: Token) -> int:
159-
return int(token)
160-
161-
def ESCAPED_STRING(self, token: Token) -> str:
162-
# Remove quotes from escaped string
163-
value = str(token)
164-
if value.startswith('"') and value.endswith('"'):
165-
return value[1:-1]
166-
return value
167-
168161
@staticmethod
169162
def try_create_dataarray(array_info: dict) -> dict:
170163
control = array_info["control"]
@@ -176,3 +169,19 @@ def try_create_dataarray(array_info: dict) -> dict:
176169
case "external":
177170
pass
178171
return array_info
172+
173+
def __default__(self, data, children, meta):
174+
if self.blocks is None or self.fields is None:
175+
return super().__default__(data, children, meta)
176+
if data.endswith("_block") and (block_name := data[:-6]) in self.blocks:
177+
return {block_name: children[0]}
178+
elif data.endswith("_vars"):
179+
return {item[0].lower(): item[1] for item in children}
180+
elif (field := self.fields.get(data, None)) is not None:
181+
if field["type"] == "keyword":
182+
return data, True
183+
elif field["type"] in _SCALAR_TYPES and field.get("shape", None):
184+
return data, TypedTransformer.try_create_dataarray(children[0])
185+
else:
186+
return data, children[0]
187+
return super().__default__(data, children, meta)

0 commit comments

Comments
 (0)