Skip to content

Commit 94b2b21

Browse files
committed
Put encoding information into separate converter class, for use in proxies.
1 parent 042028e commit 94b2b21

File tree

2 files changed

+161
-156
lines changed

2 files changed

+161
-156
lines changed

lib/iris/fileformats/netcdf/_bytecoding_datasets.py

Lines changed: 144 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import codecs
4444
import contextlib
45+
import dataclasses
4546
import threading
4647
import warnings
4748

@@ -80,55 +81,6 @@ def decode_bytesarray_to_stringarray(
8081
return result
8182

8283

83-
#
84-
# TODO: remove?
85-
# this older version is "overly flexible", less efficient and not needed here.
86-
#
87-
def flexi_encode_stringarray_as_bytearray(
88-
data: np.ndarray, encoding=None, string_dimension_length: int | None = None
89-
) -> np.ndarray:
90-
"""Encode strings as bytearray.
91-
92-
Note: if 'string_dimension_length' is not given (None), it is set to the longest
93-
encoded bytes element, **OR** the dtype size, if that is greater.
94-
If 'string_dimension_length' is specified, the last array
95-
dimension is set to this and content strings are truncated or extended as required.
96-
"""
97-
if np.ma.isMaskedArray(data):
98-
# netCDF4-python sees zeros as "missing" -- we don't need or want that
99-
data = data.data
100-
element_shape = data.shape
101-
# Encode all the strings + see which is longest
102-
max_length = 1 # this is a MINIMUM - i.e. not zero!
103-
data_elements = np.zeros(element_shape, dtype=object)
104-
for index in np.ndindex(element_shape):
105-
data_element = data[index].encode(encoding=encoding)
106-
element_length = len(data_element)
107-
data_elements[index] = data_element
108-
if element_length > max_length:
109-
max_length = element_length
110-
111-
if string_dimension_length is None:
112-
# If the string length was not specified, it is the maximum encoded length
113-
# (n-bytes), **or** the dtype string-length, if greater.
114-
string_dimension_length = max_length
115-
array_string_length = int(str(data.dtype)[2:]) # Yuck. No better public way?
116-
if array_string_length > string_dimension_length:
117-
string_dimension_length = array_string_length
118-
119-
# We maybe *already* encoded all the strings above, but stored them in an
120-
# object-array as we didn't yet know the fixed byte-length to convert to.
121-
# Now convert to a fixed-width byte array with an extra string-length dimension
122-
result = np.zeros(element_shape + (string_dimension_length,), dtype="S1")
123-
right_pad = b"\0" * string_dimension_length
124-
for index in np.ndindex(element_shape):
125-
bytes = data_elements[index]
126-
bytes = (bytes + right_pad)[:string_dimension_length]
127-
result[index] = [bytes[i : i + 1] for i in range(string_dimension_length)]
128-
129-
return result
130-
131-
13284
def encode_stringarray_as_bytearray(
13385
data: np.typing.ArrayLike, encoding: str, string_dimension_length: int
13486
) -> np.ndarray:
@@ -158,6 +110,114 @@ def encode_stringarray_as_bytearray(
158110
return result
159111

160112

113+
@dataclasses.dataclass
114+
class VariableEncoder:
115+
"""A record of encoding details which can apply them to variable data."""
116+
117+
varname: str # just for the error messages
118+
dtype: np.dtype
119+
is_chardata: bool # just a shortcut for the dtype test
120+
read_encoding: str # *always* a valid encoding from the codecs package
121+
write_encoding: str # *always* a valid encoding from the codecs package
122+
n_chars_dim: int # length of associated character dimension
123+
string_width: int # string lengths when viewing as strings (i.e. "Uxx")
124+
125+
def __init__(self, cf_var):
126+
"""Get all the info from an netCDF4 variable (or similar wrapper object).
127+
128+
Most importantly, we do *not* store 'cf_var' : instead we extract the
129+
necessary information and store it in this object.
130+
So, this object has static state + is serialisable.
131+
"""
132+
self.varname = cf_var.name
133+
self.dtype = cf_var.dtype
134+
self.is_chardata = np.issubdtype(self.dtype, np.bytes_)
135+
self.read_encoding = self._get_encoding(cf_var, writing=False)
136+
self.write_encoding = self._get_encoding(cf_var, writing=True)
137+
self.n_chars_dim = cf_var.group().dimensions[cf_var.dimensions[-1]].size
138+
self.string_width = self._get_string_width(cf_var)
139+
140+
@staticmethod
141+
def _get_encoding(cf_var, writing=False) -> str:
142+
"""Get the byte encoding defined for this variable (or None)."""
143+
result = getattr(cf_var, "_Encoding", None)
144+
if result is not None:
145+
try:
146+
# Accept + normalise naming of encodings
147+
result = codecs.lookup(result).name
148+
# NOTE: if encoding does not suit data, errors can occur.
149+
# For example, _Encoding = "ascii", with non-ascii content.
150+
except LookupError:
151+
# Unrecognised encoding name : handle this as just a warning
152+
msg = (
153+
f"Ignoring unknown encoding for variable {cf_var.name!r}: "
154+
f"_Encoding = {result!r}."
155+
)
156+
warntype = IrisCfSaveWarning if writing else IrisCfLoadWarning
157+
warnings.warn(msg, category=warntype)
158+
# Proceed as if there is no specified encoding
159+
result = None
160+
161+
if result is None:
162+
if writing:
163+
result = DEFAULT_WRITE_ENCODING
164+
else:
165+
result = DEFAULT_READ_ENCODING
166+
return result
167+
168+
def _get_string_width(self, cf_var) -> int:
169+
"""Return the string-length defined for this variable."""
170+
# Work out the actual byte width from the parent dataset dimensions.
171+
strlen = self.n_chars_dim
172+
# Convert the string dimension length (i.e. bytes) to a sufficiently-long
173+
# string width, depending on the (read) encoding used.
174+
encoding = self.read_encoding
175+
if "utf-16" in encoding:
176+
# Each char needs at least 2 bytes -- including a terminator char
177+
strlen = (strlen // 2) - 1
178+
elif "utf-32" in encoding:
179+
# Each char needs exactly 4 bytes -- including a terminator char
180+
strlen = (strlen // 4) - 1
181+
# "ELSE": assume there can be (at most) as many chars as bytes
182+
return strlen
183+
184+
def decode_bytes_to_stringarray(self, data: np.ndarray) -> np.ndarray:
185+
if self.is_chardata and DECODE_TO_STRINGS_ON_READ:
186+
# N.B. read encoding default is UTF-8 --> a "usually safe" choice
187+
encoding = self.read_encoding
188+
strlen = self.string_width
189+
try:
190+
data = decode_bytesarray_to_stringarray(data, encoding, strlen)
191+
except UnicodeDecodeError as err:
192+
msg = (
193+
f"Character data in variable {self.varname!r} could not be decoded "
194+
f"with the {encoding!r} encoding. This can be fixed by setting the "
195+
"variable '_Encoding' attribute to suit the content."
196+
)
197+
raise ValueError(msg) from err
198+
199+
return data
200+
201+
def encode_strings_as_bytearray(self, data: np.ndarray) -> np.ndarray:
202+
if data.dtype.kind == "U":
203+
# N.B. it is also possible to pass a byte array (dtype "S1"),
204+
# to be written directly, without processing.
205+
try:
206+
# N.B. write encoding *default* is "ascii" --> fails bad content
207+
encoding = self.write_encoding
208+
strlen = self.n_chars_dim
209+
data = encode_stringarray_as_bytearray(data, encoding, strlen)
210+
except UnicodeEncodeError as err:
211+
msg = (
212+
f"String data written to netcdf character variable {self.varname!r} "
213+
f"could not be represented in encoding {self.write_encoding!r}. "
214+
"This can be fixed by setting a suitable variable '_Encoding' "
215+
'attribute, e.g. <variable>._Encoding="UTF-8".'
216+
)
217+
raise ValueError(msg) from err
218+
return data
219+
220+
161221
class NetcdfStringDecodeSetting(threading.local):
162222
def __init__(self, perform_encoding: bool = True):
163223
self.set(perform_encoding)
@@ -184,109 +244,24 @@ def context(self, perform_encoding: bool):
184244
class EncodedVariable(VariableWrapper):
185245
"""A variable wrapper that translates variable data according to byte encodings."""
186246

187-
def __getitem__(self, keys):
188-
if self._is_chardata():
189-
# N.B. we never need to UNset this, as we totally control it
190-
self._contained_instance.set_auto_chartostring(False)
247+
def __init__(self, *args, **kwargs):
248+
super().__init__(*args, **kwargs)
191249

250+
def __getitem__(self, keys):
251+
self._contained_instance.set_auto_chartostring(False)
192252
data = super().__getitem__(keys)
193-
194-
if DECODE_TO_STRINGS_ON_READ and self._is_chardata():
195-
encoding = self._get_encoding() or DEFAULT_READ_ENCODING
196-
# N.B. typically, read encoding default is UTF-8 --> a "usually safe" choice
197-
strlen = self._get_string_width()
198-
try:
199-
data = decode_bytesarray_to_stringarray(data, encoding, strlen)
200-
except UnicodeDecodeError as err:
201-
msg = (
202-
f"Character data in variable {self.name!r} could not be decoded "
203-
f"with the {encoding!r} encoding. This can be fixed by setting the "
204-
"variable '_Encoding' attribute to suit the content."
205-
)
206-
raise ValueError(msg) from err
207-
253+
# Create a coding spec : redo every time in case "_Encoding" has changed
254+
encoding_spec = VariableEncoder(self._contained_instance)
255+
data = encoding_spec.decode_bytes_to_stringarray(data)
208256
return data
209257

210258
def __setitem__(self, keys, data):
211259
data = np.asanyarray(data)
212-
if self._is_chardata():
213-
# N.B. we never need to UNset this, as we totally control it
214-
self._contained_instance.set_auto_chartostring(False)
215-
216-
# N.B. typically, write encoding default is "ascii" --> fails bad content
217-
if data.dtype.kind == "U":
218-
try:
219-
encoding = (
220-
self._get_encoding(writing=True) or DEFAULT_WRITE_ENCODING
221-
)
222-
strlen = self._get_byte_width()
223-
data = encode_stringarray_as_bytearray(data, encoding, strlen)
224-
except UnicodeEncodeError as err:
225-
msg = (
226-
f"String data written to netcdf character variable {self.name!r} "
227-
f"could not be represented in encoding {encoding!r}. This can be "
228-
"fixed by setting a suitable variable '_Encoding' attribute, "
229-
'e.g. <variable>._Encoding="UTF-8".'
230-
)
231-
raise ValueError(msg) from err
232-
260+
# Create a coding spec : redo every time in case "_Encoding" has changed
261+
encoding_spec = VariableEncoder(self._contained_instance)
262+
data = encoding_spec.encode_strings_as_bytearray(data)
233263
super().__setitem__(keys, data)
234264

235-
def _is_chardata(self):
236-
return np.issubdtype(self.dtype, np.bytes_)
237-
238-
def _get_encoding(self, writing=False) -> str | None:
239-
"""Get the byte encoding defined for this variable (or None)."""
240-
result = getattr(self, "_Encoding", None)
241-
if result is not None:
242-
try:
243-
# Accept + normalise naming of encodings
244-
result = codecs.lookup(result).name
245-
# NOTE: if encoding does not suit data, errors can occur.
246-
# For example, _Encoding = "ascii", with non-ascii content.
247-
except LookupError:
248-
# Unrecognised encoding name : handle this as just a warning
249-
msg = (
250-
f"Ignoring unknown encoding for variable {self.name!r}: "
251-
f"_Encoding = {result!r}."
252-
)
253-
warntype = IrisCfSaveWarning if writing else IrisCfLoadWarning
254-
warnings.warn(msg, category=warntype)
255-
# Proceed as if there is no specified encoding
256-
result = None
257-
return result
258-
259-
def _get_byte_width(self) -> int | None:
260-
if not hasattr(self, "_bytewidth"):
261-
n_bytes = self.group().dimensions[self.dimensions[-1]].size
262-
# Cache this length control on the variable -- but not as a netcdf attribute
263-
self.__dict__["_bytewidth"] = n_bytes
264-
265-
return self.__dict__["_bytewidth"]
266-
267-
def _get_string_width(self):
268-
"""Return the string-length defined for this variable."""
269-
if not hasattr(self, "_strlen"):
270-
# Work out the actual byte width from the parent dataset dimensions.
271-
strlen = self._get_byte_width()
272-
# Convert the string dimension length (i.e. bytes) to a sufficiently-long
273-
# string width, depending on the encoding used.
274-
encoding = self._get_encoding() or DEFAULT_READ_ENCODING
275-
# regularise the name for comparison with recognised ones
276-
encoding = codecs.lookup(encoding).name
277-
if "utf-16" in encoding:
278-
# Each char needs at least 2 bytes -- including a terminator char
279-
strlen = (strlen // 2) - 1
280-
elif "utf-32" in encoding:
281-
# Each char needs exactly 4 bytes -- including a terminator char
282-
strlen = (strlen // 4) - 1
283-
# "ELSE": assume there can be (at most) as many chars as bytes
284-
285-
# Cache this length control on the variable -- but not as a netcdf attribute
286-
self.__dict__["_strlen"] = strlen
287-
288-
return self._strlen
289-
290265
def set_auto_chartostring(self, onoff: bool):
291266
msg = "auto_chartostring is not supported by Iris 'EncodedVariable' type."
292267
raise TypeError(msg)
@@ -297,14 +272,37 @@ class EncodedDataset(DatasetWrapper):
297272

298273
VAR_WRAPPER_CLS = EncodedVariable
299274

275+
def __init__(self, *args, **kwargs):
276+
super().__init__(*args, **kwargs)
277+
300278
def set_auto_chartostring(self, onoff: bool):
301279
msg = "auto_chartostring is not supported by Iris 'EncodedDataset' type."
302280
raise TypeError(msg)
303281

304282

305283
class EncodedNetCDFDataProxy(NetCDFDataProxy):
306-
DATASET_CLASS = EncodedDataset
284+
__slots__ = NetCDFDataProxy.__slots__ + ("encoding_details",)
285+
286+
def __init__(self, cf_var, *args, **kwargs):
287+
# When creating, also capture + record the encoding to be performed.
288+
kwargs["use_byte_data"] = True
289+
super().__init__(cf_var, *args, **kwargs)
290+
self.encoding_details = VariableEncoder(cf_var)
291+
292+
def __getitem__(self, keys):
293+
data = super().__getitem__(keys)
294+
# Apply the optional bytes-to-strings conversion
295+
data = self.encoding_details.decode_bytes_to_stringarray(data)
296+
return data
307297

308298

309299
class EncodedNetCDFWriteProxy(NetCDFWriteProxy):
310-
DATASET_CLASS = EncodedDataset
300+
def __init__(self, filepath, cf_var, file_write_lock):
301+
super.__init__(filepath, cf_var, file_write_lock)
302+
self.encoding_details = VariableEncoder(cf_var)
303+
304+
def __setitem__(self, key, data):
305+
data = np.asanyarray(data)
306+
# Apply the optional strings-to-bytes conversion
307+
data = self.encoding_details.encode_strings_as_bytearray(data)
308+
super.__setitem__(key, data)

lib/iris/fileformats/netcdf/_thread_safe_nc.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -314,15 +314,22 @@ def fromcdl(cls, *args, **kwargs):
314314
class NetCDFDataProxy:
315315
"""A reference to the data payload of a single NetCDF file variable."""
316316

317-
__slots__ = ("shape", "dtype", "path", "variable_name", "fill_value")
318-
DATASET_CLASS = netCDF4.Dataset
319-
320-
def __init__(self, shape, dtype, path, variable_name, fill_value):
321-
self.shape = shape
317+
__slots__ = (
318+
"shape",
319+
"dtype",
320+
"path",
321+
"variable_name",
322+
"fill_value",
323+
"use_byte_data",
324+
)
325+
326+
def __init__(self, cf_var, dtype, path, fill_value, *, use_byte_data=False):
327+
self.shape = cf_var.shape
328+
self.variable_name = cf_var.name
322329
self.dtype = dtype
323330
self.path = path
324-
self.variable_name = variable_name
325331
self.fill_value = fill_value
332+
self.use_byte_data = use_byte_data
326333

327334
@property
328335
def ndim(self):
@@ -338,9 +345,11 @@ def __getitem__(self, keys):
338345
# netCDF4 library, presumably because __getitem__ gets called so many
339346
# times by Dask. Use _GLOBAL_NETCDF4_LOCK directly instead.
340347
with _GLOBAL_NETCDF4_LOCK:
341-
dataset = self.DATASET_CLASS(self.path)
348+
dataset = netCDF4.Dataset(self.path)
342349
try:
343350
variable = dataset.variables[self.variable_name]
351+
if self.use_byte_data:
352+
variable.set_auto_mask(False)
344353
# Get the NetCDF variable data and slice.
345354
var = variable[keys]
346355
finally:
@@ -375,8 +384,6 @@ class NetCDFWriteProxy:
375384
TODO: could be improved with a caching scheme, but this just about works.
376385
"""
377386

378-
DATASET_CLASS = netCDF4.Dataset
379-
380387
def __init__(self, filepath, cf_var, file_write_lock):
381388
self.path = filepath
382389
self.varname = cf_var.name
@@ -404,7 +411,7 @@ def __setitem__(self, keys, array_data):
404411
# investigation needed.
405412
for attempt in range(5):
406413
try:
407-
dataset = self.DATASET_CLASS(self.path, "r+")
414+
dataset = netCDF4.Dataset(self.path, "r+")
408415
break
409416
except OSError:
410417
if attempt < 4:

0 commit comments

Comments
 (0)