Skip to content

Commit 6d6c6fa

Browse files
committed
Radically simplify 'make_bytesarray', by using a known specified bytewidth.
1 parent 7baee94 commit 6d6c6fa

File tree

1 file changed

+22
-54
lines changed

1 file changed

+22
-54
lines changed

lib/iris/tests/unit/fileformats/netcdf/test_bytecoding_datasets.py

Lines changed: 22 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def check_raw_content(path, varname, expected_byte_array):
7272
)
7373

7474

75-
def _make_bytearray_inner(data, encoding):
75+
def _make_bytearray_inner(data, bytewidth, encoding):
7676
# Convert to a (list of [lists of..]) strings or bytes to a
7777
# (list of [lists of..]) length-1 bytes with an extra dimension.
7878
if isinstance(data, str):
@@ -81,61 +81,25 @@ def _make_bytearray_inner(data, encoding):
8181
if isinstance(data, bytes):
8282
# iterate over bytes to get a sequence of length-1 bytes (what np.array wants)
8383
result = [data[i : i + 1] for i in range(len(data))]
84+
# pad or truncate everything to the required bytewidth
85+
result = (result + [b"\0"] * bytewidth)[:bytewidth]
8486
else:
8587
# If not string/bytes, expect the input to be a list.
8688
# N.B. the recursion is inefficient, but we don't care about that here
87-
result = [_make_bytearray_inner(part, encoding) for part in data]
89+
result = [_make_bytearray_inner(part, bytewidth, encoding) for part in data]
8890
return result
8991

9092

91-
def make_bytearray(data, encoding="ascii"):
93+
def make_bytearray(data, bytewidth, encoding="ascii"):
9294
"""Convert bytes or lists of bytes into a numpy byte array.
9395
9496
This is largely to avoid using "encode_stringarray_as_bytearray", since we don't
9597
want to depend on that when we should be testing it.
9698
So, it mostly replicates the function of that, but it does also support bytes in the
97-
input, and it automatically finds + applies the maximum bytes-lengths in the input.
99+
input.
98100
"""
99101
# First, Convert to a (list of [lists of]..) length-1 bytes objects
100-
data = _make_bytearray_inner(data, encoding)
101-
102-
# Numbers of bytes in the inner dimension are the lengths of bytes/strings input,
103-
# so they aren't all the same.
104-
# To enable array conversion, we fix that by expanding all to the max length
105-
106-
def get_maxlen(data):
107-
# Find the maximum number of bytes in the inner dimension.
108-
if not isinstance(data, list):
109-
# Inner bytes object
110-
assert isinstance(data, bytes)
111-
longest = len(data)
112-
else:
113-
# We have a list: either a list of bytes, or a list of lists.
114-
if len(data) == 0 or not isinstance(data[0], list):
115-
# inner-most list, should contain bytes if anything
116-
assert len(data) == 0 or isinstance(data[0], bytes)
117-
# return n-bytes
118-
longest = len(data)
119-
else:
120-
# list of lists: return max length of sub-lists
121-
longest = max(get_maxlen(part) for part in data)
122-
return longest
123-
124-
maxlen = get_maxlen(data)
125-
126-
def extend_all_to_maxlen(data, length, filler=b"\0"):
127-
# Extend each "innermost" list (of single bytes) to the required length
128-
if isinstance(data, list):
129-
if len(data) == 0 or not isinstance(data[0], list):
130-
# Pad all the inner-most lists to the required number of elements
131-
n_extra = length - len(data)
132-
if n_extra > 0:
133-
data = data + [filler] * n_extra
134-
else:
135-
data = [extend_all_to_maxlen(part, length, filler) for part in data]
136-
return data
137-
138-
data = extend_all_to_maxlen(data, maxlen)
102+
data = _make_bytearray_inner(data, bytewidth, encoding)
139103
# We should now be able to create an array of single bytes.
140104
result = np.array(data)
141105
assert result.dtype == "<S1"
@@ -171,7 +135,7 @@ def test_write_strings(self, encoding, tempdir):
171135

172136
# Close, re-open as an "ordinary" dataset, and check the raw content.
173137
ds_encoded.close()
174-
expected_bytes = make_bytearray(writedata, write_encoding)
138+
expected_bytes = make_bytearray(writedata, strlen, write_encoding)
175139
check_raw_content(path, "vxs", expected_bytes)
176140

