Skip to content

Commit c6ee565

Browse files
committed
fix test
1 parent 289f27a commit c6ee565

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

brainpy/math/operators/tests/test_op_register.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111

1212

1313
def abs_eval(events, indices, indptr, post_val, values):
14-
return post_val
14+
return [post_val]
1515

1616

1717
def event_sum_op(outs, ins):
1818
events, indices, indptr, post, values = ins
1919
v = values[()]
20+
outs, = outs
2021
outs.fill(0)
2122
for i in range(len(events)):
2223
if events[i]:
@@ -79,7 +80,8 @@ def update(self, tdi):
7980
# Customized operator
8081
# ------------------------------------------------------------------------------------------------------------
8182
post_val = bm.zeros(self.post.num)
82-
self.g += event_sum2(self.pre.spike, self.pre2post[0], self.pre2post[1], post_val, self.g_max)
83+
r = event_sum2(self.pre.spike, self.pre2post[0], self.pre2post[1], post_val, self.g_max)
84+
self.g += r[0]
8385
# ------------------------------------------------------------------------------------------------------------
8486
self.post.input += self.g * (self.E - self.post.V)
8587

@@ -112,12 +114,13 @@ def __init__(self, syn_class, scale=1.0, method='exp_auto', ):
112114
class TestOpRegister(unittest.TestCase):
113115
def test_op(self):
114116

115-
fig, gs = bp.visualize.get_figure(1, 3, 4, 5)
117+
fig, gs = bp.visualize.get_figure(1, 1, 4, 5)
116118

117119
net = EINet(ExponentialSyn, scale=1., method='euler')
118120
runner = bp.dyn.DSRunner(
119121
net,
120-
inputs=[(net.E.input, 20.), (net.I.input, 20.)],
122+
inputs=[(net.E.input, 20.),
123+
(net.I.input, 20.)],
121124
monitors={'E.spike': net.E.spike},
122125
)
123126
t, _ = runner.run(100., eval_time=True)
@@ -128,12 +131,13 @@ def test_op(self):
128131
net3 = EINet(ExponentialSyn3, scale=1., method='euler')
129132
runner3 = bp.dyn.DSRunner(
130133
net3,
131-
inputs=[(net3.E.input, 20.), (net3.I.input, 20.)],
134+
inputs=[(net3.E.input, 20.),
135+
(net3.I.input, 20.)],
132136
monitors={'E.spike': net3.E.spike},
133137
)
134138
t, _ = runner3.run(100., eval_time=True)
135139
print(t)
136-
ax = fig.add_subplot(gs[0, 2])
140+
ax = fig.add_subplot(gs[0, 1])
137141
bp.visualize.raster_plot(runner3.mon.ts, runner3.mon['E.spike'], ax=ax, show=True)
138142

139143
# clear

brainpy/math/tests/test_delay_vars.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ def test_dim3(self):
6969
def test1(self):
7070
print()
7171
delay = bm.TimeDelay(jnp.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t)
72-
self.assertTrue(bm.allclose(delay(-0.2), bm.ones(3) * 0.2))
72+
self.assertTrue(bm.allclose(delay(-0.2), bm.ones(3) * -0.2))
7373
delay = bm.TimeDelay(jnp.zeros((3, 2)), delay_len=1., dt=0.1, before_t0=lambda t: t)
74-
self.assertTrue(bm.allclose(delay(-0.6), bm.zeros((3, 2)) * -0.6))
74+
self.assertTrue(bm.allclose(delay(-0.6), bm.ones((3, 2)) * -0.6))
7575
delay = bm.TimeDelay(jnp.zeros((3, 2, 1)), delay_len=1., dt=0.1, before_t0=lambda t: t)
76-
self.assertTrue(bm.allclose(delay(-0.8), jnp.zeros(3, 2, 1) * -0.8))
76+
self.assertTrue(bm.allclose(delay(-0.8), jnp.ones((3, 2, 1)) * -0.8))
7777

7878
def test_current_time2(self):
7979
print()

0 commit comments

Comments
 (0)