1+ from collections import ChainMap
2+ from collections .abc import Mapping
13from pathlib import Path
24from typing import Any
35
46import numpy as np
57import xarray as xr
68from lark import Token , Transformer
9+ from modflow_devtools .dfn import _SCALAR_TYPES , Dfn , get_blocks , get_fields
710
811
912class BasicTransformer (Transformer ):
@@ -59,38 +62,47 @@ def INT(self, token: Token) -> int:
5962class TypedTransformer (Transformer ):
6063 """Type-aware transformer for MF6 input files."""
6164
62- def start (self , items : list [Any ]) -> dict :
65+ def __init__ (self , visit_tokens = False , dfn : Dfn = None ):
66+ super ().__init__ (visit_tokens )
67+ self .dfn = dfn
68+ self .blocks = get_blocks (dfn ) if dfn else None
69+ self .fields = get_fields (dfn ) if dfn else None
70+
71+ def start (self , items : list [Any ]) -> Mapping :
72+ return ChainMap (* items )
73+
74+ def block (self , items : list [Any ]) -> dict :
6375 return items [0 ]
6476
6577 def array (self , items : list [Any ]) -> dict :
66- infos = items [0 ]
67- if isinstance (infos , list ):
68- data = xr .concat ([info ["data" ] for info in infos if "data" in info ], dim = "layer" )
78+ arrs = items [0 ]
79+ if isinstance (arrs , list ):
80+ data = xr .concat ([arr ["data" ] for arr in arrs if "data" in arr ], dim = "layer" )
6981 return {
70- "control" : [info ["control" ] for info in infos if "control" in info ],
82+ "control" : [arr ["control" ] for arr in arrs if "control" in arr ],
7183 "data" : data ,
72- "attrs" : {k : v for k , v in infos [0 ].items () if k not in ["data" ]},
73- "dims" : {"layer" : len (infos )},
84+ "attrs" : {k : v for k , v in arrs [0 ].items () if k not in ["data" ]},
85+ "dims" : {"layer" : len (arrs )},
7486 }
75- return infos
87+ return arrs
7688
7789 def single_array (self , items : list [Any ]) -> dict :
7890 netcdf = items [0 ]
79- info = items [- 1 ]
91+ arr = items [- 1 ]
8092 if netcdf :
81- info ["netcdf" ] = netcdf
82- return TypedTransformer .try_create_dataarray (info )
93+ arr ["netcdf" ] = netcdf
94+ return TypedTransformer .try_create_dataarray (arr )
8395
8496 def layered_array (self , items : list [Any ]) -> list [dict ]:
8597 netcdf = items [0 ]
86- infos = []
87- for info in items [2 :]:
88- if info is None :
98+ layers = []
99+ for arr in items [2 :]:
100+ if arr is None :
89101 continue
90102 if netcdf :
91- info ["netcdf" ] = netcdf
92- infos .append (TypedTransformer .try_create_dataarray (info ))
93- return infos
103+ arr ["netcdf" ] = netcdf
104+ layers .append (TypedTransformer .try_create_dataarray (arr ))
105+ return layers
94106
95107 def readarray (self , items : list [Any ]) -> dict [str , Any ]:
96108 control = items [0 ]
@@ -129,7 +141,7 @@ def binary(self, items: list[Any]) -> dict[str, bool]:
129141 return {"binary" : True }
130142
131143 def filename (self , items : list [Any ]) -> Path :
132- return Path (items [0 ])
144+ return Path (items [0 ]. strip ( " \" '" ) )
133145
134146 def string (self , items : list [Any ]) -> str :
135147 return items [0 ].strip ("\" '" )
@@ -146,25 +158,6 @@ def data(self, items: list[Any]) -> np.ndarray:
146158 def netcdf (self , items : list [Any ]) -> dict [str , bool ]:
147159 return {"netcdf" : True }
148160
149- def NUMBER (self , token : Token ) -> int | float :
150- return float (token )
151-
152- def SIGNED_NUMBER (self , token : Token ) -> int | float :
153- return self .NUMBER (token )
154-
155- def INT (self , token : Token ) -> int :
156- return int (token )
157-
158- def SIGNED_INT (self , token : Token ) -> int :
159- return int (token )
160-
161- def ESCAPED_STRING (self , token : Token ) -> str :
162- # Remove quotes from escaped string
163- value = str (token )
164- if value .startswith ('"' ) and value .endswith ('"' ):
165- return value [1 :- 1 ]
166- return value
167-
168161 @staticmethod
169162 def try_create_dataarray (array_info : dict ) -> dict :
170163 control = array_info ["control" ]
@@ -176,3 +169,19 @@ def try_create_dataarray(array_info: dict) -> dict:
176169 case "external" :
177170 pass
178171 return array_info
172+
173+ def __default__ (self , data , children , meta ):
174+ if self .blocks is None or self .fields is None :
175+ return super ().__default__ (data , children , meta )
176+ if data .endswith ("_block" ) and (block_name := data [:- 6 ]) in self .blocks :
177+ return {block_name : children [0 ]}
178+ elif data .endswith ("_vars" ):
179+ return {item [0 ].lower (): item [1 ] for item in children }
180+ elif (field := self .fields .get (data , None )) is not None :
181+ if field ["type" ] == "keyword" :
182+ return data , True
183+ elif field ["type" ] in _SCALAR_TYPES and field .get ("shape" , None ):
184+ return data , TypedTransformer .try_create_dataarray (children [0 ])
185+ else :
186+ return data , children [0 ]
187+ return super ().__default__ (data , children , meta )
0 commit comments