Skip to content

Commit 4b17638

Browse files
committed
Add read tests.
1 parent 6d6c6fa commit 4b17638

File tree

2 files changed

+184
-19
lines changed

2 files changed

+184
-19
lines changed

lib/iris/fileformats/netcdf/_bytecoding_datasets.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def __getitem__(self, keys):
175175
if DECODE_TO_STRINGS_ON_READ and self._is_chardata():
176176
encoding = self._get_encoding() or DEFAULT_READ_ENCODING
177177
# N.B. typically, read encoding default is UTF-8 --> a "usually safe" choice
178-
strlen = self._get_string_length()
178+
strlen = self._get_string_width()
179179
try:
180180
data = decode_bytesarray_to_stringarray(data, encoding, strlen)
181181
except UnicodeDecodeError as err:
@@ -194,11 +194,11 @@ def __setitem__(self, keys, data):
194194
# N.B. we never need to UNset this, as we totally control it
195195
self._contained_instance.set_auto_chartostring(False)
196196

197-
encoding = self._get_encoding() or DEFAULT_WRITE_ENCODING
198197
# N.B. typically, write encoding default is "ascii" --> fails bad content
199198
if data.dtype.kind == "U":
200199
try:
201-
strlen = self._get_string_length()
200+
encoding = self._get_encoding() or DEFAULT_WRITE_ENCODING
201+
strlen = self._get_byte_width()
202202
data = encode_stringarray_as_bytearray(data, encoding, strlen)
203203
except UnicodeEncodeError as err:
204204
msg = (
@@ -230,12 +230,36 @@ def _get_encoding(self) -> str | None:
230230

231231
return result
232232

233-
def _get_string_length(self):
233+
def _get_byte_width(self) -> int | None:
234+
if not hasattr(self, "_bytewidth"):
235+
n_bytes = self.group().dimensions[self.dimensions[-1]].size
236+
# Cache this length control on the variable -- but not as a netcdf attribute
237+
self.__dict__["_bytewidth"] = n_bytes
238+
239+
return self.__dict__["_bytewidth"]
240+
241+
def _get_string_width(self):
234242
"""Return the string-length defined for this variable."""
235243
if not hasattr(self, "_strlen"):
236-
# Work out the string length from the parent dataset dimensions.
237-
strlen = self.group().dimensions[self.dimensions[-1]].size
238-
# Cache this on the variable -- but not as a netcdf attribute (!)
244+
if hasattr(self, "iris_string_width"):
245+
strlen = self.get_ncattr("iris_string_width")
246+
else:
247+
# Work out the actual byte width from the parent dataset dimensions.
248+
strlen = self._get_byte_width()
249+
# Convert the string dimension length (i.e. bytes) to a sufficiently-long
250+
# string width, depending on the encoding used.
251+
encoding = self._get_encoding() or DEFAULT_READ_ENCODING
252+
# regularise the name for comparison with recognised ones
253+
encoding = codecs.lookup(encoding).name
254+
if "utf-16" in encoding:
255+
# Each char needs at least 2 bytes -- including a terminator char
256+
strlen = (strlen // 2) - 1
257+
elif "utf-32" in encoding:
258+
# Each char needs exactly 4 bytes -- including a terminator char
259+
strlen = (strlen // 4) - 1
260+
# "ELSE": assume there can be (at most) as many chars as bytes
261+
262+
# Cache this length control on the variable -- but not as a netcdf attribute
239263
self.__dict__["_strlen"] = strlen
240264

241265
return self._strlen

lib/iris/tests/unit/fileformats/netcdf/test_bytecoding_datasets.py

Lines changed: 153 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
import numpy as np
1010
import pytest
1111

12-
from iris.fileformats.netcdf._bytecoding_datasets import EncodedDataset
12+
from iris.fileformats.netcdf._bytecoding_datasets import (
13+
DECODE_TO_STRINGS_ON_READ,
14+
EncodedDataset,
15+
)
1316
from iris.fileformats.netcdf._thread_safe_nc import DatasetWrapper
1417

1518
encoding_options = [None, "ascii", "utf-8", "utf-32"]
@@ -62,14 +65,17 @@ def fetch_undecoded_var(path, varname):
6265
return v
6366

6467

68+
def check_array_matching(arr1, arr2):
69+
"""Check for arrays matching shape, dtype and content."""
70+
assert (
71+
arr1.shape == arr2.shape and arr1.dtype == arr2.dtype and np.all(arr1 == arr2)
72+
)
73+
74+
6575
def check_raw_content(path, varname, expected_byte_array):
6676
v = fetch_undecoded_var(path, varname)
6777
bytes_result = v[:]
68-
assert (
69-
bytes_result.shape == expected_byte_array.shape
70-
and bytes_result.dtype == expected_byte_array.dtype
71-
and np.all(bytes_result == expected_byte_array)
72-
)
78+
check_array_matching(bytes_result, expected_byte_array)
7379

7480

7581
def _make_bytearray_inner(data, bytewidth, encoding):
@@ -102,7 +108,7 @@ def make_bytearray(data, bytewidth, encoding="ascii"):
102108
data = _make_bytearray_inner(data, bytewidth, encoding)
103109
# We should now be able to create an array of single bytes.
104110
result = np.array(data)
105-
assert result.dtype == "<S1"
111+
assert result.dtype == "S1"
106112
return result
107113

108114

@@ -113,7 +119,7 @@ class TestWriteStrings:
113119
which is separately tested -- see 'TestReadStrings'.
114120
"""
115121

116-
def test_write_strings(self, encoding, tempdir):
122+
def test_encodings(self, encoding, tempdir):
117123
# Create a dataset with the variable
118124
path = tempdir / f"test_writestrings_encoding_{encoding!s}.nc"
119125

@@ -258,8 +264,143 @@ def test_write_chars(self, tempdir, write_form):
258264
check_raw_content(path, "vxs", write_bytes)
259265

260266

261-
class TestReadStrings:
262-
"""Test how character data is read and converted to strings."""
267+
class TestRead:
268+
"""Test how character data is read and converted to strings.
269+
270+
N.B. many testcases here parallel the 'TestWriteStrings' : we are creating test
271+
datafiles with 'make_dataset' and assigning raw bytes, as-per 'TestWriteChars'.
272+
273+
We are mostly checking here that reading back produces string arrays as expected.
274+
However, it is simple + convenient to also check the 'DECODE_TO_STRINGS_ON_READ'
275+
function here, i.e. "raw" bytes reads. So that is also done in this class.
276+
"""
277+
278+
@pytest.fixture(params=["strings", "bytes"])
279+
def readmode(self, request):
280+
return request.param
281+
282+
def test_encodings(self, encoding, tempdir, readmode):
283+
# Create a dataset with the variable
284+
path = tempdir / f"test_read_encodings_{encoding!s}_{readmode}.nc"
285+
286+
if encoding in [None, "ascii"]:
287+
write_strings = samples_3_ascii
288+
write_encoding = "ascii"
289+
else:
290+
write_strings = samples_3_nonascii
291+
write_encoding = encoding
292+
293+
write_strings = write_strings.copy() # just for safety?
294+
strlen = strings_maxbytes(write_strings, write_encoding)
295+
write_bytes = make_bytearray(write_strings, strlen, encoding=write_encoding)
296+
297+
ds_encoded = make_encoded_dataset(path, strlen, encoding)
298+
v = ds_encoded.variables["vxs"]
299+
v[:] = write_bytes
300+
301+
if readmode == "strings":
302+
# Test "normal" read --> string array
303+
result = v[:]
304+
expected = write_strings
305+
if encoding == "utf-8":
306+
# In this case, with the given non-ascii sample data, the
307+
# "default minimum string length" is overestimated.
308+
assert strlen == 7 and result.dtype == "U7"
309+
# correct the result dtype to pass the write_strings comparison below
310+
truncated_result = result.astype("U4")
311+
# Also check that content is the same (i.e. not actually truncated)
312+
assert np.all(truncated_result == result)
313+
result = truncated_result
314+
else:
315+
# Test "raw" read --> byte array
316+
with DECODE_TO_STRINGS_ON_READ.context(False):
317+
result = v[:]
318+
expected = write_bytes
319+
320+
check_array_matching(result, expected)
321+
322+
def test_scalar(self, tempdir, readmode):
323+
# Like 'test_write_strings', but the variable has *only* the string dimension.
324+
path = tempdir / f"test_read_scalar_{readmode}.nc"
325+
326+
strlen = 5
327+
ds_encoded = make_encoded_dataset(path, strlen=strlen)
328+
v = ds_encoded.createVariable("v0_scalar", "S1", ("strlen",))
329+
330+
data_string = "stuff"
331+
data_bytes = make_bytearray(data_string, 5)
332+
333+
# Checks that we *can* write a string
334+
v[:] = data_bytes
335+
336+
if readmode == "strings":
337+
# Test "normal" read --> string array
338+
result = v[:]
339+
expected = np.array(data_string)
340+
else:
341+
# Test "raw" read --> byte array
342+
with DECODE_TO_STRINGS_ON_READ.context(False):
343+
result = v[:]
344+
expected = data_bytes
345+
346+
check_array_matching(result, expected)
347+
348+
def test_multidim(self, tempdir, readmode):
349+
# Like 'test_write_strings', but the variable has additional dimensions.
350+
path = tempdir / f"test_read_multidim_{readmode}.nc"
351+
352+
strlen = 5
353+
ds_encoded = make_encoded_dataset(path, strlen=strlen)
354+
ds_encoded.createDimension("y", 2)
355+
v = ds_encoded.createVariable(
356+
"vyxn",
357+
"S1",
358+
(
359+
"y",
360+
"x",
361+
"strlen",
362+
),
363+
)
364+
365+
# Check that we *can* write a multidimensional string array
366+
test_strings = [
367+
["one", "n", ""],
368+
["two", "xxxxx", "four"],
369+
]
370+
test_bytes = make_bytearray(test_strings, strlen)
371+
v[:] = test_bytes
372+
373+
if readmode == "strings":
374+
# Test "normal" read --> string array
375+
result = v[:]
376+
expected = np.array(test_strings)
377+
else:
378+
# Test "raw" read --> byte array
379+
with DECODE_TO_STRINGS_ON_READ.context(False):
380+
result = v[:]
381+
expected = test_bytes
382+
383+
check_array_matching(result, expected)
384+
385+
def test_read_encoding_failure(self, tempdir, readmode):
386+
path = tempdir / f"test_read_encoding_failure_{readmode}.nc"
387+
strlen = 10
388+
ds = make_encoded_dataset(path, strlen=strlen, encoding="ascii")
389+
v = ds.variables["vxs"]
390+
test_utf8_bytes = make_bytearray(
391+
samples_3_nonascii, bytewidth=strlen, encoding="utf-8"
392+
)
393+
v[:] = test_utf8_bytes
394+
395+
if readmode == "strings":
396+
msg = (
397+
"Character data in variable 'vxs' could not be decoded "
398+
"with the 'ascii' encoding."
399+
)
400+
with pytest.raises(ValueError, match=msg):
401+
v[:]
402+
else:
403+
with DECODE_TO_STRINGS_ON_READ.context(False):
404+
result = v[:] # this ought to be ok!
263405

264-
def test_encodings(self, encoding):
265-
pass
406+
assert np.all(result == test_utf8_bytes)

0 commit comments

Comments
 (0)