|
| 1 | +""" |
| 2 | +Copyright (C) 2025 Laurent Courty |
| 3 | +
|
| 4 | +This program is free software; you can redistribute it and/or |
| 5 | +modify it under the terms of the GNU General Public License |
| 6 | +as published by the Free Software Foundation; either version 2 |
| 7 | +of the License, or (at your option) any later version. |
| 8 | +
|
| 9 | +This program is distributed in the hope that it will be useful, |
| 10 | +but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 11 | +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 12 | +GNU General Public License for more details. |
| 13 | +""" |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | +from datetime import datetime, timedelta |
| 17 | +from typing import TypedDict, TYPE_CHECKING, Tuple, List |
| 18 | +from io import StringIO |
| 19 | +import csv |
| 20 | +import dataclasses |
| 21 | + |
| 22 | +import pandas as pd |
| 23 | + |
| 24 | +from itzi.providers.base import VectorOutputProvider |
| 25 | +from itzi.data_containers import DrainageLinkData, DrainageLinkAttributes |
| 26 | +from itzi.data_containers import DrainageNodeData, DrainageNodeAttributes |
| 27 | + |
| 28 | +if TYPE_CHECKING: |
| 29 | + from itzi.data_containers import DrainageNetworkData |
| 30 | + |
| 31 | +try: |
| 32 | + import obstore |
| 33 | + import pyproj |
| 34 | +except ImportError: |
| 35 | + raise ImportError( |
| 36 | + "To use the CSV backend, install itzi with: " |
| 37 | + "'uv tool install itzi[cloud]' " |
| 38 | + "or 'pip install itzi[cloud]'" |
| 39 | + ) |
| 40 | + |
| 41 | + |
| 42 | +class CSVVectorOutputConfig(TypedDict): |
| 43 | + crs: pyproj.CRS | None |
| 44 | + store: obstore.store.ObjectStore |
| 45 | + results_prefix: str |
| 46 | + drainage_results_name: str |
| 47 | + overwrite: bool |
| 48 | + |
| 49 | + |
| 50 | +class CSVVectorOutputProvider(VectorOutputProvider): |
| 51 | + """Save drainage simulation outputs in CSV files hosted on a cloud object storage. |
| 52 | + Write two files: |
| 53 | + - one for nodes, with suffix *_nodes.csv |
| 54 | + - one for links, with suffix *_nodes.csv |
| 55 | + If a file already exists at the prefix and overwrite is False, the results are appended if possible.""" |
| 56 | + |
| 57 | + def __init__(self, config: CSVVectorOutputConfig) -> None: |
| 58 | + """Initialize output provider with provider configuration.""" |
| 59 | + try: |
| 60 | + self.srid = config["crs"].to_epsg() |
| 61 | + except AttributeError: |
| 62 | + self.srid = 0 |
| 63 | + self.store = config["store"] |
| 64 | + results_prefix = config["results_prefix"] |
| 65 | + |
| 66 | + self.existing_ids = {"link": None, "node": None} # Objects ids already in the file |
| 67 | + self.existing_max_time = {"link": None, "node": None} # Max of sim_time in existing_file |
| 68 | + self.number_of_writes = {"link": 0, "node": 0} |
| 69 | + self.file_paths = {"link": None, "node": None} |
| 70 | + self.headers = {"link": None, "node": None} |
| 71 | + self.append_mode = {"link": True, "node": True} |
| 72 | + if config["overwrite"]: |
| 73 | + self.append_mode = {"link": False, "node": False} |
| 74 | + |
| 75 | + for geom_type, obj in [("node", DrainageNodeAttributes), ("link", DrainageLinkAttributes)]: |
| 76 | + base_headers = [field.name for field in dataclasses.fields(obj)] |
| 77 | + self.headers[geom_type] = ["sim_time"] + list(base_headers) + ["srid", "geometry"] |
| 78 | + |
| 79 | + results_name = f"{config['drainage_results_name']}_{geom_type}s.csv" |
| 80 | + self.file_paths[geom_type] = results_prefix + "/" + results_name |
| 81 | + # No need to check if we overwrite |
| 82 | + if not config["overwrite"]: |
| 83 | + self._check_existing_csv(geom_type) |
| 84 | + # create the CSV files |
| 85 | + if not self.append_mode[geom_type]: |
| 86 | + self._write_headers(geom_type) |
| 87 | + print(self.existing_ids) |
| 88 | + print(self.existing_max_time) |
| 89 | + |
| 90 | + def write_vector( |
| 91 | + self, drainage_data: DrainageNetworkData, sim_time: datetime | timedelta |
| 92 | + ) -> None: |
| 93 | + """Save simulation data for current time step.""" |
| 94 | + # Validate time on first write |
| 95 | + self._validate_time_on_first_write(sim_time) |
| 96 | + # Convert sim_time to ISO8601 format |
| 97 | + if isinstance(sim_time, timedelta): |
| 98 | + # ISO8601 duration format: PT{seconds}S |
| 99 | + sim_time_str = f"PT{sim_time.total_seconds()}S" |
| 100 | + else: |
| 101 | + sim_time_str = sim_time.isoformat() |
| 102 | + # Nodes |
| 103 | + self._update_csv(sim_time_str, "node", drainage_data.nodes) |
| 104 | + self.number_of_writes["node"] += 1 |
| 105 | + # Links |
| 106 | + self._update_csv(sim_time_str, "link", drainage_data.links) |
| 107 | + self.number_of_writes["link"] += 1 |
| 108 | + |
| 109 | + def finalize(self, drainage_data: DrainageNetworkData) -> None: |
| 110 | + """Finalize outputs and cleanup.""" |
| 111 | + pass |
| 112 | + |
| 113 | + def _check_existing_csv(self, geom_type: str): |
| 114 | + """In order to be compatible, an existing CSV should have: |
| 115 | + - Same headers |
| 116 | + Other compatibility issues, like: |
| 117 | + - Same sim_time type |
| 118 | + - new sim_time < existing |
| 119 | + - new object ID ≠ existing ones |
| 120 | + could not be checked without drainage network data |
| 121 | + """ |
| 122 | + existing_csv = None |
| 123 | + try: |
| 124 | + existing_csv = StringIO( |
| 125 | + bytes(obstore.get(self.store, self.file_paths[geom_type]).bytes()).decode("utf-8") |
| 126 | + ) |
| 127 | + except FileNotFoundError: |
| 128 | + self.append_mode[geom_type] = False |
| 129 | + return |
| 130 | + df_csv = pd.read_csv(existing_csv) |
| 131 | + existing_headers = list(df_csv.columns) |
| 132 | + expected_headers = self.headers[geom_type] |
| 133 | + if not existing_headers == expected_headers: |
| 134 | + raise ValueError(f"Headers mismatch in existing file {self.file_paths[geom_type]}.") |
| 135 | + self.append_mode[geom_type] = False |
| 136 | + id_col = f"{geom_type}_id" |
| 137 | + |
| 138 | + # Store values existing ids |
| 139 | + self.existing_ids[geom_type] = set(df_csv[id_col]) |
| 140 | + # Store maximum sim_time values |
| 141 | + try: |
| 142 | + self.existing_max_time[geom_type] = pd.to_timedelta( |
| 143 | + max(df_csv["sim_time"]) |
| 144 | + ).to_pytimedelta() |
| 145 | + except ValueError: |
| 146 | + try: |
| 147 | + self.existing_max_time[geom_type] = pd.to_datetime( |
| 148 | + max(df_csv["sim_time"]) |
| 149 | + ).to_pydatetime() |
| 150 | + except ValueError: |
| 151 | + raise ValueError( |
| 152 | + f"Unknown sim_time column in existing file {self.file_paths[geom_type]}." |
| 153 | + ) |
| 154 | + print(df_csv) |
| 155 | + |
| 156 | + def _write_headers(self, geom_type: str): |
| 157 | + """Create an in-memory CSV file with headers and save it in the store.""" |
| 158 | + f_obj = StringIO() |
| 159 | + writer = csv.writer(f_obj) |
| 160 | + writer.writerow(self.headers[geom_type]) |
| 161 | + csv_content = f_obj.getvalue() |
| 162 | + obstore.put(self.store, self.file_paths[geom_type], file=csv_content.encode("utf-8")) |
| 163 | + return self |
| 164 | + |
| 165 | + def _validate_time_on_first_write(self, sim_time: datetime | timedelta) -> None: |
| 166 | + """Validate sim_time type matches existing files on first write.""" |
| 167 | + |
| 168 | + for geom_type in ["node", "link"]: |
| 169 | + # Only validate on first write |
| 170 | + if self.number_of_writes[geom_type] > 0 or self.existing_max_time[geom_type] is None: |
| 171 | + continue |
| 172 | + # Type must match |
| 173 | + if type(self.existing_max_time[geom_type]) is not type(sim_time): |
| 174 | + time_type_name = ( |
| 175 | + "relative (timedelta)" |
| 176 | + if isinstance(sim_time, timedelta) |
| 177 | + else "absolute (datetime)" |
| 178 | + ) |
| 179 | + existing_type_name = ( |
| 180 | + "relative" |
| 181 | + if isinstance(self.existing_max_time[geom_type], timedelta) |
| 182 | + else "absolute" |
| 183 | + ) |
| 184 | + raise ValueError( |
| 185 | + f"Time type mismatch for {geom_type}: " |
| 186 | + f"attempting to write {time_type_name} but existing file has {existing_type_name}" |
| 187 | + ) |
| 188 | + # Time must increase |
| 189 | + if not sim_time > self.existing_max_time[geom_type]: |
| 190 | + raise ValueError( |
| 191 | + f"Time not increasing for {geom_type}: attempting to write {sim_time} but " |
| 192 | + f"existing file has a max sim_time value of {self.existing_max_time[geom_type]}" |
| 193 | + ) |
| 194 | + |
| 195 | + def _update_csv( |
| 196 | + self, |
| 197 | + sim_time_str: str, |
| 198 | + geom_type: str, |
| 199 | + drainage_elements: Tuple[DrainageNodeData, ... | DrainageLinkData, ...], |
| 200 | + ): |
| 201 | + """Update adequate CSV in object store""" |
| 202 | + # Check compatibility on first write |
| 203 | + if 0 == self.number_of_writes[geom_type] and self.existing_ids[geom_type]: |
| 204 | + # IDs must match |
| 205 | + new_ids = set( |
| 206 | + [ |
| 207 | + dataclasses.asdict(drainage_elem.attributes)[f"{geom_type}_id"] |
| 208 | + for drainage_elem in drainage_elements |
| 209 | + ] |
| 210 | + ) |
| 211 | + if not new_ids == self.existing_ids[geom_type]: |
| 212 | + raise ValueError( |
| 213 | + f"Object ids mismatch for {geom_type}: " |
| 214 | + f"attempting to write {new_ids} but existing file has {self.existing_ids[geom_type]}" |
| 215 | + ) |
| 216 | + f_obj = StringIO() |
| 217 | + csv_writer = csv.writer(f_obj) |
| 218 | + for drainage_elem in drainage_elements: |
| 219 | + data_line_list = [sim_time_str] + self._attrs_line(drainage_elem) |
| 220 | + csv_writer.writerow(data_line_list) |
| 221 | + new_rows = f_obj.getvalue() |
| 222 | + # Get the file from the store as bytes and decode it |
| 223 | + existing_csv = bytes(obstore.get(self.store, self.file_paths[geom_type]).bytes()).decode( |
| 224 | + "utf-8" |
| 225 | + ) |
| 226 | + updated_csv = existing_csv + new_rows |
| 227 | + obstore.put(self.store, self.file_paths[geom_type], file=updated_csv.encode("utf-8")) |
| 228 | + |
| 229 | + def _attrs_line(self, drainage_element: DrainageNodeData | DrainageLinkData) -> List[str, ...]: |
| 230 | + """Return a list of attributes""" |
| 231 | + # Convert attributes to list |
| 232 | + attributes = [str(a) for a in dataclasses.asdict(drainage_element.attributes).values()] |
| 233 | + # Create geometry WKT |
| 234 | + if isinstance(drainage_element, DrainageNodeData): |
| 235 | + if drainage_element.coordinates is not None: |
| 236 | + x, y = drainage_element.coordinates |
| 237 | + geom_wkt = f"POINT({x} {y})" |
| 238 | + else: |
| 239 | + geom_wkt = "" |
| 240 | + elif isinstance(drainage_element, DrainageLinkData): |
| 241 | + if drainage_element.vertices is not None and len(drainage_element.vertices) > 0: |
| 242 | + # Filter out None vertices |
| 243 | + valid_vertices = [v for v in drainage_element.vertices if v is not None] |
| 244 | + if len(valid_vertices) >= 2: |
| 245 | + coords_str = ", ".join([f"{x} {y}" for x, y in valid_vertices]) |
| 246 | + geom_wkt = f"LINESTRING({coords_str})" |
| 247 | + else: |
| 248 | + # Not enough valid vertices to create a linestring |
| 249 | + geom_wkt = "" |
| 250 | + else: |
| 251 | + geom_wkt = "" |
| 252 | + else: |
| 253 | + raise RuntimeError(f"Unknown drainage_element: {type(drainage_element)}") |
| 254 | + return attributes + [str(self.srid), geom_wkt] |
0 commit comments