|
5 | 5 |
|
6 | 6 | import os # noqa: I001 |
7 | 7 | from enum import Enum |
8 | | -from typing import Optional |
| 8 | +from typing import Any, Optional |
9 | 9 |
|
10 | 10 | import numpy as np |
11 | 11 | import pandas as pd |
@@ -93,6 +93,26 @@ def _create_cf_time_axis(data: xr.DataArray, timestep_str: str) -> tuple[np.ndar |
93 | 93 | units = f"{timestep_str} since {formatted_string_with_tz}" |
94 | 94 | return axis, units, calendar |
95 | 95 |
|
| 96 | +def _validate_station_id_for_int32(station_id: np.ndarray, intdata_type: str) -> None: |
| 97 | + """Validate that station_id values can be safely stored as int32. |
| 98 | +
|
| 99 | + Args: |
| 100 | + station_id: Array of station ID values to validate |
| 101 | + intdata_type: The intended integer data type (e.g., 'i4' for int32) |
| 102 | +
|
| 103 | + Raises: |
| 104 | + TypeError: If station_id values are not integers |
| 105 | + OverflowError: If station_id values are outside the int32 range |
| 106 | + """ |
| 107 | + if intdata_type == "i4": |
| 108 | + max_station_id = np.max(station_id) |
| 109 | + min_station_id = np.min(station_id) |
| 110 | + if not np.issubdtype(type(max_station_id), np.integer) or not np.issubdtype(type(min_station_id), np.integer): |
| 111 | + raise TypeError("station_id values must be integers to be stored in STF2.0 format.") |
| 112 | + if max_station_id > np.iinfo(np.int32).max or min_station_id < np.iinfo(np.int32).min: |
| 113 | + raise OverflowError( |
| 114 | + f"station_id values must be in the int32 range [{np.iinfo(np.int32).min}, {np.iinfo(np.int32).max}] to be stored in STF2.0 format.", |
| 115 | + ) |
96 | 116 |
|
97 | 117 | def write_nc_stf2( |
98 | 118 | out_nc_file: str, |
@@ -242,18 +262,7 @@ def _check_optional_var_attr(dataset: xr.Dataset, var_id: str) -> None: |
242 | 262 |
|
243 | 263 | # station_id |
244 | 264 |
|
245 | | - # we check that station_id can be safely stored as int32 |
246 | | - # I add this deliberately as a check to avoid possibly silent data corruption as observed in |
247 | | - # https://github.com/csiro-hydroinformatics/efts-io/issues/17 |
248 | | - if intdata_type == "i4": |
249 | | - max_station_id = np.max(station_id) |
250 | | - min_station_id = np.min(station_id) |
251 | | - if not np.issubdtype(type(max_station_id), np.integer) or not np.issubdtype(type(min_station_id), np.integer): |
252 | | - raise TypeError("station_id values must be integers to be stored in STF2.0 format.") |
253 | | - if max_station_id > np.iinfo(np.int32).max or min_station_id < np.iinfo(np.int32).min: |
254 | | - raise OverflowError( |
255 | | - f"station_id values must be in the int32 range [{np.iinfo(np.int32).min}, {np.iinfo(np.int32).max}] to be stored in STF2.0 format.", |
256 | | - ) |
| 265 | + _validate_station_id_for_int32(station_id, intdata_type) |
257 | 266 |
|
258 | 267 | station_id_var = ncfile.createVariable(STATION_ID_VARNAME, intdata_type, (STATION_DIMNAME,), fill_value=-9999) |
259 | 268 | station_id_var.setncattr(LONG_NAME_ATTR_KEY, "station or node identification code") |
@@ -370,14 +379,7 @@ def add_optional_variables(data: xr.DataArray, ncfile: Dataset, var_id: str) -> |
370 | 379 | d_type[0] = "der" |
371 | 380 | d_type_long[0] = "derived (from observations)" |
372 | 381 |
|
373 | | - if int(stf_nc_vers) == 1: |
374 | | - d_type[1] = "fcast" |
375 | | - d_type_long[1] = "forecast" |
376 | | - elif int(stf_nc_vers) == 2: # noqa: PLR2004 |
377 | | - d_type[1] = "fct" |
378 | | - d_type_long[1] = "forecast" |
379 | | - else: |
380 | | - raise ValueError("Version not recognised: Currently only version 1.X or 2.X are supported") |
| 382 | + _get_stationid_data_types(stf_nc_vers, d_type, d_type_long) |
381 | 383 |
|
382 | 384 | d_type[2] = "obs" |
383 | 385 | d_type_long[2] = "observed" |
@@ -479,6 +481,16 @@ def add_optional_variables(data: xr.DataArray, ncfile: Dataset, var_id: str) -> |
479 | 481 | # This prevents double-close in the exception handler |
480 | 482 | ncfile.close() |
481 | 483 |
|
| 484 | +def _get_stationid_data_types(stf_nc_vers: Any, d_type:np.ndarray, d_type_long:np.ndarray) -> None: |
| 485 | + if int(stf_nc_vers) == 1: |
| 486 | + d_type[1] = "fcast" |
| 487 | + d_type_long[1] = "forecast" |
| 488 | + elif int(stf_nc_vers) == 2: # noqa: PLR2004 |
| 489 | + d_type[1] = "fct" |
| 490 | + d_type_long[1] = "forecast" |
| 491 | + else: |
| 492 | + raise ValueError("Version not recognised: Currently only version 1.X or 2.X are supported") |
| 493 | + |
482 | 494 |
|
483 | 495 | def make_ready_for_saving(data: xr.DataArray, dataset: xr.Dataset, dimensions_order: tuple) -> xr.DataArray: |
484 | 496 | """Transform an xarray DataArray to ensure it has all required dimensions in the correct order for saving to NetCDF. |
|
0 commit comments