Skip to content

Commit 1ed4cba

Browse files
committed
ci: refactor to pass quality checks
1 parent bb0fd0d commit 1ed4cba

File tree

1 file changed

+33
-21
lines changed

1 file changed

+33
-21
lines changed

src/efts_io/_ncdf_stf2.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import os # noqa: I001
77
from enum import Enum
8-
from typing import Optional
8+
from typing import Any, Optional
99

1010
import numpy as np
1111
import pandas as pd
@@ -93,6 +93,26 @@ def _create_cf_time_axis(data: xr.DataArray, timestep_str: str) -> tuple[np.ndar
9393
units = f"{timestep_str} since {formatted_string_with_tz}"
9494
return axis, units, calendar
9595

96+
def _validate_station_id_for_int32(station_id: np.ndarray, intdata_type: str) -> None:
97+
"""Validate that station_id values can be safely stored as int32.
98+
99+
Args:
100+
station_id: Array of station ID values to validate
101+
intdata_type: The intended integer data type (e.g., 'i4' for int32)
102+
103+
Raises:
104+
TypeError: If station_id values are not integers
105+
OverflowError: If station_id values are outside the int32 range
106+
"""
107+
if intdata_type == "i4":
108+
max_station_id = np.max(station_id)
109+
min_station_id = np.min(station_id)
110+
if not np.issubdtype(type(max_station_id), np.integer) or not np.issubdtype(type(min_station_id), np.integer):
111+
raise TypeError("station_id values must be integers to be stored in STF2.0 format.")
112+
if max_station_id > np.iinfo(np.int32).max or min_station_id < np.iinfo(np.int32).min:
113+
raise OverflowError(
114+
f"station_id values must be in the int32 range [{np.iinfo(np.int32).min}, {np.iinfo(np.int32).max}] to be stored in STF2.0 format.",
115+
)
96116

97117
def write_nc_stf2(
98118
out_nc_file: str,
@@ -242,18 +262,7 @@ def _check_optional_var_attr(dataset: xr.Dataset, var_id: str) -> None:
242262

243263
# station_id
244264

245-
# we check that station_id can be safely stored as int32
246-
# I add this deliberately as a check to avoid possibly silent data corruption as observed in
247-
# https://github.com/csiro-hydroinformatics/efts-io/issues/17
248-
if intdata_type == "i4":
249-
max_station_id = np.max(station_id)
250-
min_station_id = np.min(station_id)
251-
if not np.issubdtype(type(max_station_id), np.integer) or not np.issubdtype(type(min_station_id), np.integer):
252-
raise TypeError("station_id values must be integers to be stored in STF2.0 format.")
253-
if max_station_id > np.iinfo(np.int32).max or min_station_id < np.iinfo(np.int32).min:
254-
raise OverflowError(
255-
f"station_id values must be in the int32 range [{np.iinfo(np.int32).min}, {np.iinfo(np.int32).max}] to be stored in STF2.0 format.",
256-
)
265+
_validate_station_id_for_int32(station_id, intdata_type)
257266

258267
station_id_var = ncfile.createVariable(STATION_ID_VARNAME, intdata_type, (STATION_DIMNAME,), fill_value=-9999)
259268
station_id_var.setncattr(LONG_NAME_ATTR_KEY, "station or node identification code")
@@ -370,14 +379,7 @@ def add_optional_variables(data: xr.DataArray, ncfile: Dataset, var_id: str) ->
370379
d_type[0] = "der"
371380
d_type_long[0] = "derived (from observations)"
372381

373-
if int(stf_nc_vers) == 1:
374-
d_type[1] = "fcast"
375-
d_type_long[1] = "forecast"
376-
elif int(stf_nc_vers) == 2: # noqa: PLR2004
377-
d_type[1] = "fct"
378-
d_type_long[1] = "forecast"
379-
else:
380-
raise ValueError("Version not recognised: Currently only version 1.X or 2.X are supported")
382+
_get_stationid_data_types(stf_nc_vers, d_type, d_type_long)
381383

382384
d_type[2] = "obs"
383385
d_type_long[2] = "observed"
@@ -479,6 +481,16 @@ def add_optional_variables(data: xr.DataArray, ncfile: Dataset, var_id: str) ->
479481
# This prevents double-close in the exception handler
480482
ncfile.close()
481483

484+
def _get_stationid_data_types(stf_nc_vers: Any, d_type:np.ndarray, d_type_long:np.ndarray) -> None:
485+
if int(stf_nc_vers) == 1:
486+
d_type[1] = "fcast"
487+
d_type_long[1] = "forecast"
488+
elif int(stf_nc_vers) == 2: # noqa: PLR2004
489+
d_type[1] = "fct"
490+
d_type_long[1] = "forecast"
491+
else:
492+
raise ValueError("Version not recognised: Currently only version 1.X or 2.X are supported")
493+
482494

483495
def make_ready_for_saving(data: xr.DataArray, dataset: xr.Dataset, dimensions_order: tuple) -> xr.DataArray:
484496
"""Transform an xarray DataArray to ensure it has all required dimensions in the correct order for saving to NetCDF.

0 commit comments

Comments
 (0)