Skip to content

Commit 9f552c8

Browse files
committed
fix bugs in math.operator
1 parent 3399cd9 commit 9f552c8

File tree

2 files changed

+51
-9
lines changed

2 files changed

+51
-9
lines changed

brainpy/math/operators/op_register.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,11 @@ def __init__(
7777
gpu_func = None
7878

7979
# register OP
80-
_, self.op = brainpylib.register_op(self.name,
81-
cpu_func=cpu_func,
82-
gpu_func=gpu_func,
83-
out_shapes=eval_shape,
84-
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu,
85-
return_primitive=True)
80+
self.op = brainpylib.register_op(self.name,
81+
cpu_func=cpu_func,
82+
gpu_func=gpu_func,
83+
out_shapes=eval_shape,
84+
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
8685

8786
def __call__(self, *args, **kwargs):
8887
args = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a,
@@ -131,6 +130,7 @@ def register_op(
131130

132131
def fixed_op(*inputs):
133132
inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs])
134-
return f(*inputs)
133+
res = f.bind(*inputs)
134+
return res[0] if len(res) == 1 else res
135135

136136
return fixed_op

brainpy/math/operators/tests/test_op_register.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def event_sum_op(outs, ins):
2424

2525

2626
event_sum = bm.register_op(name='event_sum', cpu_func=event_sum_op, eval_shape=abs_eval)
27+
event_sum2 = bm.XLACustomOp(name='event_sum', cpu_func=event_sum_op, eval_shape=abs_eval)
2728
event_sum = bm.jit(event_sum)
2829

2930

@@ -83,6 +84,36 @@ def update(self, tdi):
8384
self.post.input += self.g * (self.E - self.post.V)
8485

8586

87+
class ExponentialSyn3(bp.dyn.TwoEndConn):
88+
def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.,
89+
method='exp_auto'):
90+
super(ExponentialSyn3, self).__init__(pre=pre, post=post, conn=conn)
91+
self.check_pre_attrs('spike')
92+
self.check_post_attrs('input', 'V')
93+
94+
# parameters
95+
self.E = E
96+
self.tau = tau
97+
self.delay = delay
98+
self.g_max = g_max
99+
self.pre2post = self.conn.require('pre2post')
100+
101+
# variables
102+
self.g = bm.Variable(bm.zeros(self.post.num))
103+
104+
# function
105+
self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)
106+
107+
def update(self, tdi):
108+
self.g.value = self.integral(self.g, tdi['t'], tdi['dt'])
109+
# Customized operator
110+
# ------------------------------------------------------------------------------------------------------------
111+
post_val = bm.zeros(self.post.num)
112+
self.g += event_sum2(self.pre.spike, self.pre2post[0], self.pre2post[1], post_val, self.g_max)
113+
# ------------------------------------------------------------------------------------------------------------
114+
self.post.input += self.g * (self.E - self.post.V)
115+
116+
86117
class EINet(bp.dyn.Network):
87118
def __init__(self, syn_class, scale=1.0, method='exp_auto', ):
88119
super(EINet, self).__init__()
@@ -111,7 +142,7 @@ def __init__(self, syn_class, scale=1.0, method='exp_auto', ):
111142
class TestOpRegister(unittest.TestCase):
112143
def test_op(self):
113144

114-
fig, gs = bp.visualize.get_figure(1, 2, 4, 5)
145+
fig, gs = bp.visualize.get_figure(1, 3, 4, 5)
115146

116147
net = EINet(ExponentialSyn, scale=1., method='euler')
117148
runner = bp.dyn.DSRunner(
@@ -133,5 +164,16 @@ def test_op(self):
133164
t, _ = runner2.run(100., eval_time=True)
134165
print(t)
135166
ax = fig.add_subplot(gs[0, 1])
136-
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], ax=ax, show=True)
167+
bp.visualize.raster_plot(runner2.mon.ts, runner2.mon['E.spike'], ax=ax)
168+
169+
net3 = EINet(ExponentialSyn3, scale=1., method='euler')
170+
runner3 = bp.dyn.DSRunner(
171+
net3,
172+
inputs=[(net3.E.input, 20.), (net3.I.input, 20.)],
173+
monitors={'E.spike': net3.E.spike},
174+
)
175+
t, _ = runner3.run(100., eval_time=True)
176+
print(t)
177+
ax = fig.add_subplot(gs[0, 2])
178+
bp.visualize.raster_plot(runner3.mon.ts, runner3.mon['E.spike'], ax=ax, show=True)
137179
plt.close()

0 commit comments

Comments
 (0)