Skip to content

Commit 7f26476

Browse files
authored
store drainage data attributes in a dataclass (ItziModel#127)
1 parent 6d2e8e2 commit 7f26476

File tree

5 files changed

+128
-125
lines changed

5 files changed

+128
-125
lines changed

src/itzi/data_containers.py

Lines changed: 65 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,91 +13,87 @@
1313
GNU General Public License for more details.
1414
"""
1515

16-
from typing import Dict, Tuple, ClassVar
17-
from dataclasses import dataclass
16+
from typing import Dict, Tuple
17+
import dataclasses
1818
from datetime import datetime
1919

2020
import numpy as np
2121

2222

23-
@dataclass(frozen=True)
23+
@dataclasses.dataclass(frozen=True)
24+
class DrainageAttributes:
25+
"""A base class for drainage data attributes."""
26+
27+
def get_columns_definition(self) -> list[tuple[str, str]]:
28+
"""Return a list of tuples to create DB columns"""
29+
type_corresp = {str: "TEXT", int: "INT", float: "REAL"}
30+
db_columns_def = [("cat", "INTEGER PRIMARY KEY")]
31+
for f in dataclasses.fields(self):
32+
db_field = (f.name, type_corresp[f.type])
33+
db_columns_def.append(db_field)
34+
return db_columns_def
35+
36+
37+
@dataclasses.dataclass(frozen=True)
38+
class DrainageLinkAttributes(DrainageAttributes):
39+
link_id: str
40+
link_type: str
41+
flow: float
42+
depth: float
43+
volume: float
44+
inlet_offset: float
45+
outlet_offset: float
46+
froude: float
47+
48+
49+
@dataclasses.dataclass(frozen=True)
2450
class DrainageLinkData:
2551
"""Store the instantaneous state of a node during a drainage simulation"""
2652

2753
vertices: Tuple[Tuple[float, float], ...]
28-
attributes: Tuple # one values for each columns, minus "cat"
29-
columns_definition: ClassVar[Tuple[Tuple[str, str], ...]] = (
30-
("cat", "INTEGER PRIMARY KEY"),
31-
("link_id", "TEXT"),
32-
("type", "TEXT"),
33-
("flow", "REAL"),
34-
("depth", "REAL"),
35-
# (u'velocity', 'REAL'),
36-
("volume", "REAL"),
37-
("offset1", "REAL"),
38-
("offset2", "REAL"),
39-
# (u'yFull', 'REAL'),
40-
("froude", "REAL"),
41-
)
42-
43-
def __post_init__(self):
44-
"""Validate attributes length after initialization."""
45-
expected_len = len(self.columns_definition) - 1
46-
if len(self.attributes) != expected_len:
47-
raise ValueError(
48-
f"DrainageLinkData: Incorrect number of attributes. "
49-
f"Expected {expected_len}, got {len(self.attributes)}"
50-
)
51-
52-
53-
@dataclass(frozen=True)
54+
attributes: DrainageLinkAttributes
55+
56+
57+
@dataclasses.dataclass(frozen=True)
58+
class DrainageNodeAttributes(DrainageAttributes):
59+
node_id: str
60+
node_type: str
61+
coupling_type: str
62+
coupling_flow: float
63+
inflow: float
64+
outflow: float
65+
lateral_inflow: float
66+
losses: float
67+
overflow: float
68+
depth: float
69+
head: float
70+
# crownElev: float
71+
crest_elevation: float
72+
invert_elevation: float
73+
initial_depth: float
74+
full_depth: float
75+
surcharge_depth: float
76+
ponding_area: float
77+
# degree: int
78+
volume: float
79+
full_volume: float
80+
81+
82+
@dataclasses.dataclass(frozen=True)
5483
class DrainageNodeData:
5584
"""Store the instantaneous state of a node during a drainage simulation"""
5685

5786
coordinates: Tuple[float, float]
58-
attributes: Tuple # one values for each columns, minus "cat"
59-
columns_definition: ClassVar[Tuple[Tuple[str, str], ...]] = (
60-
("cat", "INTEGER PRIMARY KEY"),
61-
("node_id", "TEXT"),
62-
("type", "TEXT"),
63-
("linkage_type", "TEXT"),
64-
("linkage_flow", "REAL"),
65-
("inflow", "REAL"),
66-
("outflow", "REAL"),
67-
("latFlow", "REAL"),
68-
("losses", "REAL"),
69-
("overflow", "REAL"),
70-
("depth", "REAL"),
71-
("head", "REAL"),
72-
# (u'crownElev', 'REAL'),
73-
("crestElev", "REAL"),
74-
("invertElev", "REAL"),
75-
("initDepth", "REAL"),
76-
("fullDepth", "REAL"),
77-
("surDepth", "REAL"),
78-
("pondedArea", "REAL"),
79-
# (u'degree', 'INT'),
80-
("newVolume", "REAL"),
81-
("fullVolume", "REAL"),
82-
)
83-
84-
def __post_init__(self):
85-
"""Validate attributes length after initialization."""
86-
expected_len = len(self.columns_definition) - 1
87-
if len(self.attributes) != expected_len:
88-
raise ValueError(
89-
f"DrainageNodeData: Incorrect number of attributes. "
90-
f"Expected {expected_len}, got {len(self.attributes)}"
91-
)
92-
93-
94-
@dataclass(frozen=True)
87+
attributes: DrainageNodeAttributes
88+
89+
90+
@dataclasses.dataclass(frozen=True)
9591
class DrainageNetworkData:
9692
nodes: Tuple[DrainageNodeData, ...]
9793
links: Tuple[DrainageLinkData, ...]
9894

9995

100-
@dataclass(frozen=True)
96+
@dataclasses.dataclass(frozen=True)
10197
class ContinuityData:
10298
"""Store information about simulation continuity"""
10399

@@ -107,7 +103,7 @@ class ContinuityData:
107103
continuity_error: float
108104

109105

110-
@dataclass(frozen=True)
106+
@dataclasses.dataclass(frozen=True)
111107
class SimulationData:
112108
"""Immutable data container for passing raw simulation state to Report.
113109
@@ -127,7 +123,7 @@ class SimulationData:
127123
drainage_network_data: DrainageNetworkData | None
128124

129125

130-
@dataclass(frozen=True)
126+
@dataclasses.dataclass(frozen=True)
131127
class MassBalanceData:
132128
"""Contains the fields written to the mass balance file"""
133129

src/itzi/drainage.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222

2323
from itzi import DefaultValues
2424
from itzi import messenger as msgr
25-
from itzi.data_containers import DrainageNodeData, DrainageLinkData, DrainageNetworkData
25+
from itzi.data_containers import (
26+
DrainageNodeData,
27+
DrainageLinkData,
28+
DrainageNetworkData,
29+
DrainageLinkAttributes,
30+
DrainageNodeAttributes,
31+
)
2632

2733

2834
class CouplingTypes(StrEnum):
@@ -168,37 +174,34 @@ def is_coupled(self):
168174
"""return True if the node is coupled to the 2D domain"""
169175
return self.coupling_type != CouplingTypes.NOT_COUPLED
170176

171-
def get_attrs(self):
172-
"""return a list of node data in the right DB order
173-
TODO: put the burden of DB order to the code actually writing the DB
174-
"""
175-
attrs = [
176-
self.node_id,
177-
self.node_type,
178-
self.coupling_type.value,
179-
self.coupling_flow,
180-
self.pyswmm_node.total_inflow,
181-
self.pyswmm_node.total_outflow,
182-
self.pyswmm_node.lateral_inflow,
183-
self.pyswmm_node.losses,
184-
self.get_overflow(),
185-
self.pyswmm_node.depth,
186-
self.pyswmm_node.head,
187-
# values['crown_elev'],
188-
self.get_crest_elev(),
189-
self.pyswmm_node.invert_elevation,
190-
self.pyswmm_node.initial_depth,
191-
self.pyswmm_node.full_depth,
192-
self.pyswmm_node.surcharge_depth,
193-
self.pyswmm_node.ponding_area,
194-
# values['degree'],
195-
self.pyswmm_node.volume,
196-
self.get_full_volume(),
197-
]
198-
return attrs
177+
def get_attrs(self) -> DrainageNodeAttributes:
178+
""" """
179+
return DrainageNodeAttributes(
180+
node_id=self.node_id,
181+
node_type=self.node_type,
182+
coupling_type=self.coupling_type.value,
183+
coupling_flow=self.coupling_flow,
184+
inflow=self.pyswmm_node.total_inflow,
185+
outflow=self.pyswmm_node.total_outflow,
186+
lateral_inflow=self.pyswmm_node.lateral_inflow,
187+
losses=self.pyswmm_node.losses,
188+
overflow=self.get_overflow(),
189+
depth=self.pyswmm_node.depth,
190+
head=self.pyswmm_node.head,
191+
# crownElev=values['crown_elev'],
192+
crest_elevation=self.get_crest_elev(),
193+
invert_elevation=self.pyswmm_node.invert_elevation,
194+
initial_depth=self.pyswmm_node.initial_depth,
195+
full_depth=self.pyswmm_node.full_depth,
196+
surcharge_depth=self.pyswmm_node.surcharge_depth,
197+
ponding_area=self.pyswmm_node.ponding_area,
198+
# degree=values['degree'],
199+
volume=self.pyswmm_node.volume,
200+
full_volume=self.get_full_volume(),
201+
)
199202

200203
def get_data(self) -> DrainageNodeData:
201-
return DrainageNodeData(coordinates=self.coordinates, attributes=tuple(self.get_attrs()))
204+
return DrainageNodeData(coordinates=self.coordinates, attributes=self.get_attrs())
202205

203206
def apply_coupling(self, z, h, dt_drainage, cell_surf):
204207
"""Apply the coupling to the node"""
@@ -333,23 +336,20 @@ def _get_link_type(self):
333336
raise ValueError(f"Unknown link type for link {self.link_id}")
334337
return link_type
335338

336-
def get_attrs(self):
337-
"""return a list of link data in the right DB order
338-
TODO: put the burden of DB order to the code actually writing the DB
339-
"""
340-
attrs = [
341-
self.link_id,
342-
self.link_type,
343-
self.pyswmm_link.flow,
344-
self.pyswmm_link.depth,
345-
# values['velocity'],
346-
self.pyswmm_link.volume,
347-
self.pyswmm_link.inlet_offset,
348-
self.pyswmm_link.outlet_offset,
349-
# values['full_depth'],
350-
self.pyswmm_link.froude,
351-
]
352-
return attrs
339+
def get_attrs(self) -> DrainageLinkAttributes:
340+
""" """
341+
return DrainageLinkAttributes(
342+
link_id=self.link_id,
343+
link_type=self.link_type,
344+
flow=self.pyswmm_link.flow,
345+
depth=self.pyswmm_link.depth,
346+
# velocity=values['velocity'],
347+
volume=self.pyswmm_link.volume,
348+
inlet_offset=self.pyswmm_link.inlet_offset,
349+
outlet_offset=self.pyswmm_link.outlet_offset,
350+
# full_depth=values['full_depth'],
351+
froude=self.pyswmm_link.froude,
352+
)
353353

354354
def get_data(self) -> DrainageLinkData:
355-
return DrainageLinkData(vertices=self.vertices, attributes=tuple(self.get_attrs()))
355+
return DrainageLinkData(vertices=self.vertices, attributes=self.get_attrs())

src/itzi/providers/grass_interface.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pathlib import Path
1818
from collections import namedtuple
1919
from datetime import datetime, timedelta
20+
import dataclasses
2021

2122
# from multiprocessing import Process, JoinableQueue
2223
from threading import Thread, Lock
@@ -547,8 +548,8 @@ def write_vector_map(self, drainage_data: DrainageNetworkData, map_name: str):
547548

548549
with VectorTopo(map_name, mode="w", overwrite=self.overwrite) as vect_map:
549550
# create db links and tables
550-
node_columns_def = drainage_data.nodes[0].columns_definition
551-
link_columns_def = drainage_data.links[0].columns_definition
551+
node_columns_def = drainage_data.nodes[0].attributes.get_columns_definition()
552+
link_columns_def = drainage_data.links[0].attributes.get_columns_definition()
552553
linking_elements = {
553554
"node": self.LayerDescr(
554555
table_suffix="_node", cols=node_columns_def, layer_number=1
@@ -573,7 +574,10 @@ def write_vector_map(self, drainage_data: DrainageNetworkData, map_name: str):
573574
map_layer, dbtable = dblinks["node"]
574575
self.write_vector_geometry(vect_map, point, cat_num, map_layer)
575576
# Get DB attributes
576-
attrs = (cat_num,) + node.attributes
577+
node_attributes = tuple(
578+
value for _, value in dataclasses.asdict(node.attributes).items()
579+
)
580+
attrs = (cat_num,) + node_attributes
577581
db_info["node"].append(attrs)
578582
# bump cat
579583
cat_num += 1
@@ -587,7 +591,10 @@ def write_vector_map(self, drainage_data: DrainageNetworkData, map_name: str):
587591
map_layer, dbtable = dblinks["link"]
588592
self.write_vector_geometry(vect_map, line_object, cat_num, map_layer)
589593
# keep DB info
590-
attrs = (cat_num,) + link.attributes
594+
link_attributes = tuple(
595+
value for _, value in dataclasses.asdict(link.attributes).items()
596+
)
597+
attrs = (cat_num,) + link_attributes
591598
db_info["link"].append(attrs)
592599
# bump cat
593600
cat_num += 1

tests/test_ea8b.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,12 @@ def ea_test8b_sim(ea_test8b, test_data_path, test_data_temp_path):
149149

150150
@pytest.fixture(scope="module")
151151
def ea8b_itzi_drainage_results(ea_test8b_sim):
152-
"""Extract linkage flow from the drainage network"""
152+
"""Extract coupling flow from the drainage network"""
153153
current_mapset = gscript.read_command("g.mapset", flags="p").rstrip()
154154
assert current_mapset == "ea8b"
155-
select_col = ["start_time", "linkage_flow"]
155+
select_col = ["start_time", "coupling_flow"]
156156
itzi_results = gscript.read_command("t.vect.db.select", input="out_drainage")
157-
# translate to Pandas dataframe and keep only linkage_flow with start_time over 3000
157+
# translate to Pandas dataframe and keep only coupling_flow with start_time over 3000
158158
df_itzi_results = pd.read_csv(StringIO(itzi_results), sep="|")[select_col]
159159
df_itzi_results = df_itzi_results[df_itzi_results.start_time >= 3000]
160160
df_itzi_results.set_index("start_time", drop=True, inplace=True, verify_integrity=True)

tests/test_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_tutorial_drainage(itzi_tutorial, test_data_path, test_data_temp_path, h
152152
sim_runner.initialize(config_file)
153153
sim_runner.run().finalize()
154154
# Check the results
155-
select_cols = ["start_time", "linkage_flow"]
155+
select_cols = ["start_time", "coupling_flow"]
156156
drainage_results = gscript.read_command(
157157
"t.vect.db.select",
158158
input="nc_itzi_tutorial_drainage",

0 commit comments

Comments
 (0)