Skip to content

Commit 4008f07

Browse files
committed
fix test bugs
1 parent c877600 commit 4008f07

File tree

5 files changed

+29
-23
lines changed

5 files changed

+29
-23
lines changed

.github/workflows/Windows_CI.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ jobs:
2929
python -m pip install --upgrade pip
3030
python -m pip install flake8 pytest
3131
python -m pip install numpy>=1.21.0
32-
python -m pip install "jaxlib==0.3.14" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
33-
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
32+
python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
33+
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz
3434
python -m pip install -r requirements-dev.txt
3535
python -m pip install tqdm brainpylib
3636
pip uninstall brainpy -y

brainpy/analysis/highdim/tests/test_slow_points.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,22 +102,24 @@ def ds2(s2, t, s1, coh=0.5, mu=20.):
102102

103103
def step(s):
104104
return bm.asarray([ds1(s[0], 0., s[1]), ds2(s[1], 0., s[0])])
105-
105+
106+
rng = bm.random.RandomState(123)
106107
finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS)
107-
finder.find_fps_with_opt_solver(bm.random.random((100, 2)))
108+
finder.find_fps_with_opt_solver(rng.random((100, 2)))
108109
bm.clear_buffer_memory()
109110

110111
def test_opt_solver_for_ds1(self):
111112
hh = HH(1)
112113
finder = bp.analysis.SlowPointFinder(f_cell=hh, excluded_vars=[hh.input, hh.spike])
114+
rng = bm.random.RandomState(123)
113115

114116
with self.assertRaises(ValueError):
115-
finder.find_fps_with_opt_solver(bm.random.random((100, 4)))
117+
finder.find_fps_with_opt_solver(rng.random((100, 4)))
116118

117-
finder.find_fps_with_opt_solver({'V': bm.random.random((100, 1)),
118-
'm': bm.random.random((100, 1)),
119-
'h': bm.random.random((100, 1)),
120-
'n': bm.random.random((100, 1))})
119+
finder.find_fps_with_opt_solver({'V': rng.random((100, 1)),
120+
'm': rng.random((100, 1)),
121+
'h': rng.random((100, 1)),
122+
'n': rng.random((100, 1))})
121123
bm.clear_buffer_memory()
122124

123125
def test_gd_method_for_func1(self):
@@ -149,21 +151,23 @@ def ds2(s2, t, s1, coh=0.5, mu=20.):
149151
def step(s):
150152
return bm.asarray([ds1(s[0], 0., s[1]), ds2(s[1], 0., s[0])])
151153

154+
rng = bm.random.RandomState(123)
152155
finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS)
153-
finder.find_fps_with_gd_method(bm.random.random((100, 2)), num_opt=100)
156+
finder.find_fps_with_gd_method(rng.random((100, 2)), num_opt=100)
154157
bm.clear_buffer_memory()
155158

156159
def test_gd_method_for_func2(self):
157160
hh = HH(1)
158161
finder = bp.analysis.SlowPointFinder(f_cell=hh, excluded_vars=[hh.input, hh.spike])
159-
162+
rng = bm.random.RandomState(123)
163+
160164
with self.assertRaises(ValueError):
161-
finder.find_fps_with_opt_solver(bm.random.random((100, 4)))
165+
finder.find_fps_with_opt_solver(rng.random((100, 4)))
162166

163-
finder.find_fps_with_gd_method({'V': bm.random.random((100, 1)),
164-
'm': bm.random.random((100, 1)),
165-
'h': bm.random.random((100, 1)),
166-
'n': bm.random.random((100, 1))},
167+
finder.find_fps_with_gd_method({'V': rng.random((100, 1)),
168+
'm': rng.random((100, 1)),
169+
'h': rng.random((100, 1)),
170+
'n': rng.random((100, 1))},
167171
num_opt=100)
168172
bm.clear_buffer_memory()
169173

brainpy/math/object_transform/tests/test_controls.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,16 @@ def update(x):
6262
return update
6363

6464
bp.math.random.seed()
65-
_v1 = bm.random.normal(size=10)
66-
_v2 = bm.random.random(size=10)
65+
_v1 = bm.Variable(bm.random.normal(size=10))
66+
_v2 = bm.as_variable(bm.random.random(size=10))
6767
_xs = bm.random.uniform(size=(4, 10))
6868

6969
scan_f = bm.make_loop(make_node(_v1, _v2),
7070
dyn_vars=(_v1, _v2),
7171
out_vars=(_v1,),
7272
has_return=True)
73-
with self.assertRaises(bp.errors.MathError):
74-
outs, returns = scan_f(_xs)
73+
# with self.assertRaises(bp.errors.MathError):
74+
outs, returns = scan_f(_xs)
7575

7676
@parameterized.named_parameters(
7777
{"testcase_name": "_jit_scan={}_jit_f={}_unroll={}".format(jit_scan, jit_f, unroll),

brainpy/math/operators/tests/test_differential_spike.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
def test_sp_sigmoid_grad():
1111
f_grad = bm.vector_grad(lambda a: bm.spike_with_sigmoid_grad(a, 1.))
12-
x = bm.random.random(10) - 0.5
12+
rng = bm.random.RandomState()
13+
x = rng.random(10) - 0.5
1314
print(f_grad(x))
1415

1516

brainpy/math/operators/tests/test_op_register.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,9 @@ def __init__(self, syn_class, scale=1.0, method='exp_auto', ):
9696
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
9797
self.E = bp.neurons.LIF(num_exc, **pars, method=method)
9898
self.I = bp.neurons.LIF(num_inh, **pars, method=method)
99-
self.E.V[:] = bp.math.random.randn(num_exc) * 2 - 55.
100-
self.I.V[:] = bp.math.random.randn(num_inh) * 2 - 55.
99+
rng = bm.random.RandomState()
100+
self.E.V[:] = rng.randn(num_exc) * 2 - 55.
101+
self.I.V[:] = rng.randn(num_inh) * 2 - 55.
101102

102103
# synapses
103104
we = 0.6 / scale # excitatory synaptic weight (voltage)

0 commit comments

Comments
 (0)