Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flopy4/mf6/codec/reader/grammar/typed.lark
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ layered: "layered"i
netcdf: "netcdf"i
readarray: control [data]
control: constant | internal | external
constant: "constant"i _number
constant: "constant"i double
internal: "internal"i [factor] [iprn]
external: "open/close"i filename [factor] [binary] [iprn]
factor: "factor"i _number
iprn: "iprn"i _integer
factor: "factor"i double
iprn: "iprn"i integer
binary: "(binary)"i
filename: ESCAPED_STRING | _word
data: _number+
data: double+

_word: /[a-zA-Z0-9._'~,-\\(\\)]+/
_number: SIGNED_NUMBER | NUMBER
Expand Down
83 changes: 46 additions & 37 deletions flopy4/mf6/codec/reader/transformer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections import ChainMap
from collections.abc import Mapping
from pathlib import Path
from typing import Any

import numpy as np
import xarray as xr
from lark import Token, Transformer
from modflow_devtools.dfn import _SCALAR_TYPES, Dfn, get_blocks, get_fields


class BasicTransformer(Transformer):
Expand Down Expand Up @@ -59,38 +62,47 @@ def INT(self, token: Token) -> int:
class TypedTransformer(Transformer):
"""Type-aware transformer for MF6 input files."""

def start(self, items: list[Any]) -> dict:
def __init__(self, visit_tokens=False, dfn: Dfn = None):
super().__init__(visit_tokens)
self.dfn = dfn
self.blocks = get_blocks(dfn) if dfn else None
self.fields = get_fields(dfn) if dfn else None

def start(self, items: list[Any]) -> Mapping:
return ChainMap(*items)

def block(self, items: list[Any]) -> dict:
return items[0]

def array(self, items: list[Any]) -> dict:
infos = items[0]
if isinstance(infos, list):
data = xr.concat([info["data"] for info in infos if "data" in info], dim="layer")
arrs = items[0]
if isinstance(arrs, list):
data = xr.concat([arr["data"] for arr in arrs if "data" in arr], dim="layer")
return {
"control": [info["control"] for info in infos if "control" in info],
"control": [arr["control"] for arr in arrs if "control" in arr],
"data": data,
"attrs": {k: v for k, v in infos[0].items() if k not in ["data"]},
"dims": {"layer": len(infos)},
"attrs": {k: v for k, v in arrs[0].items() if k not in ["data"]},
"dims": {"layer": len(arrs)},
}
return infos
return arrs

def single_array(self, items: list[Any]) -> dict:
netcdf = items[0]
info = items[-1]
arr = items[-1]
if netcdf:
info["netcdf"] = netcdf
return TypedTransformer.try_create_dataarray(info)
arr["netcdf"] = netcdf
return TypedTransformer.try_create_dataarray(arr)

def layered_array(self, items: list[Any]) -> list[dict]:
netcdf = items[0]
infos = []
for info in items[2:]:
if info is None:
layers = []
for arr in items[2:]:
if arr is None:
continue
if netcdf:
info["netcdf"] = netcdf
infos.append(TypedTransformer.try_create_dataarray(info))
return infos
arr["netcdf"] = netcdf
layers.append(TypedTransformer.try_create_dataarray(arr))
return layers

def readarray(self, items: list[Any]) -> dict[str, Any]:
control = items[0]
Expand Down Expand Up @@ -129,7 +141,7 @@ def binary(self, items: list[Any]) -> dict[str, bool]:
return {"binary": True}

def filename(self, items: list[Any]) -> Path:
return Path(items[0])
return Path(items[0].strip("\"'"))

def string(self, items: list[Any]) -> str:
return items[0].strip("\"'")
Expand All @@ -146,25 +158,6 @@ def data(self, items: list[Any]) -> np.ndarray:
def netcdf(self, items: list[Any]) -> dict[str, bool]:
return {"netcdf": True}

def NUMBER(self, token: Token) -> int | float:
return float(token)

def SIGNED_NUMBER(self, token: Token) -> int | float:
return self.NUMBER(token)

def INT(self, token: Token) -> int:
return int(token)

def SIGNED_INT(self, token: Token) -> int:
return int(token)

def ESCAPED_STRING(self, token: Token) -> str:
# Remove quotes from escaped string
value = str(token)
if value.startswith('"') and value.endswith('"'):
return value[1:-1]
return value

@staticmethod
def try_create_dataarray(array_info: dict) -> dict:
control = array_info["control"]
Expand All @@ -176,3 +169,19 @@ def try_create_dataarray(array_info: dict) -> dict:
case "external":
pass
return array_info

def __default__(self, data, children, meta):
if self.blocks is None or self.fields is None:
return super().__default__(data, children, meta)
if data.endswith("_block") and (block_name := data[:-6]) in self.blocks:
return {block_name: children[0]}
elif data.endswith("_vars"):
return {item[0].lower(): item[1] for item in children}
elif (field := self.fields.get(data, None)) is not None:
if field["type"] == "keyword":
return data, True
elif field["type"] in _SCALAR_TYPES and field.get("shape", None):
return data, TypedTransformer.try_create_dataarray(children[0])
else:
return data, children[0]
return super().__default__(data, children, meta)
Loading
Loading