Skip to content

Commit b3e4c5a

Browse files
authored
allow CSV write to receive None. Update type hint accordingly (#207)
1 parent 6502754 commit b3e4c5a

File tree

3 files changed

+21
-14
lines changed

3 files changed

+21
-14
lines changed

src/itzi/providers/base.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
GNU General Public License for more details.
1313
"""
1414

15+
from __future__ import annotations
16+
1517
from abc import ABC, abstractmethod
1618
from typing import Mapping, TYPE_CHECKING, Union, Tuple
1719

@@ -36,14 +38,14 @@ def get_origin(self) -> Tuple[float, float]:
3638
return (domain_data.north, domain_data.west)
3739

3840
@abstractmethod
39-
def get_domain_data(self) -> "DomainData":
41+
def get_domain_data(self) -> DomainData:
4042
"""Return a DomainData object."""
4143
pass
4244

4345
@abstractmethod
4446
def get_array(
45-
self, map_key: str, current_time: "datetime"
46-
) -> Tuple["np.ndarray", "datetime", "datetime"]:
47+
self, map_key: str, current_time: datetime
48+
) -> Tuple[np.ndarray, datetime, datetime]:
4749
"""Take a given map key and current time
4850
return a numpy array associated with its start and end time
4951
if no map is found, return None instead of an array
@@ -61,13 +63,13 @@ def __init__(self, config: Mapping) -> None:
6163

6264
@abstractmethod
6365
def write_arrays(
64-
self, array_dict: Mapping[str, "np.ndarray"], sim_time: Union["datetime", "timedelta"]
66+
self, array_dict: Mapping[str, np.ndarray], sim_time: Union[datetime, timedelta]
6567
) -> None:
6668
"""Write all arrays for the current time step."""
6769
pass
6870

6971
@abstractmethod
70-
def finalize(self, final_data: "SimulationData") -> None:
72+
def finalize(self, final_data: SimulationData) -> None:
7173
"""Finalize outputs and cleanup."""
7274
pass
7375

@@ -82,12 +84,12 @@ def __init__(self, config: Mapping) -> None:
8284

8385
@abstractmethod
8486
def write_vector(
85-
self, drainage_data: "DrainageNetworkData", sim_time: Union["datetime", "timedelta"]
87+
self, drainage_data: DrainageNetworkData | None, sim_time: Union[datetime, timedelta]
8688
) -> None:
8789
"""Write simulation data for current time step."""
8890
pass
8991

9092
@abstractmethod
91-
def finalize(self, drainage_data: "DrainageNetworkData") -> None:
93+
def finalize(self, drainage_data: DrainageNetworkData | None) -> None:
9294
"""Finalize outputs and cleanup."""
9395
pass

src/itzi/providers/csv_output.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""
1414

1515
from __future__ import annotations
16+
1617
from datetime import datetime, timedelta
1718
from typing import TypedDict, TYPE_CHECKING, Tuple, List
1819
from io import StringIO
@@ -88,9 +89,11 @@ def __init__(self, config: CSVVectorOutputConfig) -> None:
8889
print(self.existing_max_time)
8990

9091
def write_vector(
91-
self, drainage_data: DrainageNetworkData, sim_time: datetime | timedelta
92+
self, drainage_data: DrainageNetworkData | None, sim_time: datetime | timedelta
9293
) -> None:
9394
"""Save simulation data for current time step."""
95+
if not drainage_data:
96+
return
9497
# Validate time on first write
9598
self._validate_time_on_first_write(sim_time)
9699
# Convert sim_time to ISO8601 format
@@ -106,7 +109,7 @@ def write_vector(
106109
self._update_csv(sim_time_str, "link", drainage_data.links)
107110
self.number_of_writes["link"] += 1
108111

109-
def finalize(self, drainage_data: DrainageNetworkData) -> None:
112+
def finalize(self, drainage_data: DrainageNetworkData | None) -> None:
110113
"""Finalize outputs and cleanup."""
111114
pass
112115

src/itzi/providers/grass_output.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
GNU General Public License for more details.
1313
"""
1414

15+
from __future__ import annotations
16+
1517
from typing import Dict, TypedDict, Union, TYPE_CHECKING
1618

1719
import numpy as np
@@ -51,7 +53,7 @@ def __init__(self, config: GrassRasterOutputConfig) -> None:
5153
self.output_maplist = {k: [] for k in self.out_map_names.keys()}
5254

5355
def _write_array(
54-
self, array: np.ndarray, map_key: str, sim_time: Union["datetime", "timedelta"]
56+
self, array: np.ndarray, map_key: str, sim_time: Union[datetime, timedelta]
5557
) -> None:
5658
"""Write simulation data for current time step."""
5759
suffix = str(self.record_counter[map_key]).zfill(4)
@@ -66,7 +68,7 @@ def _write_array(
6668
self.record_counter[map_key] += 1
6769

6870
def write_arrays(
69-
self, array_dict: Dict[str, np.ndarray], sim_time: Union["datetime", "timedelta"]
71+
self, array_dict: Dict[str, np.ndarray], sim_time: Union[datetime, timedelta]
7072
) -> None:
7173
for arr_key, arr in array_dict.items():
7274
if isinstance(arr, np.ndarray):
@@ -76,7 +78,7 @@ def _write_max_array(self, arr_max, map_key):
7678
map_max_name = f"{self.out_map_names[map_key]}_max"
7779
self.grass_interface.write_raster_map(arr_max, map_max_name, map_key, hmin=0.0)
7880

79-
def finalize(self, final_data: "SimulationData") -> None:
81+
def finalize(self, final_data: SimulationData) -> None:
8082
"""Finalize outputs and cleanup."""
8183

8284
# Write the final raster maps
@@ -109,7 +111,7 @@ def __init__(self, config: GrassVectorOutputConfig) -> None:
109111
self.vector_drainage_maplist = []
110112

111113
def write_vector(
112-
self, drainage_data: "DrainageNetworkData", sim_time: Union["datetime", "timedelta"]
114+
self, drainage_data: DrainageNetworkData | None, sim_time: Union[datetime, timedelta]
113115
) -> None:
114116
"""Write drainage simulation data for current time step."""
115117
if self.drainage_map_name and drainage_data:
@@ -122,7 +124,7 @@ def write_vector(
122124
self.vector_drainage_maplist.append((map_name, sim_time))
123125
self.record_counter += 1
124126

125-
def finalize(self, drainage_data: "DrainageNetworkData") -> None:
127+
def finalize(self, drainage_data: DrainageNetworkData | None) -> None:
126128
"""Finalize outputs and cleanup."""
127129
if self.drainage_map_name and drainage_data:
128130
self.grass_interface.register_maps_in_stds(

0 commit comments

Comments
 (0)