Skip to content

Commit c034591

Browse files
authored
Merge pull request #1625 from alejoe91/final-2.0-fixes
Fix failing tests due to Numpy 2.0
2 parents dbe2e95 + 70d040f commit c034591

File tree

7 files changed

+22
-24
lines changed

7 files changed

+22
-24
lines changed

.github/workflows/io-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
strategy:
1919
fail-fast: true
2020
matrix:
21-
python-version: ['3.9', '3.12']
21+
python-version: ['3.9', '3.13']
2222
defaults:
2323
# by default run in bash mode (required for conda usage)
2424
run:

environment_testing.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,3 @@ channels:
44
dependencies:
55
- datalad
66
- pip
7-
# temporary have this here for IO testing while we decide how to deal with
8-
# external packages not 2.0 ready
9-
- numpy=1.26.4

neo/io/klustakwikio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _load_spike_times(self, fetfilename):
197197
names.append("spike_time")
198198

199199
# Load into recarray
200-
data = np.recfromtxt(fetfilename, names=names, skip_header=1, delimiter=" ")
200+
data = np.genfromtxt(fetfilename, names=names, skip_header=1, delimiter=" ")
201201

202202
# get features
203203
features = np.array([data[f"fet{n}"] for n in range(nbFeatures)])

neo/rawio/blackrockrawio.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,7 +1365,7 @@ def __match_nsx_and_nev_segment_ids(self, nsx_nb):
13651365

13661366
# Show warning if spikes do not fit any segment (+- 1 sampling 'tick')
13671367
# Spike should belong to segment before
1368-
mask_outside = (ev_ids == i) & (data["timestamp"] < int(seg["timestamp"]) - nsx_offset - nsx_period)
1368+
mask_outside = (ev_ids == i) & (data["timestamp"] < int(seg["timestamp"]) - int(nsx_offset) - int(nsx_period))
13691369

13701370
if len(data[mask_outside]) > 0:
13711371
warnings.warn(f"Spikes outside any segment. Detected on segment #{i}")
@@ -1995,6 +1995,7 @@ def __get_nsx_param_variant_a(self, nsx_nb):
19951995
else:
19961996
units = "uV"
19971997

