Skip to content

Commit 42b0537

Browse files
committed
refactor
1 parent bea1781 commit 42b0537

File tree

2 files changed

+47
-55
lines changed

2 files changed

+47
-55
lines changed

datajoint/blob.py

Lines changed: 46 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -14,53 +14,42 @@
1414
from .settings import config
1515

1616

17-
scalar_type = dict(
18-
(
19-
# see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html
20-
("UNKNOWN", None),
21-
("CELL", None),
22-
("STRUCT", None),
23-
("LOGICAL", np.dtype("bool")),
24-
("CHAR", np.dtype("c")),
25-
("VOID", np.dtype("O")),
26-
("DOUBLE", np.dtype("float64")),
27-
("SINGLE", np.dtype("float32")),
28-
("INT8", np.dtype("int8")),
29-
("UINT8", np.dtype("uint8")),
30-
("INT16", np.dtype("int16")),
31-
("UINT16", np.dtype("uint16")),
32-
("INT32", np.dtype("int32")),
33-
("UINT32", np.dtype("uint32")),
34-
("INT64", np.dtype("int64")),
35-
("UINT64", np.dtype("uint64")),
36-
("FUNCTION", None),
37-
("DATETIME64", np.dtype("<M8[us]")),
38-
)
39-
)
40-
41-
scalar_codes = {
42-
np.dtype("bool"): 3, # LOGICAL
43-
np.dtype("c"): 4, # CHAR
44-
np.dtype("O"): 5, # VOID
45-
np.dtype("float64"): 6, # DOUBLE
46-
np.dtype("float32"): 7, # SINGLE
47-
np.dtype("int8"): 8, # INT8
48-
np.dtype("uint8"): 9, # UINT8
49-
np.dtype("int16"): 10, # INT16
50-
np.dtype("uint16"): 11, # UINT16
51-
np.dtype("int32"): 12, # INT32
52-
np.dtype("uint32"): 13, # UINT32
53-
np.dtype("int64"): 14, # INT64
54-
np.dtype("uint64"): 15, # UINT64
55-
np.dtype(
56-
"<M8[us]"
57-
): 50, # Datetime[us], skipped to 50 to accommodate more matlab types
17+
deserialize_lookup = {
18+
0: {"dtype": None, "scalar_type": "UNKNOWN"},
19+
1: {"dtype": None, "scalar_type": "CELL"},
20+
2: {"dtype": None, "scalar_type": "STRUCT"},
21+
3: {"dtype": np.dtype("bool"), "scalar_type": "LOGICAL"},
22+
4: {"dtype": np.dtype("c"), "scalar_type": "CHAR"},
23+
5: {"dtype": np.dtype("O"), "scalar_type": "VOID"},
24+
6: {"dtype": np.dtype("float64"), "scalar_type": "DOUBLE"},
25+
7: {"dtype": np.dtype("float32"), "scalar_type": "SINGLE"},
26+
8: {"dtype": np.dtype("int8"), "scalar_type": "INT8"},
27+
9: {"dtype": np.dtype("uint8"), "scalar_type": "UINT8"},
28+
10: {"dtype": np.dtype("int16"), "scalar_type": "INT16"},
29+
11: {"dtype": np.dtype("uint16"), "scalar_type": "UINT16"},
30+
12: {"dtype": np.dtype("int32"), "scalar_type": "INT32"},
31+
13: {"dtype": np.dtype("uint32"), "scalar_type": "UINT32"},
32+
14: {"dtype": np.dtype("int64"), "scalar_type": "INT64"},
33+
15: {"dtype": np.dtype("uint64"), "scalar_type": "UINT64"},
34+
16: {"dtype": None, "scalar_type": "FUNCTION"},
35+
128: {"dtype": np.dtype("<M8[Y]"), "scalar_type": "DATETIME64[Y]"},
36+
129: {"dtype": np.dtype("<M8[M]"), "scalar_type": "DATETIME64[M]"},
37+
130: {"dtype": np.dtype("<M8[W]"), "scalar_type": "DATETIME64[W]"},
38+
131: {"dtype": np.dtype("<M8[D]"), "scalar_type": "DATETIME64[D]"},
39+
132: {"dtype": np.dtype("<M8[h]"), "scalar_type": "DATETIME64[h]"},
40+
133: {"dtype": np.dtype("<M8[m]"), "scalar_type": "DATETIME64[m]"},
41+
134: {"dtype": np.dtype("<M8[s]"), "scalar_type": "DATETIME64[s]"},
42+
135: {"dtype": np.dtype("<M8[ms]"), "scalar_type": "DATETIME64[ms]"},
43+
136: {"dtype": np.dtype("<M8[us]"), "scalar_type": "DATETIME64[us]"},
44+
137: {"dtype": np.dtype("<M8[ps]"), "scalar_type": "DATETIME64[ps]"},
45+
138: {"dtype": np.dtype("<M8[fs]"), "scalar_type": "DATETIME64[fs]"},
46+
139: {"dtype": np.dtype("<M8[as]"), "scalar_type": "DATETIME64[as]"},
47+
}
48+
serialize_lookup = {
49+
v["dtype"]: {"type_id": k, "scalar_type": v["scalar_type"]}
50+
for k, v in deserialize_lookup.items()
51+
if v["dtype"] is not None
5852
}
59-
60-
# Lookup dict for quickly getting a scalar type from its code
61-
scalar_code_lookup = dict((v, k) for k, v in scalar_codes.items())
62-
# Lookup dict for quickly getting a scalar name from its type
63-
scalar_name_lookup = dict((v, k) for k, v in scalar_type.items())
6453

6554

6655
compression = {b"ZL123\0": zlib.decompress}
@@ -235,16 +224,16 @@ def read_array(self):
235224
dtype_id, is_complex = self.read_value("uint32", 2)
236225

237226
# Get dtype from type id
238-
dtype = scalar_code_lookup[dtype_id]
227+
dtype = deserialize_lookup[dtype_id]["dtype"]
239228

240229
# Check if name is void
241-
if scalar_name_lookup[dtype] == "VOID":
230+
if deserialize_lookup[dtype_id]["scalar_type"] == "VOID":
242231
data = np.array(
243232
list(self.read_blob(self.read_value()) for _ in range(n_elem)),
244233
dtype=np.dtype("O"),
245234
)
246235
# Check if name is char
247-
elif scalar_name_lookup[dtype] == "CHAR":
236+
elif deserialize_lookup[dtype_id]["scalar_type"] == "CHAR":
248237
# compensate for MATLAB packing of char arrays
249238
data = self.read_value(dtype, count=2 * n_elem)
250239
data = data[::2].astype("U1")
@@ -267,7 +256,6 @@ def pack_array(self, array):
267256
Serialize an np.ndarray into bytes. Scalars are encoded with ndim=0.
268257
"""
269258
if "datetime64" in array.dtype.name:
270-
array = array.astype("datetime64[us]")
271259
self.set_dj0()
272260
blob = (
273261
b"A"
@@ -278,23 +266,27 @@ def pack_array(self, array):
278266
if is_complex:
279267
array, imaginary = np.real(array), np.imag(array)
280268
try:
281-
type_id = scalar_codes[array.dtype]
269+
type_id = serialize_lookup[array.dtype]["type_id"]
282270
except KeyError:
283271
if array.dtype.char == "U":
284-
type_id = scalar_codes[np.dtype("O")]
272+
type_id = serialize_lookup[np.dtype("O")]["type_id"]
285273
pass
286274
else:
287275
raise DataJointError("Type %s is ambiguous or unknown" % array.dtype)
288276

289277
blob += np.array([type_id, is_complex], dtype=np.uint32).tobytes()
290278
# array of dtype('O'), U is for unicode string
291-
if array.dtype.char == "U" or scalar_name_lookup[array.dtype] == "VOID":
279+
if (
280+
array.dtype.char == "U"
281+
or serialize_lookup[array.dtype]["scalar_type"] == "VOID"
282+
):
292283
blob += b"".join(
293284
len_u64(it) + it
294285
for it in (self.pack_blob(e) for e in array.flatten(order="F"))
295286
)
296287
self.set_dj0() # not supported by original mym
297-
elif scalar_name_lookup[array.dtype] == "CHAR": # array of dtype('c')
288+
# array of dtype('c')
289+
elif serialize_lookup[array.dtype]["scalar_type"] == "CHAR":
298290
blob += (
299291
array.view(np.uint8).astype(np.uint16).tobytes()
300292
) # convert to 16-bit chars for MATLAB

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ pyparsing
44
ipython
55
pandas
66
tqdm
7-
networkx<2.8.3
7+
networkx<=2.8.2
88
pydot
99
minio>=7.0.0
1010
matplotlib

0 commit comments

Comments
 (0)