Skip to content

Commit adbc54e

Browse files
authored
Merge pull request #633 from xylar/add-cdf5-handling-to-write-netcdf
Add support for `NETCDF3_64BIT_DATA` in `write_netcdf()`
2 parents 353136b + 2fab2c8 commit adbc54e

File tree

7 files changed

+249
-25
lines changed

7 files changed

+249
-25
lines changed

conda_package/dev-spec.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ hdf5
1212
inpoly
1313
libnetcdf
1414
matplotlib-base>=3.9.0
15+
nco
1516
netcdf4
1617
networkx
1718
numpy>=2.0,<3.0

conda_package/docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ analyzing simulations, and in other MPAS-related workflows.
2727

2828
config
2929

30+
io
31+
3032
logging
3133

3234
transects

conda_package/docs/io.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
.. _io:
2+
3+
*********
4+
I/O Tools
5+
*********
6+
7+
The :py:mod:`mpas_tools.io` module provides utilities for reading and writing
8+
NetCDF files, especially for compatibility with MPAS mesh and data conventions.
9+
10+
write_netcdf
11+
============
12+
13+
The :py:func:`mpas_tools.io.write_netcdf()` function writes an
14+
``xarray.Dataset`` to a NetCDF file, ensuring MPAS compatibility (e.g.,
15+
converting int64 to int32, handling fill values, and updating the history
16+
attribute). It also supports writing in various NetCDF formats, including
17+
conversion to ``NETCDF3_64BIT_DATA`` using ``ncks`` if needed.
18+
19+
Example usage:
20+
21+
.. code-block:: python
22+
23+
import xarray as xr
24+
from mpas_tools.io import write_netcdf
25+
26+
# Create a simple dataset
27+
ds = xr.Dataset({'foo': (('x',), [1, 2, 3])})
28+
write_netcdf(ds, 'output.nc')

conda_package/mpas_tools/io.py

Lines changed: 84 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,43 @@
1-
from __future__ import absolute_import, division, print_function, \
2-
unicode_literals
1+
import os
2+
import subprocess
3+
import sys
4+
from datetime import datetime
5+
from pathlib import Path
36

4-
import numpy
57
import netCDF4
6-
from datetime import datetime
7-
import sys
8+
import numpy
89

10+
from mpas_tools.logging import check_call
911

1012
default_format = 'NETCDF3_64BIT'
1113
default_engine = None
1214
default_char_dim_name = 'StrLen'
1315
default_fills = netCDF4.default_fillvals
1416

1517

