Skip to content

Commit 60d92b4

Browse files
authored
Add in-memory output providers (ItziModel#129)
* move DrainageNodeCouplingData to data_containers * get_node_list() no longer dependent on grass_interface * analytic tests now use in-memory results and use generated arrays as input
1 parent 7f26476 commit 60d92b4

File tree

21 files changed

+440
-630
lines changed

21 files changed

+440
-630
lines changed

src/itzi/data_containers.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,29 @@
1313
GNU General Public License for more details.
1414
"""
1515

16-
from typing import Dict, Tuple
16+
from typing import Dict, Tuple, TYPE_CHECKING
1717
import dataclasses
1818
from datetime import datetime
1919

2020
import numpy as np
2121

22+
if TYPE_CHECKING:
23+
from itzi.drainage import DrainageNode
24+
25+
26+
@dataclasses.dataclass(frozen=True)
27+
class DrainageNodeCouplingData:
28+
"""Store the translation between coordinates and array location for a given drainage node."""
29+
30+
node_id: str # Name of the drainage node
31+
node_object: "DrainageNode"
32+
# Location in the coordinate system
33+
x: float | None
34+
y: float | None
35+
# Location in the array
36+
row: int | None
37+
col: int | None
38+
2239

2340
@dataclasses.dataclass(frozen=True)
2441
class DrainageAttributes:

src/itzi/itzi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ def initialize(self, conf_file):
110110
raster_mask_id=self.conf.grass_params["mask"],
111111
)
112112
# Instantiate Simulation object and initialize it
113-
from itzi.simulation_factories import create_simulation
113+
from itzi.simulation_factories import create_grass_simulation
114114

115-
self.sim, self.tarr = create_simulation(
115+
self.sim, self.tarr = create_grass_simulation(
116116
sim_times=self.conf.sim_times,
117117
stats_file=self.conf.stats_file,
118118
input_maps=self.conf.input_map_names,

src/itzi/providers/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
from itzi.providers.grass_output import GrassRasterOutputProvider as GrassRasterOutputProvider
2-
from itzi.providers.grass_output import GrassVectorOutputProvider as GrassVectorOutputProvider

src/itzi/providers/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ def initialize(self, config: Dict) -> Self:
3030
pass
3131

3232
@abstractmethod
33-
def write_array(
34-
self, array: np.ndarray, map_name: str, map_key: str, sim_time: datetime | timedelta
35-
) -> None:
33+
def write_array(self, array: np.ndarray, map_key: str, sim_time: datetime | timedelta) -> None:
3634
"""Write simulation data for current time step."""
3735
pass
3836

@@ -58,6 +56,6 @@ def write_vector(
5856
pass
5957

6058
@abstractmethod
61-
def finalize(self) -> None:
59+
def finalize(self, drainage_data: DrainageNetworkData) -> None:
6260
"""Finalize outputs and cleanup."""
6361
pass
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# coding=utf8
2+
"""
3+
Copyright (C) 2025 Laurent Courty
4+
5+
This program is free software; you can redistribute it and/or
6+
modify it under the terms of the GNU General Public License
7+
as published by the Free Software Foundation; either version 2
8+
of the License, or (at your option) any later version.
9+
10+
This program is distributed in the hope that it will be useful,
11+
but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
GNU General Public License for more details.
14+
"""
15+
16+
from datetime import datetime, timedelta
17+
from typing import Dict, Self
18+
from copy import deepcopy
19+
20+
import numpy as np
21+
22+
from itzi.providers.base import RasterOutputProvider, VectorOutputProvider
23+
from itzi.data_containers import SimulationData, DrainageNetworkData
24+
25+
26+
class MemoryRasterOutputProvider(RasterOutputProvider):
27+
"""Save rasters in memory as numpy arrays."""
28+
29+
def initialize(self, config: Dict) -> Self:
30+
"""Initialize output provider with simulation configuration."""
31+
self.out_map_names = config["out_map_names"]
32+
33+
self.output_maps_dict = {k: [] for k in self.out_map_names.keys()}
34+
return self
35+
36+
def write_array(self, array: np.ndarray, map_key: str, sim_time: datetime | timedelta) -> None:
37+
"""Save simulation data for current time step."""
38+
self.output_maps_dict[map_key].append((deepcopy(sim_time), array.copy()))
39+
40+
def finalize(self, final_data: SimulationData) -> None:
41+
"""Finalize outputs and cleanup."""
42+
pass
43+
44+
45+
class MemoryVectorOutputProvider(VectorOutputProvider):
46+
"""Save drainage simulation outputs in memory."""
47+
48+
def initialize(self, config: Dict | None = None) -> Self:
49+
"""Initialize output provider with simulation configuration."""
50+
self.drainage_data = []
51+
52+
return self
53+
54+
def write_vector(
55+
self, drainage_data: DrainageNetworkData, sim_time: datetime | timedelta
56+
) -> None:
57+
"""Save simulation data for current time step."""
58+
self.drainage_data.append((deepcopy(sim_time), deepcopy(drainage_data)))
59+
60+
def finalize(self, drainage_data: DrainageNetworkData) -> None:
61+
"""Finalize outputs and cleanup."""
62+
pass

src/itzi/rasterdomain.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,57 @@
1414
"""
1515

1616
from datetime import datetime
17+
from typing import Tuple
1718
import numpy as np
1819

1920
import itzi.flow as flow
20-
from itzi import rastermetrics
21+
22+
23+
class DomainData:
24+
"""Store raster domain information. Alike GRASS region."""
25+
26+
def __init__(self, north: float, south: float, east: float, west: float, rows: int, cols: int):
27+
self.north = north
28+
self.south = south
29+
self.east = east
30+
self.west = west
31+
self.rows = rows
32+
self.cols = cols
33+
34+
if self.north < self.south:
35+
raise ValueError(f"north must be superior to south. {self.north=}, {self.south=}")
36+
if self.east < self.west:
37+
raise ValueError(f"east must be superior to west. {self.east=}, {self.west=}")
38+
39+
self.nsres = (self.north - self.south) / self.rows
40+
self.ewres = (self.east - self.west) / self.cols
41+
self.cell_area = self.ewres * self.nsres
42+
self.cell_shape = (self.ewres, self.nsres)
43+
self.shape = (self.rows, self.cols)
44+
self.cells = self.rows * self.cols
45+
46+
def is_in_domain(self, *, x: float, y: float) -> bool:
47+
"""For a given coordinate pair(x, y),
48+
return True is inside the domain, False otherwise.
49+
"""
50+
bool_x = self.west < x < self.east
51+
bool_y = self.south < y < self.north
52+
return bool(bool_x and bool_y)
53+
54+
def coordinates_to_pixel(self, *, x: float, y: float) -> Tuple[float, float] | None:
55+
"""For a given coordinate pair(x, y),
56+
return True is inside the domain, False otherwise.
57+
"""
58+
if not self.is_in_domain(x=x, y=y):
59+
return None
60+
else:
61+
norm_row = (y - self.south) / (self.north - self.south)
62+
row = int(np.round((1 - norm_row) * (self.rows - 1)))
63+
64+
norm_col = (x - self.west) / (self.east - self.west)
65+
col = int(np.round(norm_col * (self.cols - 1)))
66+
67+
return (row, col)
2168

2269

2370
class TimedArray:

src/itzi/simulation.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,23 @@
1414
"""
1515

1616
from datetime import datetime, timedelta
17-
from typing import Self
17+
from typing import Self, Union, TYPE_CHECKING
1818
import copy
1919

2020
import numpy as np
2121

22-
from itzi.surfaceflow import SurfaceFlowSimulation
23-
import itzi.rasterdomain as rasterdomain
24-
from itzi.report import Report
2522
from itzi.data_containers import ContinuityData, SimulationData
26-
from itzi.drainage import DrainageSimulation
2723
import itzi.messenger as msgr
28-
from itzi.hydrology import Hydrology
2924
from itzi.itzi_error import NullError, MassBalanceError
3025
from itzi import rastermetrics
3126

27+
if TYPE_CHECKING:
28+
from itzi.drainage import DrainageSimulation
29+
from itzi.hydrology import Hydrology
30+
from itzi.surfaceflow import SurfaceFlowSimulation
31+
from itzi.rasterdomain import RasterDomain
32+
from itzi.report import Report
33+
3234

3335
class Simulation:
3436
""" """
@@ -72,12 +74,12 @@ def __init__(
7274
self,
7375
start_time: datetime,
7476
end_time: datetime,
75-
raster_domain: rasterdomain.RasterDomain,
76-
hydrology_model: Hydrology,
77-
surface_flow: SurfaceFlowSimulation,
78-
drainage_model: DrainageSimulation | None,
77+
raster_domain: "RasterDomain",
78+
hydrology_model: "Hydrology",
79+
surface_flow: "SurfaceFlowSimulation",
80+
drainage_model: Union["DrainageSimulation", None],
7981
nodes_list: list | None,
80-
report: Report,
82+
report: "Report",
8183
mass_balance_error_threshold: float,
8284
):
8385
self.raster_domain = raster_domain
@@ -123,7 +125,7 @@ def __init__(
123125
self.next_ts["drainage"] = self.end_time
124126
else:
125127
self.node_id_to_loc = {
126-
n.id: (n.row, n.col) for n in self.nodes_list if n.object.is_coupled()
128+
n.node_id: (n.row, n.col) for n in self.nodes_list if n.node_object.is_coupled()
127129
}
128130
# Grid spacing (for BMI)
129131
self.spacing = (self.raster_domain.dy, self.raster_domain.dx)

0 commit comments

Comments
 (0)