Skip to content

Commit 4b5b778

Browse files
authored
Fix 64bit cast for parse_netcdf (#92)
- Fix cast for specific emvorado files - Add pytest
1 parent 31df3e0 commit 4b5b778

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
This module contains unit tests for the `model_output_parser.py` module.
3+
"""
4+
5+
import numpy as np
6+
import xarray as xr
7+
8+
from util.model_output_parser import parse_netcdf
9+
10+
11+
def test_parse_netcdf_only_floats_converted(tmp_path):
12+
"""
13+
Ensure parse_netcdf converts only float variables to float64 and
14+
does not attempt to convert string/bytes variables.
15+
"""
16+
17+
ds = xr.Dataset(
18+
{
19+
"float32_var": ("t", np.array([1.0, 2.0], dtype=np.float32)),
20+
"float64_var": ("t", np.array([1.0, 2.0], dtype=np.float64)),
21+
"int_var": ("t", np.array([1, 2], dtype=np.int32)),
22+
"str_var": ("t", np.array([b"A", b"B"], dtype="S1")),
23+
}
24+
)
25+
26+
# Save to a temporary NetCDF file
27+
filename = tmp_path / "test.nc"
28+
ds.to_netcdf(filename)
29+
30+
# Specification for parse_netcdf
31+
specification = {
32+
"time_dim": "t",
33+
"horizontal_dims": [],
34+
"fill_value_key": None,
35+
}
36+
37+
var_dfs = parse_netcdf("test_file", str(filename), specification)
38+
39+
# Check dtypes
40+
var_names = ["float32_var", "float64_var", "int_var", "str_var"]
41+
for name, df in zip(var_names, var_dfs):
42+
dtype = df.dtypes[0]
43+
if name.startswith("float"):
44+
assert dtype == np.float64
45+
elif name.startswith("int"):
46+
assert np.issubdtype(dtype, np.integer)
47+
elif name.startswith("str"):
48+
assert np.issubdtype(dtype, np.object_) # pandas converts bytes -> object

util/model_output_parser.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import sys
2727
from collections.abc import Iterable
28+
from typing import Any, Dict, List
2829

2930
import numpy as np
3031
import pandas as pd
@@ -36,15 +37,23 @@
3637
from util.xarray_ops import statistics_over_horizontal_dim
3738

3839

39-
def parse_netcdf(file_id, filename, specification):
40+
def parse_netcdf(
41+
file_id: str, filename: str, specification: Dict[str, Any]
42+
) -> List[pd.DataFrame]:
43+
"""
44+
Parse a NetCDF file into pandas DataFrames.
45+
"""
46+
4047
logger.debug("parse NetCDF file %s", filename)
4148
time_dim = specification["time_dim"]
4249
horizontal_dims = specification["horizontal_dims"]
4350
fill_value_key = specification.get("fill_value_key", None)
4451
ds = xarray.open_dataset(filename, decode_cf=False)
52+
4553
# Convert all float variables to float64
46-
float_vars = [v for v in ds.data_vars if np.issubdtype(ds[v].dtype, np.floating)]
47-
ds[float_vars] = ds[float_vars].astype(np.float64)
54+
for v in ds.data_vars:
55+
if np.issubdtype(ds[v].dtype, np.floating):
56+
ds[v] = ds[v].astype(np.float64)
4857

4958
var_tmp = __get_variables(ds, time_dim, horizontal_dims)
5059

0 commit comments

Comments
 (0)