|
38 | 38 | )
|
39 | 39 | )
|
40 | 40 |
|
41 |
| -# Matlab numeric codes |
42 |
| -matlab_scalar_mapping = { |
| 41 | +scalar_codes = { |
43 | 42 | np.dtype("bool"): 3, # LOGICAL
|
44 | 43 | np.dtype("c"): 4, # CHAR
|
45 | 44 | np.dtype("O"): 5, # VOID
|
|
53 | 52 | np.dtype("uint32"): 13, # UINT32
|
54 | 53 | np.dtype("int64"): 14, # INT64
|
55 | 54 | np.dtype("uint64"): 15, # UINT64
|
| 55 | + np.dtype( |
| 56 | + "<M8[us]" |
| 57 | + ): 50, # Datetime[us], skipped to 50 to accommodate more matlab types |
56 | 58 | }
|
57 | 59 |
|
58 |
| -dtype_list = list(scalar_type.values()) |
59 |
| -type_names = list(scalar_type) |
| 60 | +# Lookup dict for quickly getting a scalar type from its code |
| 61 | +scalar_code_lookup = dict((v, k) for k, v in scalar_codes.items()) |
| 62 | +# Lookup dict for quickly getting a scalar name from its type |
| 63 | +scalar_name_lookup = dict((v, k) for k, v in scalar_type.items()) |
| 64 | + |
60 | 65 |
|
61 | 66 | compression = {b"ZL123\0": zlib.decompress}
|
62 | 67 |
|
@@ -231,14 +236,18 @@ def read_array(self):
|
231 | 236 | shape = self.read_value(count=n_dims)
|
232 | 237 | n_elem = np.prod(shape, dtype=int)
|
233 | 238 | dtype_id, is_complex = self.read_value("uint32", 2)
|
234 |
| - dtype = dtype_list[dtype_id] |
235 | 239 |
|
236 |
| - if type_names[dtype_id] == "VOID": |
| 240 | + # Get dtype from type id |
| 241 | + dtype = scalar_code_lookup[dtype_id] |
| 242 | + |
| 243 | + # Check if name is void |
| 244 | + if scalar_name_lookup[dtype] == "VOID": |
237 | 245 | data = np.array(
|
238 | 246 | list(self.read_blob(self.read_value()) for _ in range(n_elem)),
|
239 | 247 | dtype=np.dtype("O"),
|
240 | 248 | )
|
241 |
| - elif type_names[dtype_id] == "CHAR": |
| 249 | + # Check if name is char |
| 250 | + elif scalar_name_lookup[dtype] == "CHAR": |
242 | 251 | # compensate for MATLAB packing of char arrays
|
243 | 252 | data = self.read_value(dtype, count=2 * n_elem)
|
244 | 253 | data = data[::2].astype("U1")
|
@@ -271,26 +280,24 @@ def pack_array(self, array):
|
271 | 280 | is_complex = np.iscomplexobj(array)
|
272 | 281 | if is_complex:
|
273 | 282 | array, imaginary = np.real(array), np.imag(array)
|
274 |
| - type_id = ( |
275 |
| - matlab_scalar_mapping[np.dtype("O")] |
276 |
| - if array.dtype not in matlab_scalar_mapping |
277 |
| - else ( |
278 |
| - matlab_scalar_mapping[array.dtype] |
279 |
| - if array.dtype.char != "U" |
280 |
| - else matlab_scalar_mapping[np.dtype("O")] |
281 |
| - ) |
282 |
| - ) |
283 |
| - if dtype_list[type_id] is None: |
284 |
| - raise DataJointError("Type %s is ambiguous or unknown" % array.dtype) |
| 283 | + try: |
| 284 | + type_id = scalar_codes[array.dtype] |
| 285 | + except KeyError: |
| 286 | + if array.dtype.char == "U": |
| 287 | + type_id = scalar_codes[np.dtype("O")] |
| 288 | + pass |
| 289 | + else: |
| 290 | + raise DataJointError("Type %s is ambiguous or unknown" % array.dtype) |
285 | 291 |
|
286 | 292 | blob += np.array([type_id, is_complex], dtype=np.uint32).tobytes()
|
287 |
| - if type_names[type_id] == "VOID": # array of dtype('O') |
| 293 | + # array of dtype('O'), U is for unicode string |
| 294 | + if array.dtype.char == "U" or scalar_name_lookup[array.dtype] == "VOID": |
288 | 295 | blob += b"".join(
|
289 | 296 | len_u64(it) + it
|
290 | 297 | for it in (self.pack_blob(e) for e in array.flatten(order="F"))
|
291 | 298 | )
|
292 | 299 | self.set_dj0() # not supported by original mym
|
293 |
| - elif type_names[type_id] == "CHAR": # array of dtype('c') |
| 300 | + elif scalar_name_lookup[array.dtype] == "CHAR": # array of dtype('c') |
294 | 301 | blob += (
|
295 | 302 | array.view(np.uint8).astype(np.uint16).tobytes()
|
296 | 303 | ) # convert to 16-bit chars for MATLAB
|
|
0 commit comments