Skip to content

Commit b6a721c

Browse files
authored
Merge pull request #1613 from h-mayorquin/fix_overflow_2.0
Fix overflow of Plexon in numpy 2.0
2 parents 49534ce + 6f56bd0 commit b6a721c

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

neo/rawio/plexonrawio.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ def _parse_header(self):
209209
for index, pos in enumerate(positions):
210210
bl_header = data[pos : pos + 16].view(DataBlockHeader)[0]
211211

212+
# To avoid overflow errors when doing arithmetic operations on numpy scalars
213+
np_scalar_to_python_scalar = lambda x: x.item() if isinstance(x, np.generic) else x
214+
bl_header = {key: np_scalar_to_python_scalar(bl_header[key]) for key in bl_header.dtype.names}
215+
212216
current_upper_byte_of_5_byte_timestamp = int(bl_header["UpperByteOf5ByteTimestamp"])
213217
current_bl_timestamp = int(bl_header["TimeStamp"])
214218
timestamp = current_upper_byte_of_5_byte_timestamp * 2**32 + current_bl_timestamp
@@ -255,24 +259,29 @@ def _parse_header(self):
255259
else:
256260
chan_loop = range(nb_sig_chan)
257261
for chan_index in chan_loop:
258-
h = slowChannelHeaders[chan_index]
259-
name = h["Name"].decode("utf8")
260-
chan_id = h["Channel"]
262+
slow_channel_headers = slowChannelHeaders[chan_index]
263+
264+
# To avoid overflow errors when doing arithmetic operations on numpy scalars
265+
np_scalar_to_python_scalar = lambda x: x.item() if isinstance(x, np.generic) else x
266+
slow_channel_headers = {key: np_scalar_to_python_scalar(slow_channel_headers[key]) for key in slow_channel_headers.dtype.names}
267+
268+
name = slow_channel_headers["Name"].decode("utf8")
269+
chan_id = slow_channel_headers["Channel"]
261270
length = self._data_blocks[5][chan_id]["size"].sum() // 2
262271
if length == 0:
263272
continue # channel not added
264-
source_id.append(h["SrcId"])
273+
source_id.append(slow_channel_headers["SrcId"])
265274
channel_num_samples.append(length)
266-
sampling_rate = float(h["ADFreq"])
275+
sampling_rate = float(slow_channel_headers["ADFreq"])
267276
sig_dtype = "int16"
268277
units = "" # I don't know units
269278
if global_header["Version"] in [100, 101]:
270-
gain = 5000.0 / (2048 * h["Gain"] * 1000.0)
279+
gain = 5000.0 / (2048 * slow_channel_headers["Gain"] * 1000.0)
271280
elif global_header["Version"] in [102]:
272-
gain = 5000.0 / (2048 * h["Gain"] * h["PreampGain"])
281+
gain = 5000.0 / (2048 * slow_channel_headers["Gain"] * slow_channel_headers["PreampGain"])
273282
elif global_header["Version"] >= 103:
274283
gain = global_header["SlowMaxMagnitudeMV"] / (
275-
0.5 * (2 ** global_header["BitsPerSpikeSample"]) * h["Gain"] * h["PreampGain"]
284+
0.5 * (2 ** global_header["BitsPerSpikeSample"]) * slow_channel_headers["Gain"] * slow_channel_headers["PreampGain"]
276285
)
277286
offset = 0.0
278287

@@ -358,21 +367,21 @@ def _parse_header(self):
358367
unit_loop = enumerate(self.internal_unit_ids)
359368

360369
for unit_index, (chan_id, unit_id) in unit_loop:
361-
c = np.nonzero(dspChannelHeaders["Channel"] == chan_id)[0][0]
362-
h = dspChannelHeaders[c]
370+
channel_index = np.nonzero(dspChannelHeaders["Channel"] == chan_id)[0][0]
371+
dsp_channel_headers = dspChannelHeaders[channel_index]
363372

364-
name = h["Name"].decode("utf8")
373+
name = dsp_channel_headers["Name"].decode("utf8")
365374
_id = f"ch{chan_id}#{unit_id}"
366375
wf_units = ""
367376
if global_header["Version"] < 103:
368-
wf_gain = 3000.0 / (2048 * h["Gain"] * 1000.0)
377+
wf_gain = 3000.0 / (2048 * dsp_channel_headers["Gain"] * 1000.0)
369378
elif 103 <= global_header["Version"] < 105:
370379
wf_gain = global_header["SpikeMaxMagnitudeMV"] / (
371-
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * h["Gain"] * 1000.0
380+
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * dsp_channel_headers["Gain"] * 1000.0
372381
)
373382
elif global_header["Version"] >= 105:
374383
wf_gain = global_header["SpikeMaxMagnitudeMV"] / (
375-
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * h["Gain"] * global_header["SpikePreAmpGain"]
384+
0.5 * 2.0 ** (global_header["BitsPerSpikeSample"]) * dsp_channel_headers["Gain"] * global_header["SpikePreAmpGain"]
376385
)
377386
wf_offset = 0.0
378387
wf_left_sweep = -1 # DONT KNOWN
@@ -576,7 +585,7 @@ def read_as_dict(fid, dtype, offset=None):
576585
v = v.replace("\x03", "")
577586
v = v.replace("\x00", "")
578587

579-
info[k] = v
588+
info[k] = v.item() if isinstance(v, np.generic) else v
580589
return info
581590

582591

0 commit comments

Comments
 (0)