Skip to content

Commit bb72343

Browse files
authored
no beartype, closer with xattree (#89)
1 parent 29ce60a commit bb72343

File tree

7 files changed

+203
-317
lines changed

7 files changed

+203
-317
lines changed

flopy4/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
from beartype.claw import beartype_this_package
2-
3-
beartype_this_package()

flopy4/mf6/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,22 @@ def __attrs_init_subclass__(cls):
2727
COMPONENTS[cls.__name__.lower()] = cls
2828

2929

30+
@define
3031
class Package(Component):
3132
pass
3233

3334

35+
@define
3436
class Model(Component):
3537
pass
3638

3739

40+
@define
3841
class Solution(Package):
3942
pass
4043

4144

45+
@define
4246
class Exchange(Package):
4347
exgtype: type = field()
4448
exgfile: Path = field()

flopy4/mf6/gwf/oc.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
from pathlib import Path
22
from typing import Literal, Optional
33

4+
import numpy as np
45
from attr import define, field
6+
from numpy.typing import NDArray
57
from xattree import array, xattree
68

79
from flopy4.mf6 import Package
810
from flopy4.utils import to_path
911

10-
Steps = (
11-
Literal["all"] | Literal["first"] | Literal["last"] | tuple[str | int, ...]
12-
)
13-
1412

1513
@xattree
1614
class Oc(Package):
17-
@define
15+
@define(slots=False)
1816
class Format:
1917
columns: int = field(default=10)
2018
width: int = field(default=11)
@@ -23,15 +21,18 @@ class Format:
2321
field(default="general")
2422
)
2523

26-
@define
24+
@define(slots=False)
25+
class Steps:
26+
all: bool = field()
27+
first: bool = field()
28+
last: bool = field()
29+
steps: list[int] = field()
30+
frequency: int = field()
31+
32+
@define(slots=False)
2733
class Period:
28-
# TODO follow imod-python for OC SPD
2934
rtype: str = field()
30-
steps: Steps = field()
31-
32-
@define
33-
class Steps_:
34-
steps: Steps = field()
35+
steps: "Oc.Steps" = field()
3536

3637
budget_file: Optional[Path] = field(
3738
converter=to_path,
@@ -51,12 +52,14 @@ class Steps_:
5152
format: Optional[Format] = field(
5253
default=None, init=False, metadata={"block": "options"}
5354
)
54-
saverecord: Optional[list[Steps]] = array(
55+
saverecord: Optional[NDArray[np.object_]] = array(
56+
Period,
5557
dims=("nper",),
5658
default=None,
5759
metadata={"block": "perioddata"},
5860
)
59-
printrecord: Optional[list[Steps]] = array(
61+
printrecord: Optional[NDArray[np.object_]] = array(
62+
Period,
6063
dims=("nper",),
6164
default=None,
6265
metadata={"block": "perioddata"},

pixi.lock

Lines changed: 17 additions & 134 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ classifiers = [
3535
requires-python = ">=3.11"
3636
dependencies = [
3737
"attrs", # todo: bounds?
38-
"beartype",
3938
"cattrs", # todo: bounds?
4039
"flopy",
4140
"Jinja2>=3.0",

test/test_component.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import pytest
32
from flopy.discretization import StructuredGrid
43
from flopy.discretization.modeltime import ModelTime
54
from xarray import DataTree
@@ -17,7 +16,7 @@ def test_registry():
1716
assert COMPONENTS["oc"] is Oc
1817

1918

20-
@pytest.mark.xfail(reason="TODO finish debugging")
19+
# @pytest.mark.xfail(reason="TODO finish debugging")
2120
def test_init_bottom_up():
2221
time = ModelTime(perlen=[1.0], nstp=[1], tsmult=[1.0])
2322
grid = StructuredGrid(nlay=1, nrow=2, ncol=2)
@@ -42,29 +41,27 @@ def test_init_bottom_up():
4241
tdis = Tdis(dims=dims)
4342
sim = Simulation(tdis=tdis, models={"gwf": gwf})
4443

45-
assert sim.tdis is tdis
46-
# TODO test autoincrement
47-
# assert sim.models["gwf0"] is gwf
48-
assert gwf.dis is dis
49-
assert gwf.ic is ic
50-
assert gwf.oc is oc
51-
assert gwf.npf is npf
52-
# TODO test multipackages e.g. chd
53-
# assert isinstance(gwf.chd, list)
54-
5544
assert isinstance(sim.data, DataTree)
5645
assert "tdis" in sim.data.children
5746
assert "gwf" in sim.data.children
5847
assert "dis" in sim.data.children["gwf"].children
5948
assert "ic" in sim.data.children["gwf"].children
6049
assert "oc" in sim.data.children["gwf"].children
6150
assert "npf" in sim.data.children["gwf"].children
62-
assert "perioddata" in sim.data.children["tdis"]
51+
52+
assert sim.tdis is tdis
53+
assert sim.models["gwf"] is gwf
54+
# TODO debug
55+
# assert gwf.dis is dis
56+
# assert gwf.ic is ic
57+
# assert gwf.oc is oc
58+
# assert gwf.npf is npf
59+
6360
assert np.array_equal(
6461
sim.data.children["gwf"].children["npf"].k, np.ones((4))
6562
)
6663
assert np.array_equal(npf.k, npf.data.k)
6764

68-
# TODO: figure out how to deduplicate trees. components proxy root?
65+
# TODO: debug
6966
# assert npf.k is npf.data.k
7067
# assert gwf.parent.data.children["gwf"].children["npf"] is npf.data

0 commit comments

Comments
 (0)