Skip to content

Commit 73bcb00

Browse files
lrntctkarosc
andauthored
Replace dataclasses by pydantic BaseModel (#209)
* data_container.py transition from dataclasses to pydantic BaseModel * remove fail-fast from build wheel * build swmm-toolkit without stable ABI --------- Co-authored-by: karosc <[email protected]>
1 parent ed1356d commit 73bcb00

File tree

12 files changed

+179
-58
lines changed

12 files changed

+179
-58
lines changed

.github/workflows/build_wheels.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ jobs:
1010
name: Build wheels on ${{ matrix.os }}
1111
runs-on: ${{ matrix.os }}
1212
strategy:
13+
fail-fast: false
1314
matrix:
1415
# macos-13 is intel, macos-14 is apple silicon, windows-11-arm
1516
os: [ubuntu-latest, ubuntu-24.04-arm, windows-latest, windows-11-arm]
@@ -27,10 +28,11 @@ jobs:
2728
env:
2829
CIBW_ARCHS: auto64 # Don't build 32 bits wheels
2930
CIBW_SKIP: pp* # Don't build pypy wheels
30-
CIBW_ENVIRONMENT: "ITZI_BDIST_WHEEL=1" # Trigger generic compiler flags
31+
CIBW_ENVIRONMENT: "ITZI_BDIST_WHEEL=1 NO_STABLE_ABI=1" # Trigger generic compiler flags
3132
CIBW_ENVIRONMENT_MACOS: >
3233
MACOSX_DEPLOYMENT_TARGET=14.0
33-
ITZI_BDIST_WHEEL=1
34+
ITZI_BDIST_WHEEL=1
35+
NO_STABLE_ABI=1
3436
CIBW_BEFORE_TEST: "pip install pytest pytest-benchmark pandas"
3537
CIBW_TEST_SOURCES: tests
3638
CIBW_TEST_COMMAND: > # Fast tests which do not require GRASS

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies = [
99
"numpy>=2.2",
1010
"pyswmm>=2.1.0",
1111
"bmipy>=2.0.1",
12+
"pydantic>=2.12.5",
1213
]
1314
requires-python = ">=3.11,<3.14" # pyswmm does not support python > 3.13
1415
readme = "README.md"

src/itzi/data_containers.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,26 @@
1212
GNU General Public License for more details.
1313
"""
1414

15+
from __future__ import annotations
16+
1517
from typing import Dict, Tuple, TYPE_CHECKING
16-
import dataclasses
1718
from datetime import datetime, timedelta
19+
from pathlib import Path
1820

1921
import numpy as np
22+
from pydantic import BaseModel, ConfigDict
2023

2124
from itzi.const import DefaultValues, TemporalType, InfiltrationModelType
2225

2326
if TYPE_CHECKING:
2427
from itzi.drainage import DrainageNode
2528

2629

27-
@dataclasses.dataclass(frozen=True)
28-
class DrainageNodeCouplingData:
30+
class DrainageNodeCouplingData(BaseModel):
2931
"""Store the translation between coordinates and array location for a given drainage node."""
3032

33+
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
34+
3135
node_id: str # Name of the drainage node
3236
node_object: "DrainageNode"
3337
# Location in the coordinate system
@@ -38,24 +42,24 @@ class DrainageNodeCouplingData:
3842
col: int | None
3943

4044

41-
@dataclasses.dataclass(frozen=True)
42-
class DrainageAttributes:
45+
class DrainageAttributes(BaseModel):
4346
"""A base class for drainage data attributes."""
4447

48+
model_config = ConfigDict(frozen=True)
49+
4550
@classmethod
4651
def get_columns_definition(cls, cat_primary_key=True) -> list[tuple[str, str]]:
4752
"""Return a list of tuples to create DB columns"""
4853
type_mapping = {str: "TEXT", int: "INT", float: "REAL"}
4954
db_columns_def = [("cat", "INTEGER PRIMARY KEY")]
5055
if not cat_primary_key:
5156
db_columns_def = []
52-
for f in dataclasses.fields(cls):
53-
db_field = (f.name, type_mapping[f.type])
57+
for field_name, field_info in cls.model_fields.items():
58+
db_field = (field_name, type_mapping[field_info.annotation])
5459
db_columns_def.append(db_field)
5560
return db_columns_def
5661

5762

58-
@dataclasses.dataclass(frozen=True)
5963
class DrainageLinkAttributes(DrainageAttributes):
6064
link_id: str
6165
link_type: str
@@ -67,16 +71,16 @@ class DrainageLinkAttributes(DrainageAttributes):
6771
froude: float
6872

6973

70-
@dataclasses.dataclass(frozen=True)
71-
class DrainageLinkData:
74+
class DrainageLinkData(BaseModel):
7275
"""Store the instantaneous state of a node during a drainage simulation.
7376
Vertices include the coordinates of the start and end nodes."""
7477

75-
vertices: None | Tuple[Tuple[float, float], ...]
78+
model_config = ConfigDict(frozen=True)
79+
80+
vertices: None | Tuple[Tuple[float, float] | None, ...]
7681
attributes: DrainageLinkAttributes
7782

7883

79-
@dataclasses.dataclass(frozen=True)
8084
class DrainageNodeAttributes(DrainageAttributes):
8185
node_id: str
8286
node_type: str
@@ -101,55 +105,60 @@ class DrainageNodeAttributes(DrainageAttributes):
101105
full_volume: float
102106

103107

104-
@dataclasses.dataclass(frozen=True)
105-
class DrainageNodeData:
108+
class DrainageNodeData(BaseModel):
106109
"""Store the instantaneous state of a node during a drainage simulation"""
107110

111+
model_config = ConfigDict(frozen=True)
112+
108113
coordinates: None | Tuple[float, float]
109114
attributes: DrainageNodeAttributes
110115

111116

112-
@dataclasses.dataclass(frozen=True)
113-
class DrainageNetworkData:
117+
class DrainageNetworkData(BaseModel):
118+
model_config = ConfigDict(frozen=True)
119+
114120
nodes: Tuple[DrainageNodeData, ...]
115121
links: Tuple[DrainageLinkData, ...]
116122

117123

118-
@dataclasses.dataclass(frozen=True)
119-
class ContinuityData:
124+
class ContinuityData(BaseModel):
120125
"""Store information about simulation continuity"""
121126

127+
model_config = ConfigDict(frozen=True)
128+
122129
new_domain_vol: float
123130
volume_change: float
124131
volume_error: float
125132
continuity_error: float
126133

127134

128-
@dataclasses.dataclass(frozen=True)
129-
class SimulationData:
135+
class SimulationData(BaseModel):
130136
"""Immutable data container for passing raw simulation state to Report.
131137
132138
This is a pure data structure containing only the "raw ingredients"
133139
needed for a report. All report-specific calculations (e.g., WSE,
134140
average rates) are performed by the Report class itself.
135141
"""
136142

143+
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
144+
137145
sim_time: datetime
138146
time_step: float # time step duration
139147
time_steps_counter: int # number of time steps since last update
140-
continuity_data: ContinuityData
148+
continuity_data: ContinuityData | None # Made optional for use in tests
141149
raw_arrays: Dict[str, np.ndarray]
142150
accumulation_arrays: Dict[str, np.ndarray]
143151
cell_dx: float # cell size in east-west direction
144152
cell_dy: float # cell size in north-south direction
145153
drainage_network_data: DrainageNetworkData | None
146154

147155

148-
@dataclasses.dataclass(frozen=True)
149-
class MassBalanceData:
156+
class MassBalanceData(BaseModel):
150157
"""Contains the fields written to the mass balance file"""
151158

152-
simulation_time: datetime
159+
model_config = ConfigDict(frozen=True)
160+
161+
simulation_time: datetime | timedelta
153162
average_timestep: float
154163
timesteps: int
155164
boundary_volume: float
@@ -164,10 +173,11 @@ class MassBalanceData:
164173
percent_error: float
165174

166175

167-
@dataclasses.dataclass(frozen=True)
168-
class SurfaceFlowParameters:
176+
class SurfaceFlowParameters(BaseModel):
169177
"""Parameters for the surface flow model."""
170178

179+
model_config = ConfigDict(frozen=True)
180+
171181
hmin: float = DefaultValues.HFMIN
172182
cfl: float = DefaultValues.CFL
173183
theta: float = DefaultValues.THETA
@@ -179,22 +189,23 @@ class SurfaceFlowParameters:
179189
max_error: float = DefaultValues.MAX_ERROR
180190

181191

182-
@dataclasses.dataclass(frozen=True)
183-
class SimulationConfig:
192+
class SimulationConfig(BaseModel):
184193
"""Configuration data for a simulation run."""
185194

195+
model_config = ConfigDict(frozen=True)
196+
186197
# Simulation times
187198
start_time: datetime
188199
end_time: datetime
189200
record_step: timedelta
190201
temporal_type: TemporalType
191202
# Input and output raster maps
192-
input_map_names: Dict[str, str]
193-
output_map_names: Dict[str, str]
203+
input_map_names: Dict[str, str | None]
204+
output_map_names: Dict[str, str | None]
194205
# Surface flow parameters
195206
surface_flow_parameters: SurfaceFlowParameters
196207
# Mass balance file
197-
stats_file: str
208+
stats_file: str | Path | None = None
198209
# Hydrology parameters
199210
dtinf: float = DefaultValues.DTINF
200211
infiltration_model: InfiltrationModelType = InfiltrationModelType.NULL

src/itzi/drainage.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,9 @@ def get_attrs(self) -> DrainageLinkAttributes:
352352

353353
def get_data(self) -> DrainageLinkData:
354354
return DrainageLinkData(vertices=self.vertices, attributes=self.get_attrs())
355+
356+
357+
# Rebuild Pydantic models that have forward references to DrainageNode
358+
from itzi.data_containers import DrainageNodeCouplingData # noqa E402
359+
360+
DrainageNodeCouplingData.model_rebuild()

src/itzi/massbalance.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from datetime import datetime
1616
import csv
1717
import numbers
18-
import dataclasses
1918

2019
from itzi.data_containers import MassBalanceData
2120

@@ -28,7 +27,7 @@ def __init__(
2827
file_name: str,
2928
):
3029
"""Initializes the logger and creates the output file with headers."""
31-
self.fields = [f.name for f in dataclasses.fields(MassBalanceData)]
30+
self.fields = list(MassBalanceData.model_fields.keys())
3231
self.file_name = self._set_file_name(file_name)
3332
self._create_file()
3433

@@ -48,7 +47,7 @@ def log(self, report_data: MassBalanceData) -> None:
4847
"""Writes a single line of data to the CSV file."""
4948
line_to_write = {}
5049

51-
for key, value in dataclasses.asdict(report_data).items():
50+
for key, value in report_data.model_dump().items():
5251
if value != value: # test for NaN
5352
line_to_write[key] = "-"
5453
elif "percent_error" == key:

src/itzi/providers/csv_output.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import TypedDict, TYPE_CHECKING, Tuple, List
1919
from io import StringIO
2020
import csv
21-
import dataclasses
2221

2322
import pandas as pd
2423

@@ -74,8 +73,8 @@ def __init__(self, config: CSVVectorOutputConfig) -> None:
7473
self.append_mode = {"link": False, "node": False}
7574

7675
for geom_type, obj in [("node", DrainageNodeAttributes), ("link", DrainageLinkAttributes)]:
77-
base_headers = [field.name for field in dataclasses.fields(obj)]
78-
self.headers[geom_type] = ["sim_time"] + list(base_headers) + ["srid", "geometry"]
76+
base_headers = list(obj.model_fields.keys())
77+
self.headers[geom_type] = ["sim_time"] + base_headers + ["srid", "geometry"]
7978

8079
results_name = f"{config['drainage_results_name']}_{geom_type}s.csv"
8180
self.file_paths[geom_type] = results_prefix + "/" + results_name
@@ -207,7 +206,7 @@ def _update_csv(
207206
# IDs must match
208207
new_ids = set(
209208
[
210-
dataclasses.asdict(drainage_elem.attributes)[f"{geom_type}_id"]
209+
drainage_elem.attributes.model_dump()[f"{geom_type}_id"]
211210
for drainage_elem in drainage_elements
212211
]
213212
)
@@ -232,7 +231,7 @@ def _update_csv(
232231
def _attrs_line(self, drainage_element: DrainageNodeData | DrainageLinkData) -> List[str, ...]:
233232
"""Return a list of attributes"""
234233
# Convert attributes to list
235-
attributes = [str(a) for a in dataclasses.asdict(drainage_element.attributes).values()]
234+
attributes = [str(a) for a in drainage_element.attributes.model_dump().values()]
236235
# Create geometry WKT
237236
if isinstance(drainage_element, DrainageNodeData):
238237
if drainage_element.coordinates is not None:

src/itzi/providers/grass_interface.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from pathlib import Path
1717
from collections import namedtuple
1818
from datetime import datetime, timedelta
19-
import dataclasses
2019
from typing import Self
2120

2221
# from multiprocessing import Process, JoinableQueue
@@ -225,7 +224,6 @@ def finalize(self):
225224

226225
def cleanup(self):
227226
"""Remove temporary region and mask."""
228-
msgr.debug("Reset mask and region")
229227
if self.raster_mask_id:
230228
msgr.debug("Remove temp MASK...")
231229
self.del_temp_mask()
@@ -542,9 +540,7 @@ def write_vector_map(self, drainage_data: DrainageNetworkData, map_name: str) ->
542540
# The write function of the vector map set the layer to the one we set earlier
543541
vector_map.write(point, cat=cat_num)
544542
# Get DB attributes even if no associated geometry
545-
node_attributes = tuple(
546-
value for _, value in dataclasses.asdict(node.attributes).items()
547-
)
543+
node_attributes = tuple(value for _, value in node.attributes.model_dump().items())
548544
attrs = (cat_num,) + node_attributes
549545
db_info["node"].append(attrs)
550546
cat_num += 1
@@ -560,9 +556,7 @@ def write_vector_map(self, drainage_data: DrainageNetworkData, map_name: str) ->
560556
# The write function of the vector map set the layer to the one we set earlier
561557
vector_map.write(line_object, cat=cat_num)
562558
# Get DB attributes even if no associated geometry
563-
link_attributes = tuple(
564-
value for _, value in dataclasses.asdict(link.attributes).items()
565-
)
559+
link_attributes = tuple(value for _, value in link.attributes.model_dump().items())
566560
attrs = (cat_num,) + link_attributes
567561
db_info["link"].append(attrs)
568562
cat_num += 1

src/itzi/simulation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,12 @@ def get_continuity_data(self) -> ContinuityData:
413413
else:
414414
continuity_error = volume_error / volume_change
415415

416-
return ContinuityData(new_domain_vol, volume_change, volume_error, continuity_error)
416+
return ContinuityData(
417+
new_domain_vol=new_domain_vol,
418+
volume_change=volume_change,
419+
volume_error=volume_error,
420+
continuity_error=continuity_error,
421+
)
417422

418423
def _update_accum_array(self, k: str, sim_time: datetime) -> None:
419424
"""Update the accumulation arrays."""

tests/grass_tests/test_itzi.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import os
44
from io import StringIO
5-
import dataclasses
65

76
import pandas as pd
87
import numpy as np
@@ -123,7 +122,7 @@ def test_stats_file(test_data_temp_path):
123122
assert os.path.exists(stats_path)
124123
df = pd.read_csv(stats_path)
125124

126-
expected_cols = [f.name for f in dataclasses.fields(MassBalanceData)]
125+
expected_cols = list(MassBalanceData.model_fields.keys())
127126
assert df.columns.to_list() == expected_cols
128127

129128
# Domain area in m2

tests/grass_tests/test_tutorial.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pathlib
77
import tempfile
88
from configparser import ConfigParser
9-
from dataclasses import fields
109

1110
import pytest
1211
import numpy as np
@@ -272,7 +271,7 @@ def test_tutorial_drainage(itzi_tutorial, test_data_path, test_data_temp_path, h
272271
assert len(link_entries) == 2
273272
# link DB columns are as expected
274273
actual_link_columns = v_db_select[0].split("|")
275-
expected_link_columns = ["cat"] + [field.name for field in fields(DrainageLinkAttributes)]
274+
expected_link_columns = ["cat"] + list(DrainageLinkAttributes.model_fields.keys())
276275
assert expected_link_columns == actual_link_columns
277276

278277
# Check nodes DB table
@@ -282,5 +281,5 @@ def test_tutorial_drainage(itzi_tutorial, test_data_path, test_data_temp_path, h
282281
assert len(nodes_entries) == 3
283282
# node DB columns are as expected
284283
actual_nodes_columns = v_db_select[0].split("|")
285-
expected_nodes_columns = ["cat"] + [field.name for field in fields(DrainageNodeAttributes)]
284+
expected_nodes_columns = ["cat"] + list(DrainageNodeAttributes.model_fields.keys())
286285
assert expected_nodes_columns == actual_nodes_columns

0 commit comments

Comments
 (0)