Skip to content

Commit 9cc1a53

Browse files
committed
read values dynamically based on flag
1 parent 7900ef5 commit 9cc1a53

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

datajoint/blob.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(self, squeeze=False):
6969
self._blob = None
7070
self._pos = 0
7171
self.protocol = None
72+
self.is_32_bit = False
7273

7374
def set_dj0(self):
7475
if not config.get('enable_python_native_blobs'):
@@ -96,7 +97,7 @@ def unpack(self, blob):
9697
pass # assume uncompressed but could be unrecognized compression
9798
else:
9899
self._pos += len(prefix)
99-
blob_size = self.read_value('uint64')
100+
blob_size = self.read_value()
100101
blob = compression[prefix](self._blob[self._pos:])
101102
assert len(blob) == blob_size
102103
self._blob = blob
@@ -191,8 +192,8 @@ def pack_blob(self, obj):
191192
raise DataJointError("Packing object of type %s currently not supported!" % type(obj))
192193

193194
def read_array(self):
194-
n_dims = int(self.read_value('uint64'))
195-
shape = self.read_value('uint64', count=n_dims)
195+
n_dims = int(self.read_value())
196+
shape = self.read_value(count=n_dims)
196197
n_elem = np.prod(shape, dtype=int)
197198
dtype_id, is_complex = self.read_value('uint32', 2)
198199
dtype = dtype_list[dtype_id]
@@ -365,7 +366,7 @@ def read_struct(self):
365366
return np.array(None) # empty array
366367
field_names = [self.read_zero_terminated_string() for _ in range(n_fields)]
367368
raw_data = [
368-
tuple(self.read_blob(n_bytes=int(self.read_value('uint64'))) for _ in range(n_fields))
369+
tuple(self.read_blob(n_bytes=int(self.read_value())) for _ in range(n_fields))
369370
for __ in range(n_elem)]
370371
data = np.array(raw_data, dtype=list(zip(field_names, repeat(object))))
371372
return self.squeeze(data.reshape(shape, order="F"), convert_to_scalar=False).view(MatStruct)
@@ -431,7 +432,9 @@ def read_zero_terminated_string(self):
431432
self._pos = target + 1
432433
return data
433434

434-
def read_value(self, dtype='uint64', count=1):
435+
def read_value(self, dtype=None, count=1):
436+
if dtype is None:
437+
dtype = 'uint32' if self.is_32_bit else 'uint64'
435438
data = np.frombuffer(self._blob, dtype=dtype, count=count, offset=self._pos)
436439
self._pos += data.dtype.itemsize * data.size
437440
return data[0] if count == 1 else data

0 commit comments

Comments
 (0)