16-
def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
17-
char_dim_name=None):
18+
def write_netcdf(
19+
ds,
20+
fileName,
21+
fillValues=None,
22+
format=None,
23+
engine=None,
24+
char_dim_name=None,
25+
logger=None,
26+
):
1827
"""
1928
Write an xarray.Dataset to a file with NetCDF4 fill values and the given
2029
name of the string dimension. Also adds the time and command-line to the
2130
history attribute.
2231
32+
Note: the ``NETCDF3_64BIT_DATA`` format is handled as a special case
33+
because xarray output with this format is not performant. First, the file
34+
is written in `NETCDF4` format, which supports larger files and variables.
35+
Then, the `ncks` command is used to convert the file to the
36+
`NETCDF3_64BIT_DATA` format.
37+
38+
Note: All int64 variables are automatically converted to int32 for MPAS
39+
compatibility.
40+
2341
Parameters
2442
----------
2543
ds : xarray.Dataset
@@ -50,7 +68,11 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
5068
``mpas_tools.io.default_char_dim_name``, which can be modified but
5169
which defaults to ``'StrLen'``
5270
53-
"""
71+
logger : logging.Logger, optional
72+
A logger to write messages to write the output of `ncks` conversion
73+
calls to. If None, `ncks` output is suppressed. This is only
74+
relevant if `format` is 'NETCDF3_64BIT_DATA'
75+
""" # noqa: E501
5476
if format is None:
5577
format = default_format
5678

@@ -63,6 +85,13 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
6385
if char_dim_name is None:
6486
char_dim_name = default_char_dim_name
6587

88+
# Convert int64 variables to int32 for MPAS compatibility
89+
for var in list(ds.data_vars.keys()) + list(ds.coords.keys()):
90+
if ds[var].dtype == numpy.int64:
91+
attrs = ds[var].attrs.copy()
92+
ds[var] = ds[var].astype(numpy.int32)
93+
ds[var].attrs = attrs
94+
6695
encodingDict = {}
6796
variableNames = list(ds.data_vars.keys()) + list(ds.coords.keys())
6897
for variableName in variableNames:
@@ -71,8 +100,9 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
71100
dtype = ds[variableName].dtype
72101
for fillType in fillValues:
73102
if dtype == numpy.dtype(fillType):
74-
encodingDict[variableName] = \
75-
{'_FillValue': fillValues[fillType]}
103+
encodingDict[variableName] = {
104+
'_FillValue': fillValues[fillType]
105+
}
76106
break
77107
else:
78108
encodingDict[variableName] = {'_FillValue': None}
@@ -88,14 +118,54 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
88118
# reading Time otherwise
89119
ds.encoding['unlimited_dims'] = {'Time'}
90120

91-
ds.to_netcdf(fileName, encoding=encodingDict, format=format, engine=engine)
121+
# for performance, we have to handle this as a special case
122+
convert = format == 'NETCDF3_64BIT_DATA'
123+
124+
if convert:
125+
out_path = Path(fileName)
126+
out_filename = (
127+
out_path.parent / f'_tmp_{out_path.stem}.netcdf4{out_path.suffix}'
128+
)
129+
format = 'NETCDF4'
130+
if engine == 'scipy':
131+
# that's not going to work
132+
engine = 'netcdf4'
133+
else:
134+
out_filename = fileName
135+
136+
ds.to_netcdf(
137+
out_filename, encoding=encodingDict, format=format, engine=engine
138+
)
139+
140+
if convert:
141+
args = [
142+
'ncks',
143+
'-O',
144+
'-5',
145+
out_filename,
146+
fileName,
147+
]
148+
if logger is None:
149+
subprocess.run(
150+
args,
151+
check=True,
152+
stdout=subprocess.DEVNULL,
153+
stderr=subprocess.DEVNULL,
154+
)
155+
else:
156+
check_call(args, logger=logger)
157+
# delete the temporary NETCDF4 file
158+
os.remove(out_filename)
92159

93160

94161
def update_history(ds):
95-
'''Add or append history to attributes of a data set'''
162+
"""Add or append history to attributes of a data set"""
96163

97-
thiscommand = datetime.now().strftime("%a %b %d %H:%M:%S %Y") + ": " + \
98-
" ".join(sys.argv[:])
164+
thiscommand = (
165+
datetime.now().strftime('%a %b %d %H:%M:%S %Y')
166+
+ ': '
167+
+ ' '.join(sys.argv[:])
168+
)
99169
if 'history' in ds.attrs:
100170
newhist = '\n'.join([thiscommand, ds.attrs['history']])
101171
else:

conda_package/mpas_tools/mesh/mask.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,13 @@ def entry_point_compute_mpas_region_masks():
251251
subdivisionThreshold=args.subdivision,
252252
)
253253

254-
write_netcdf(
255-
dsMasks, args.mask_file_name, format=args.format, engine=args.engine
256-
)
254+
write_netcdf(
255+
dsMasks,
256+
args.mask_file_name,
257+
format=args.format,
258+
engine=args.engine,
259+
logger=logger,
260+
)
257261

258262

259263
def compute_mpas_transect_masks(
@@ -516,9 +520,13 @@ def entry_point_compute_mpas_transect_masks():
516520
addEdgeSign=args.add_edge_sign,
517521
)
518522

519-
write_netcdf(
520-
dsMasks, args.mask_file_name, format=args.format, engine=args.engine
521-
)
523+
write_netcdf(
524+
dsMasks,
525+
args.mask_file_name,
526+
format=args.format,
527+
engine=args.engine,
528+
logger=logger,
529+
)
522530

523531

524532
def compute_mpas_flood_fill_mask(
@@ -641,9 +649,13 @@ def entry_point_compute_mpas_flood_fill_mask():
641649
dsMesh=dsMesh, fcSeed=fcSeed, logger=logger
642650
)
643651

644-
write_netcdf(
645-
dsMasks, args.mask_file_name, format=args.format, engine=args.engine
646-
)
652+
write_netcdf(
653+
dsMasks,
654+
args.mask_file_name,
655+
format=args.format,
656+
engine=args.engine,
657+
logger=logger,
658+
)
647659

648660

649661
def compute_lon_lat_region_masks(
@@ -868,7 +880,11 @@ def entry_point_compute_lon_lat_region_masks():
868880
)
869881

870882
write_netcdf(
871-
dsMasks, args.mask_file_name, format=args.format, engine=args.engine
883+
dsMasks,
884+
args.mask_file_name,
885+
format=args.format,
886+
engine=args.engine,
887+
logger=logger,
872888
)
873889

874890

@@ -1101,7 +1117,11 @@ def entry_point_compute_projection_grid_region_masks():
11011117
)
11021118

11031119
write_netcdf(
1104-
dsMasks, args.mask_file_name, format=args.format, engine=args.engine
1120+
dsMasks,
1121+
args.mask_file_name,
1122+
format=args.format,
1123+
engine=args.engine,
1124+
logger=logger,
11051125
)
11061126

11071127

conda_package/recipe/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ requirements:
4242
- networkx
4343
- netcdf-fortran
4444
- matplotlib-base >=3.9.0
45+
- nco
4546
- netcdf4
4647
- numpy >=2.0,<3.0
4748
- progressbar2

conda_package/tests/test_io.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import os
2+
import subprocess
3+
4+
import numpy as np
5+
import pytest
6+
import xarray as xr
7+
8+
from mpas_tools.io import write_netcdf
9+
10+
from .util import get_test_data_file
11+
12+
TEST_MESH = get_test_data_file('mesh.QU.1920km.151026.nc')
13+
14+
15+
@pytest.mark.skipif(
16+
not os.path.exists(TEST_MESH), reason='Test mesh not available'
17+
)
18+
def test_write_netcdf_basic(tmp_path):
19+
ds = xr.open_dataset(TEST_MESH)
20+
out_file = tmp_path / 'test_basic.nc'
21+
write_netcdf(ds, str(out_file))
22+
ds2 = xr.open_dataset(out_file)
23+
# Should have same dimensions and variables
24+
assert set(ds.dims) == set(ds2.dims)
25+
for var in ds.data_vars:
26+
assert var in ds2.data_vars
27+
ds2.close()
28+
29+
30+
@pytest.mark.skipif(
31+
not os.path.exists(TEST_MESH), reason='Test mesh not available'
32+
)
33+
def test_write_netcdf_cdf5_format(tmp_path):
34+
ds = xr.open_dataset(TEST_MESH)
35+
out_file = tmp_path / 'test_cdf5.nc'
36+
write_netcdf(ds, str(out_file), format='NETCDF3_64BIT_DATA')
37+
# Use ncdump -k to check format
38+
result = subprocess.run(
39+
['ncdump', '-k', str(out_file)],
40+
capture_output=True,
41+
text=True,
42+
check=True,
43+
)
44+
# Should be cdf5 for NETCDF3_64BIT_DATA
45+
assert result.stdout.strip() == 'cdf5'
46+
# Check that the temporary file was deleted
47+
tmp_file = (
48+
out_file.parent / f'_tmp_{out_file.stem}.netcdf4{out_file.suffix}'
49+
)
50+
assert not os.path.exists(tmp_file)
51+
52+
53+
def test_write_netcdf_int64_conversion_and_attr(tmp_path):
54+
# Create a dataset with int64 variable and an attribute
55+
arr = np.array([1, 2, 3], dtype=np.int64)
56+
ds = xr.Dataset({'foo': (('x',), arr)})
57+
ds['foo'].attrs['myattr'] = 'testattr'
58+
out_file = tmp_path / 'test_int64.nc'
59+
write_netcdf(ds, str(out_file))
60+
ds2 = xr.open_dataset(out_file)
61+
# Should be int32, not int64
62+
assert ds2['foo'].dtype == np.int32
63+
# Attribute should be preserved
64+
assert ds2['foo'].attrs['myattr'] == 'testattr'
65+
ds2.close()
66+
67+
68+
def test_write_netcdf_fill_value(tmp_path):
69+
# Test that NaN values are written with correct fill value
70+
arr = np.array([1.0, np.nan, 3.0], dtype=np.float32)
71+
ds = xr.Dataset({'bar': (('x',), arr)})
72+
out_file = tmp_path / 'test_fill.nc'
73+
write_netcdf(ds, str(out_file))
74+
ds2 = xr.open_dataset(out_file)
75+
# The second value should be the default fill value for float32
76+
fill_value = ds2['bar'].encoding.get('_FillValue', None)
77+
assert fill_value is not None
78+
assert np.isnan(ds2['bar'].values[1])
79+
ds2.close()
80+
81+
82+
def test_write_netcdf_string_dim_name(tmp_path):
83+
# Test that custom char_dim_name is used in encoding
84+
arr = np.array([b'abc', b'def'])
85+
ds = xr.Dataset({'baz': (('x',), arr)})
86+
out_file = tmp_path / 'test_strdim.nc'
87+
write_netcdf(ds, str(out_file), char_dim_name='CustomStrLen')
88+
ds2 = xr.open_dataset(out_file)
89+
# Should have the variable and correct shape
90+
assert 'baz' in ds2.variables
91+
ds2.close()
92+
arr = np.array([1, 2, 3], dtype=np.int64)
93+
ds = xr.Dataset({'foo': (('x',), arr)})
94+
ds['foo'].attrs['myattr'] = 'testattr'
95+
out_file = tmp_path / 'test_int64.nc'
96+
write_netcdf(ds, str(out_file))
97+
ds2 = xr.open_dataset(out_file)
98+
# Should be int32, not int64
99+
assert ds2['foo'].dtype == np.int32
100+
# Attribute should be preserved
101+
assert ds2['foo'].attrs['myattr'] == 'testattr'
102+
ds2.close()

0 commit comments

Comments
 (0)