Skip to content

Commit aeee3f9

Browse files
committed
improve handling of numpy scalars in numpy 2.0
1 parent 0ae6e76 commit aeee3f9

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

neo/rawio/plexonrawio.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ def _parse_header(self):
209209
for index, pos in enumerate(positions):
210210
bl_header = data[pos : pos + 16].view(DataBlockHeader)[0]
211211

212+
# To avoid overflow errors when doing arithmetic operations on numpy scalars
213+
np_scalar_to_python_scalar = lambda x: x.item() if isinstance(x, np.generic) else x
214+
bl_header = {key: np_scalar_to_python_scalar(bl_header[key]) for key in bl_header.dtype.names}
215+
212216
timestamp = bl_header["UpperByteOf5ByteTimestamp"] * 2**32 + bl_header["TimeStamp"]
213217
n1 = bl_header["NumberOfWaveforms"]
214218
n2 = bl_header["NumberOfWordsInWaveform"]
@@ -253,21 +257,26 @@ def _parse_header(self):
253257
else:
254258
chan_loop = range(nb_sig_chan)
255259
for chan_index in chan_loop:
256-
h = slowChannelHeaders[chan_index]
257-
name = h["Name"].decode("utf8")
258-
chan_id = h["Channel"]
260+
channel_headers = slowChannelHeaders[chan_index]
261+
262+
# To avoid overflow errors when doing arithmetic operations on numpy scalars
263+
np_scalar_to_python_scalar = lambda x: x.item() if isinstance(x, np.generic) else x
264+
channel_headers = {key: np_scalar_to_python_scalar(channel_headers[key]) for key in channel_headers.dtype.names}
265+
266+
name = channel_headers["Name"].decode("utf8")
267+
chan_id = channel_headers["Channel"]
259268
length = self._data_blocks[5][chan_id]["size"].sum() // 2
260269
if length == 0:
261270
continue # channel not added
262-
source_id.append(h["SrcId"])
271+
source_id.append(channel_headers["SrcId"])
263272
channel_num_samples.append(length)
264-
sampling_rate = float(h["ADFreq"])
273+
sampling_rate = float(channel_headers["ADFreq"])
265274
sig_dtype = "int16"
266275
units = "" # I don't know units
267276
if global_header["Version"] in [100, 101]:
268-
gain = 5000.0 / (2048 * h["Gain"] * 1000.0)
277+
gain = 5000.0 / (2048 * channel_headers["Gain"] * 1000.0)
269278
elif global_header["Version"] in [102]:
270-
gain = 5000.0 / (2048 * h["Gain"] * h["PreampGain"])
279+
gain = 5000.0 / (2048 * channel_headers["Gain"] * channel_headers["PreampGain"])
271280
elif global_header["Version"] >= 103:
272281
gain = global_header["SlowMaxMagnitudeMV"] / (
273282
0.5 * (2 ** global_header["BitsPerSpikeSample"]) * h["Gain"] * h["PreampGain"]
@@ -574,7 +583,7 @@ def read_as_dict(fid, dtype, offset=None):
574583
v = v.replace("\x03", "")
575584
v = v.replace("\x00", "")
576585

577-
info[k] = v
586+
info[k] = v.item() if isinstance(v, np.generic) else v
578587
return info
579588

580589

0 commit comments

Comments
 (0)