Skip to content

Commit c2cc50b

Browse files
committed
update matlab type codes to be explicit
1 parent fbe9eaf commit c2cc50b

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

datajoint/blob.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,23 @@
3838
)
3939
)
4040

41-
rev_class_id = {dtype: i for i, dtype in enumerate(scalar_id.values())}
41+
# Matlab numeric codes
42+
matlab_scalar_mapping = {
43+
np.dtype("bool"): 3, # LOGICAL
44+
np.dtype("c"): 4, # CHAR
45+
np.dtype("O"): 5, # VOID
46+
np.dtype("float64"): 6, # DOUBLE
47+
np.dtype("float32"): 7, # SINGLE
48+
np.dtype("int8"): 8, # INT8
49+
np.dtype("uint8"): 9, # UINT8
50+
np.dtype("int16"): 10, # INT16
51+
np.dtype("uint16"): 11, # UINT16
52+
np.dtype("int32"): 12, # INT32
53+
np.dtype("uint32"): 13, # UINT32
54+
np.dtype("int64"): 14, # INT64
55+
np.dtype("uint64"): 15, # UINT64
56+
}
57+
4258
dtype_list = list(scalar_id.values())
4359
type_names = list(scalar_id)
4460

@@ -256,9 +272,13 @@ def pack_array(self, array):
256272
if is_complex:
257273
array, imaginary = np.real(array), np.imag(array)
258274
type_id = (
259-
rev_class_id[array.dtype]
260-
if array.dtype.char != "U"
261-
else rev_class_id[np.dtype("O")]
275+
matlab_scalar_mapping[np.dtype("O")]
276+
if array.dtype not in matlab_scalar_mapping
277+
else (
278+
matlab_scalar_mapping[array.dtype]
279+
if array.dtype.char != "U"
280+
else matlab_scalar_mapping[np.dtype("O")]
281+
)
262282
)
263283
if dtype_list[type_id] is None:
264284
raise DataJointError("Type %s is ambiguous or unknown" % array.dtype)

tests/test_blob.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,13 @@ def test_datetime_serialization_speed():
233233
# np arrays of np.datetime64 types is now slower than regular arrays of datetime64
234234

235235
np_array_dt_exe_time = timeit.timeit(
236-
setup='myarr=pack(np.array([np.datetime64(f"{x}") for x in range(1900, 2000)]))',
236+
setup="myarr=pack(np.array([np.datetime64(f'{x}') for x in range(1900, 2000)]))",
237237
stmt="unpack(myarr)",
238238
number=10,
239239
globals=globals(),
240240
)
241241
python_array_dt_exe_time = timeit.timeit(
242-
setup='myarr2=pack([np.datetime64(f"{x}") for x in range(1900, 2000)])',
242+
setup="myarr2=pack([datetime.now() for x in range (1900, 2000)])",
243243
stmt="unpack(myarr2)",
244244
number=10,
245245
globals=globals(),

0 commit comments

Comments
 (0)