Skip to content

Commit 8596c44

Browse files
committed
refactor blob.py
1 parent c0f9af4 commit 8596c44

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

datajoint/blob.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@
3838
)
3939
)
4040

41-
# Matlab numeric codes
42-
matlab_scalar_mapping = {
41+
scalar_codes = {
4342
np.dtype("bool"): 3, # LOGICAL
4443
np.dtype("c"): 4, # CHAR
4544
np.dtype("O"): 5, # VOID
@@ -53,10 +52,16 @@
5352
np.dtype("uint32"): 13, # UINT32
5453
np.dtype("int64"): 14, # INT64
5554
np.dtype("uint64"): 15, # UINT64
55+
np.dtype(
56+
"<M8[us]"
57+
): 50, # Datetime[us], skipped to 50 to accommodate more matlab types
5658
}
5759

58-
dtype_list = list(scalar_type.values())
59-
type_names = list(scalar_type)
60+
# Lookup dict for quickly getting a scalar type from its code
61+
scalar_code_lookup = dict((v, k) for k, v in scalar_codes.items())
62+
# Lookup dict for quickly getting a scalar name from its type
63+
scalar_name_lookup = dict((v, k) for k, v in scalar_type.items())
64+
6065

6166
compression = {b"ZL123\0": zlib.decompress}
6267

@@ -231,14 +236,18 @@ def read_array(self):
231236
shape = self.read_value(count=n_dims)
232237
n_elem = np.prod(shape, dtype=int)
233238
dtype_id, is_complex = self.read_value("uint32", 2)
234-
dtype = dtype_list[dtype_id]
235239

236-
if type_names[dtype_id] == "VOID":
240+
# Get dtype from type id
241+
dtype = scalar_code_lookup[dtype_id]
242+
243+
# Check if name is void
244+
if scalar_name_lookup[dtype] == "VOID":
237245
data = np.array(
238246
list(self.read_blob(self.read_value()) for _ in range(n_elem)),
239247
dtype=np.dtype("O"),
240248
)
241-
elif type_names[dtype_id] == "CHAR":
249+
# Check if name is char
250+
elif scalar_name_lookup[dtype] == "CHAR":
242251
# compensate for MATLAB packing of char arrays
243252
data = self.read_value(dtype, count=2 * n_elem)
244253
data = data[::2].astype("U1")
@@ -271,26 +280,24 @@ def pack_array(self, array):
271280
is_complex = np.iscomplexobj(array)
272281
if is_complex:
273282
array, imaginary = np.real(array), np.imag(array)
274-
type_id = (
275-
matlab_scalar_mapping[np.dtype("O")]
276-
if array.dtype not in matlab_scalar_mapping
277-
else (
278-
matlab_scalar_mapping[array.dtype]
279-
if array.dtype.char != "U"
280-
else matlab_scalar_mapping[np.dtype("O")]
281-
)
282-
)
283-
if dtype_list[type_id] is None:
284-
raise DataJointError("Type %s is ambiguous or unknown" % array.dtype)
283+
try:
284+
type_id = scalar_codes[array.dtype]
285+
except KeyError:
286+
if array.dtype.char == "U":
287+
type_id = scalar_codes[np.dtype("O")]
288+
pass
289+
else:
290+
raise DataJointError("Type %s is ambiguous or unknown" % array.dtype)
285291

286292
blob += np.array([type_id, is_complex], dtype=np.uint32).tobytes()
287-
if type_names[type_id] == "VOID": # array of dtype('O')
293+
# array of dtype('O'), U is for unicode string
294+
if array.dtype.char == "U" or scalar_name_lookup[array.dtype] == "VOID":
288295
blob += b"".join(
289296
len_u64(it) + it
290297
for it in (self.pack_blob(e) for e in array.flatten(order="F"))
291298
)
292299
self.set_dj0() # not supported by original mym
293-
elif type_names[type_id] == "CHAR": # array of dtype('c')
300+
elif scalar_name_lookup[array.dtype] == "CHAR": # array of dtype('c')
294301
blob += (
295302
array.view(np.uint8).astype(np.uint16).tobytes()
296303
) # convert to 16-bit chars for MATLAB

0 commit comments

Comments
 (0)