177141
# Check also that the "_Encoding" property is as expected
@@ -183,22 +147,24 @@ def test_scalar(self, tempdir):
183147
# Like 'test_write_strings', but the variable has *only* the string dimension.
184148
path = tempdir / "test_writestrings_scalar.nc"
185149

186-
ds_encoded = make_encoded_dataset(path, strlen=5)
150+
strlen = 5
151+
ds_encoded = make_encoded_dataset(path, strlen=strlen)
187152
v = ds_encoded.createVariable("v0_scalar", "S1", ("strlen",))
188153

189154
# Checks that we *can* write a string
190155
v[:] = np.array("stuff", dtype=str)
191156

192157
# Close, re-open as an "ordinary" dataset, and check the raw content.
193158
ds_encoded.close()
194-
expected_bytes = make_bytearray(b"stuff")
159+
expected_bytes = make_bytearray(b"stuff", strlen)
195160
check_raw_content(path, "v0_scalar", expected_bytes)
196161

197162
def test_multidim(self, tempdir):
198163
# Like 'test_write_strings', but the variable has additional dimensions.
199164
path = tempdir / "test_writestrings_multidim.nc"
200165

201-
ds_encoded = make_encoded_dataset(path, strlen=5)
166+
strlen = 5
167+
ds_encoded = make_encoded_dataset(path, strlen=strlen)
202168
ds_encoded.createDimension("y", 2)
203169
v = ds_encoded.createVariable(
204170
"vyxn",
@@ -219,7 +185,7 @@ def test_multidim(self, tempdir):
219185

220186
# Close, re-open as an "ordinary" dataset, and check the raw content.
221187
ds_encoded.close()
222-
expected_bytes = make_bytearray(test_data)
188+
expected_bytes = make_bytearray(test_data, strlen)
223189
check_raw_content(path, "vyxn", expected_bytes)
224190

225191
def test_write_encoding_failure(self, tempdir):
@@ -236,16 +202,18 @@ def test_write_encoding_failure(self, tempdir):
236202
def test_overlength(self, tempdir):
237203
# Check expected behaviour with over-length data
238204
path = tempdir / "test_writestrings_overlength.nc"
239-
ds = make_encoded_dataset(path, strlen=5, encoding="ascii")
205+
strlen = 5
206+
ds = make_encoded_dataset(path, strlen=strlen, encoding="ascii")
240207
v = ds.variables["vxs"]
241208
v[:] = ["1", "123456789", "two"]
242-
expected_bytes = make_bytearray(["1", "12345", "two"])
209+
expected_bytes = make_bytearray(["1", "12345", "two"], strlen)
243210
check_raw_content(path, "vxs", expected_bytes)
244211

245212
def test_overlength_splitcoding(self, tempdir):
246213
# Check expected behaviour when non-ascii multibyte coding gets truncated
247214
path = tempdir / "test_writestrings_overlength_splitcoding.nc"
248-
ds = make_encoded_dataset(path, strlen=5, encoding="utf-8")
215+
strlen = 5
216+
ds = make_encoded_dataset(path, strlen=strlen, encoding="utf-8")
249217
v = ds.variables["vxs"]
250218
v[:] = ["1", "1234ü", "two"]
251219
# This creates a problem: it won't read back
@@ -263,7 +231,7 @@ def test_overlength_splitcoding(self, tempdir):
263231
b"1234\xc3", # NOTE: truncated encoding
264232
b"two",
265233
]
266-
expected_bytearray = make_bytearray(expected_bytes)
234+
expected_bytearray = make_bytearray(expected_bytes, strlen)
267235
check_raw_content(path, "vxs", expected_bytearray)
268236

269237

@@ -272,9 +240,9 @@ class TestWriteChars:
272240
def test_write_chars(self, tempdir, write_form):
273241
encoding = "utf-8"
274242
write_strings = samples_3_nonascii
275-
write_bytes = make_bytearray(write_strings, encoding=encoding)
243+
strlen = strings_maxbytes(write_strings, encoding)
244+
write_bytes = make_bytearray(write_strings, strlen, encoding=encoding)
276245
# NOTE: 'flexi' form util decides the width needs to be 7 !!
277-
strlen = write_bytes.shape[-1]
278246
path = tempdir / f"test_writechars_{write_form}.nc"
279247
ds = make_encoded_dataset(path, encoding=encoding, strlen=strlen)
280248
v = ds.variables["vxs"]

0 commit comments

Comments
 (0)