1998+
19981999
nsx_parameters = {
19992000
"nb_data_points": int(
20002001
(self.__get_file_size(filename) - bytes_in_headers)
@@ -2003,8 +2004,8 @@ def __get_nsx_param_variant_a(self, nsx_nb):
20032004
),
20042005
"labels": labels,
20052006
"units": np.array([units] * self.__nsx_basic_header[nsx_nb]["channel_count"]),
2006-
"min_analog_val": -1 * np.array(dig_factor),
2007-
"max_analog_val": np.array(dig_factor),
2007+
"min_analog_val": -1 * np.array(dig_factor, dtype="float"),
2008+
"max_analog_val": np.array(dig_factor, dtype="float"),
20082009
"min_digital_val": np.array([-1000] * self.__nsx_basic_header[nsx_nb]["channel_count"]),
20092010
"max_digital_val": np.array([1000] * self.__nsx_basic_header[nsx_nb]["channel_count"]),
20102011
"timestamp_resolution": 30000,

neo/test/iotest/test_asciisignalio.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,10 @@ def test_skiprows(self):
166166
self.assertEqual(signal.units, pq.V)
167167
assert_array_equal(signal.times, [0.0, 1.0, 2.0, 3.0] * pq.s)
168168
assert_array_equal(signal.times.magnitude, [0.0, 1.0, 2.0, 3.0])
169-
assert_array_equal(signal[0].magnitude, -64.8)
170-
assert_array_equal(signal[1].magnitude, -64.6)
171-
assert_array_equal(signal[2].magnitude, -64.3)
172-
assert_array_equal(signal[3].magnitude, -66)
169+
assert_array_almost_equal(signal[0].magnitude, -64.8, decimal=5)
170+
assert_array_almost_equal(signal[1].magnitude, -64.6, decimal=5)
171+
assert_array_almost_equal(signal[2].magnitude, -64.3, decimal=5)
172+
assert_array_almost_equal(signal[3].magnitude, -66, decimal=5)
173173
assert_array_almost_equal(np.asarray(signal).flatten(), np.array([-64.8, -64.6, -64.3, -66]), decimal=5)
174174

175175
os.remove(filename)
@@ -195,11 +195,11 @@ def test_usecols(self):
195195
self.assertEqual(signal.units, pq.V)
196196
assert_array_equal(signal.times, [0.0, 1.0, 2.0, 3.0, 4.0] * pq.s)
197197
assert_array_equal(signal.times.magnitude, [0.0, 1.0, 2.0, 3.0, 4.0])
198-
assert_array_equal(signal[0].magnitude, 0.5)
199-
assert_array_equal(signal[1].magnitude, 0.6)
200-
assert_array_equal(signal[2].magnitude, 0.7)
201-
assert_array_equal(signal[3].magnitude, 0.8)
202-
assert_array_equal(signal[4].magnitude, 1.4)
198+
assert_array_almost_equal(signal[0].magnitude, 0.5, decimal=5)
199+
assert_array_almost_equal(signal[1].magnitude, 0.6, decimal=5)
200+
assert_array_almost_equal(signal[2].magnitude, 0.7, decimal=5)
201+
assert_array_almost_equal(signal[3].magnitude, 0.8, decimal=5)
202+
assert_array_almost_equal(signal[4].magnitude, 1.4, decimal=5)
203203
assert_array_almost_equal(np.asarray(signal).flatten(), np.array([0.5, 0.6, 0.7, 0.8, 1.4]), decimal=5)
204204

205205
os.remove(filename)

neo/test/iotest/test_axographio.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from neo.test.iotest.common_io_test import BaseTestIO
1010

1111
import numpy as np
12-
from numpy.testing import assert_equal
12+
from numpy.testing import assert_equal, assert_almost_equal
1313
import quantities as pq
1414
from neo.test.rawiotest.test_axographrawio import TestAxographRawIO
1515

@@ -35,9 +35,9 @@ def test_version_1(self):
3535
target = np.array([[-5.5078130], [-3.1171880], [+1.6640626], [+1.6640626], [+4.0546880]], dtype=np.float32)
3636
assert_equal(arr, target)
3737

38-
assert_equal(sig.t_start, 0.0005000000237487257 * pq.s)
38+
assert_almost_equal(sig.t_start, 0.0005000000237487257 * pq.s, decimal=9)
3939

40-
assert_equal(sig.sampling_period, 0.0005000010132789612 * pq.s)
40+
assert_almost_equal(sig.sampling_period, 0.0005000010132789612 * pq.s, decimal=9)
4141

4242
def test_version_2(self):
4343
"""Test reading a version 2 AxoGraph file"""
@@ -87,9 +87,9 @@ def test_version_2(self):
8787
target = np.array([[0.3125], [9.6875], [9.6875], [9.6875], [9.3750]], dtype=np.float32)
8888
assert_equal(arr, target)
8989

90-
assert_equal(sig.t_start, 0.00009999999747378752 * pq.s)
90+
assert_almost_equal(sig.t_start, 0.00009999999747378752 * pq.s, decimal=9)
9191

92-
assert_equal(sig.sampling_period, 0.00009999999747378750 * pq.s)
92+
assert_almost_equal(sig.sampling_period, 0.00009999999747378750 * pq.s, decimal=9)
9393

9494
def test_version_5(self):
9595
"""Test reading a version 5 AxoGraph file"""
@@ -169,7 +169,7 @@ def test_file_written_by_axographio_package_without_linearsequence(self):
169169

170170
assert_equal(sig.t_start, 0 * pq.s)
171171

172-
assert_equal(sig.sampling_period, 0.009999999999999787 * pq.s)
172+
assert_almost_equal(sig.sampling_period, 0.009999999999999787 * pq.s, decimal=9)
173173

174174
def test_file_with_corrupt_comment(self):
175175
"""Test reading a file with a corrupt comment"""

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ docs = [
8181
"nixio",
8282
"pynwb",
8383
"igor2",
84-
"numpy<2.0" # https://github.com/NeuralEnsemble/python-neo/pull/1612
84+
"numpy>=2.0"
8585
]
8686

8787
dev = [

0 commit comments

Comments
 (0)