@@ -72,7 +72,7 @@ def check_raw_content(path, varname, expected_byte_array):
7272 )
7373
7474
75- def _make_bytearray_inner (data , encoding ):
75+ def _make_bytearray_inner (data , bytewidth , encoding ):
7676 # Convert to a (list of [lists of..]) strings or bytes to a
7777 # (list of [lists of..]) length-1 bytes with an extra dimension.
7878 if isinstance (data , str ):
@@ -81,61 +81,25 @@ def _make_bytearray_inner(data, encoding):
8181 if isinstance (data , bytes ):
8282 # iterate over bytes to get a sequence of length-1 bytes (what np.array wants)
8383 result = [data [i : i + 1 ] for i in range (len (data ))]
84+ # pad or truncate everything to the required bytewidth
85+ result = (result + [b"\0 " ] * bytewidth )[:bytewidth ]
8486 else :
8587 # If not string/bytes, expect the input to be a list.
8688 # N.B. the recursion is inefficient, but we don't care about that here
87- result = [_make_bytearray_inner (part , encoding ) for part in data ]
89+ result = [_make_bytearray_inner (part , bytewidth , encoding ) for part in data ]
8890 return result
8991
9092
91- def make_bytearray (data , encoding = "ascii" ):
93+ def make_bytearray (data , bytewidth , encoding = "ascii" ):
9294 """Convert bytes or lists of bytes into a numpy byte array.
9395
9496 This is largely to avoid using "encode_stringarray_as_bytearray", since we don't
9597 want to depend on that when we should be testing it.
9698 So, it mostly replicates the function of that, but it does also support bytes in the
97- input, and it automatically finds + applies the maximum bytes-lengths in the input .
99+ input.
98100 """
99101 # First, Convert to a (list of [lists of]..) length-1 bytes objects
100- data = _make_bytearray_inner (data , encoding )
101-
102- # Numbers of bytes in the inner dimension are the lengths of bytes/strings input,
103- # so they aren't all the same.
104- # To enable array conversion, we fix that by expanding all to the max length
105-
106- def get_maxlen (data ):
107- # Find the maximum number of bytes in the inner dimension.
108- if not isinstance (data , list ):
109- # Inner bytes object
110- assert isinstance (data , bytes )
111- longest = len (data )
112- else :
113- # We have a list: either a list of bytes, or a list of lists.
114- if len (data ) == 0 or not isinstance (data [0 ], list ):
115- # inner-most list, should contain bytes if anything
116- assert len (data ) == 0 or isinstance (data [0 ], bytes )
117- # return n-bytes
118- longest = len (data )
119- else :
120- # list of lists: return max length of sub-lists
121- longest = max (get_maxlen (part ) for part in data )
122- return longest
123-
124- maxlen = get_maxlen (data )
125-
126- def extend_all_to_maxlen (data , length , filler = b"\0 " ):
127- # Extend each "innermost" list (of single bytes) to the required length
128- if isinstance (data , list ):
129- if len (data ) == 0 or not isinstance (data [0 ], list ):
130- # Pad all the inner-most lists to the required number of elements
131- n_extra = length - len (data )
132- if n_extra > 0 :
133- data = data + [filler ] * n_extra
134- else :
135- data = [extend_all_to_maxlen (part , length , filler ) for part in data ]
136- return data
137-
138- data = extend_all_to_maxlen (data , maxlen )
102+ data = _make_bytearray_inner (data , bytewidth , encoding )
139103 # We should now be able to create an array of single bytes.
140104 result = np .array (data )
141105 assert result .dtype == "<S1"
@@ -171,7 +135,7 @@ def test_write_strings(self, encoding, tempdir):
171135
172136 # Close, re-open as an "ordinary" dataset, and check the raw content.
173137 ds_encoded .close ()
174- expected_bytes = make_bytearray (writedata , write_encoding )
138+ expected_bytes = make_bytearray (writedata , strlen , write_encoding )
175139 check_raw_content (path , "vxs" , expected_bytes )
176140
177141 # Check also that the "_Encoding" property is as expected
@@ -183,22 +147,24 @@ def test_scalar(self, tempdir):
183147 # Like 'test_write_strings', but the variable has *only* the string dimension.
184148 path = tempdir / "test_writestrings_scalar.nc"
185149
186- ds_encoded = make_encoded_dataset (path , strlen = 5 )
150+ strlen = 5
151+ ds_encoded = make_encoded_dataset (path , strlen = strlen )
187152 v = ds_encoded .createVariable ("v0_scalar" , "S1" , ("strlen" ,))
188153
189154 # Checks that we *can* write a string
190155 v [:] = np .array ("stuff" , dtype = str )
191156
192157 # Close, re-open as an "ordinary" dataset, and check the raw content.
193158 ds_encoded .close ()
194- expected_bytes = make_bytearray (b"stuff" )
159+ expected_bytes = make_bytearray (b"stuff" , strlen )
195160 check_raw_content (path , "v0_scalar" , expected_bytes )
196161
197162 def test_multidim (self , tempdir ):
198163 # Like 'test_write_strings', but the variable has additional dimensions.
199164 path = tempdir / "test_writestrings_multidim.nc"
200165
201- ds_encoded = make_encoded_dataset (path , strlen = 5 )
166+ strlen = 5
167+ ds_encoded = make_encoded_dataset (path , strlen = strlen )
202168 ds_encoded .createDimension ("y" , 2 )
203169 v = ds_encoded .createVariable (
204170 "vyxn" ,
@@ -219,7 +185,7 @@ def test_multidim(self, tempdir):
219185
220186 # Close, re-open as an "ordinary" dataset, and check the raw content.
221187 ds_encoded .close ()
222- expected_bytes = make_bytearray (test_data )
188+ expected_bytes = make_bytearray (test_data , strlen )
223189 check_raw_content (path , "vyxn" , expected_bytes )
224190
225191 def test_write_encoding_failure (self , tempdir ):
@@ -236,16 +202,18 @@ def test_write_encoding_failure(self, tempdir):
236202 def test_overlength (self , tempdir ):
237203 # Check expected behaviour with over-length data
238204 path = tempdir / "test_writestrings_overlength.nc"
239- ds = make_encoded_dataset (path , strlen = 5 , encoding = "ascii" )
205+ strlen = 5
206+ ds = make_encoded_dataset (path , strlen = strlen , encoding = "ascii" )
240207 v = ds .variables ["vxs" ]
241208 v [:] = ["1" , "123456789" , "two" ]
242- expected_bytes = make_bytearray (["1" , "12345" , "two" ])
209+ expected_bytes = make_bytearray (["1" , "12345" , "two" ], strlen )
243210 check_raw_content (path , "vxs" , expected_bytes )
244211
245212 def test_overlength_splitcoding (self , tempdir ):
246213 # Check expected behaviour when non-ascii multibyte coding gets truncated
247214 path = tempdir / "test_writestrings_overlength_splitcoding.nc"
248- ds = make_encoded_dataset (path , strlen = 5 , encoding = "utf-8" )
215+ strlen = 5
216+ ds = make_encoded_dataset (path , strlen = strlen , encoding = "utf-8" )
249217 v = ds .variables ["vxs" ]
250218 v [:] = ["1" , "1234ü" , "two" ]
251219 # This creates a problem: it won't read back
@@ -263,7 +231,7 @@ def test_overlength_splitcoding(self, tempdir):
263231 b"1234\xc3 " , # NOTE: truncated encoding
264232 b"two" ,
265233 ]
266- expected_bytearray = make_bytearray (expected_bytes )
234+ expected_bytearray = make_bytearray (expected_bytes , strlen )
267235 check_raw_content (path , "vxs" , expected_bytearray )
268236
269237
@@ -272,9 +240,9 @@ class TestWriteChars:
272240 def test_write_chars (self , tempdir , write_form ):
273241 encoding = "utf-8"
274242 write_strings = samples_3_nonascii
275- write_bytes = make_bytearray (write_strings , encoding = encoding )
243+ strlen = strings_maxbytes (write_strings , encoding )
244+ write_bytes = make_bytearray (write_strings , strlen , encoding = encoding )
276245 # NOTE: 'flexi' form util decides the width needs to be 7 !!
277- strlen = write_bytes .shape [- 1 ]
278246 path = tempdir / f"test_writechars_{ write_form } .nc"
279247 ds = make_encoded_dataset (path , encoding = encoding , strlen = strlen )
280248 v = ds .variables ["vxs" ]
0 commit comments