1- from __future__ import absolute_import , division , print_function , \
2- unicode_literals
1+ import os
2+ import subprocess
3+ import sys
4+ from datetime import datetime
5+ from pathlib import Path
36
4- import numpy
57import netCDF4
6- from datetime import datetime
7- import sys
8+ import numpy
89
10+ from mpas_tools .logging import check_call
911
1012default_format = 'NETCDF3_64BIT'
1113default_engine = None
1214default_char_dim_name = 'StrLen'
1315default_fills = netCDF4 .default_fillvals
1416
1517
16- def write_netcdf (ds , fileName , fillValues = None , format = None , engine = None ,
17- char_dim_name = None ):
18+ def write_netcdf (
19+ ds ,
20+ fileName ,
21+ fillValues = None ,
22+ format = None ,
23+ engine = None ,
24+ char_dim_name = None ,
25+ logger = None ,
26+ ):
1827 """
1928 Write an xarray.Dataset to a file with NetCDF4 fill values and the given
2029 name of the string dimension. Also adds the time and command-line to the
2130 history attribute.
2231
32+ Note: the ``NETCDF3_64BIT_DATA`` format is handled as a special case
33+ because xarray output with this format is not performant. First, the file
34+ is written in `NETCDF4` format, which supports larger files and variables.
35+ Then, the `ncks` command is used to convert the file to the
36+ `NETCDF3_64BIT_DATA` format.
37+
38+ Note: All int64 variables are automatically converted to int32 for MPAS
39+ compatibility.
40+
2341 Parameters
2442 ----------
2543 ds : xarray.Dataset
@@ -50,7 +68,11 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
5068 ``mpas_tools.io.default_char_dim_name``, which can be modified but
5169 which defaults to ``'StrLen'``
5270
53- """
71+ logger : logging.Logger, optional
72+ A logger to write messages to write the output of `ncks` conversion
73+ calls to. If None, `ncks` output is suppressed. This is only
74+ relevant if `format` is 'NETCDF3_64BIT_DATA'
75+ """ # noqa: E501
5476 if format is None :
5577 format = default_format
5678
@@ -63,6 +85,13 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
6385 if char_dim_name is None :
6486 char_dim_name = default_char_dim_name
6587
88+ # Convert int64 variables to int32 for MPAS compatibility
89+ for var in list (ds .data_vars .keys ()) + list (ds .coords .keys ()):
90+ if ds [var ].dtype == numpy .int64 :
91+ attrs = ds [var ].attrs .copy ()
92+ ds [var ] = ds [var ].astype (numpy .int32 )
93+ ds [var ].attrs = attrs
94+
6695 encodingDict = {}
6796 variableNames = list (ds .data_vars .keys ()) + list (ds .coords .keys ())
6897 for variableName in variableNames :
@@ -71,8 +100,9 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
71100 dtype = ds [variableName ].dtype
72101 for fillType in fillValues :
73102 if dtype == numpy .dtype (fillType ):
74- encodingDict [variableName ] = \
75- {'_FillValue' : fillValues [fillType ]}
103+ encodingDict [variableName ] = {
104+ '_FillValue' : fillValues [fillType ]
105+ }
76106 break
77107 else :
78108 encodingDict [variableName ] = {'_FillValue' : None }
@@ -88,14 +118,54 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
88118 # reading Time otherwise
89119 ds .encoding ['unlimited_dims' ] = {'Time' }
90120
91- ds .to_netcdf (fileName , encoding = encodingDict , format = format , engine = engine )
121+ # for performance, we have to handle this as a special case
122+ convert = format == 'NETCDF3_64BIT_DATA'
123+
124+ if convert :
125+ out_path = Path (fileName )
126+ out_filename = (
127+ out_path .parent / f'_tmp_{ out_path .stem } .netcdf4{ out_path .suffix } '
128+ )
129+ format = 'NETCDF4'
130+ if engine == 'scipy' :
131+ # that's not going to work
132+ engine = 'netcdf4'
133+ else :
134+ out_filename = fileName
135+
136+ ds .to_netcdf (
137+ out_filename , encoding = encodingDict , format = format , engine = engine
138+ )
139+
140+ if convert :
141+ args = [
142+ 'ncks' ,
143+ '-O' ,
144+ '-5' ,
145+ out_filename ,
146+ fileName ,
147+ ]
148+ if logger is None :
149+ subprocess .run (
150+ args ,
151+ check = True ,
152+ stdout = subprocess .DEVNULL ,
153+ stderr = subprocess .DEVNULL ,
154+ )
155+ else :
156+ check_call (args , logger = logger )
157+ # delete the temporary NETCDF4 file
158+ os .remove (out_filename )
92159
93160
94161def update_history (ds ):
95- ''' Add or append history to attributes of a data set'''
162+ """ Add or append history to attributes of a data set"""
96163
97- thiscommand = datetime .now ().strftime ("%a %b %d %H:%M:%S %Y" ) + ": " + \
98- " " .join (sys .argv [:])
164+ thiscommand = (
165+ datetime .now ().strftime ('%a %b %d %H:%M:%S %Y' )
166+ + ': '
167+ + ' ' .join (sys .argv [:])
168+ )
99169 if 'history' in ds .attrs :
100170 newhist = '\n ' .join ([thiscommand , ds .attrs ['history' ]])
101171 else :
0 commit comments