@@ -24,6 +24,7 @@ def event_sum_op(outs, ins):
2424
2525
2626event_sum = bm .register_op (name = 'event_sum' , cpu_func = event_sum_op , eval_shape = abs_eval )
27+ event_sum2 = bm .XLACustomOp (name = 'event_sum' , cpu_func = event_sum_op , eval_shape = abs_eval )
2728event_sum = bm .jit (event_sum )
2829
2930
@@ -83,6 +84,36 @@ def update(self, tdi):
8384 self .post .input += self .g * (self .E - self .post .V )
8485
8586
87+ class ExponentialSyn3 (bp .dyn .TwoEndConn ):
88+ def __init__ (self , pre , post , conn , g_max = 1. , delay = 0. , tau = 8.0 , E = 0. ,
89+ method = 'exp_auto' ):
90+ super (ExponentialSyn3 , self ).__init__ (pre = pre , post = post , conn = conn )
91+ self .check_pre_attrs ('spike' )
92+ self .check_post_attrs ('input' , 'V' )
93+
94+ # parameters
95+ self .E = E
96+ self .tau = tau
97+ self .delay = delay
98+ self .g_max = g_max
99+ self .pre2post = self .conn .require ('pre2post' )
100+
101+ # variables
102+ self .g = bm .Variable (bm .zeros (self .post .num ))
103+
104+ # function
105+ self .integral = bp .odeint (lambda g , t : - g / self .tau , method = method )
106+
107+ def update (self , tdi ):
108+ self .g .value = self .integral (self .g , tdi ['t' ], tdi ['dt' ])
109+ # Customized operator
110+ # ------------------------------------------------------------------------------------------------------------
111+ post_val = bm .zeros (self .post .num )
112+ self .g += event_sum2 (self .pre .spike , self .pre2post [0 ], self .pre2post [1 ], post_val , self .g_max )
113+ # ------------------------------------------------------------------------------------------------------------
114+ self .post .input += self .g * (self .E - self .post .V )
115+
116+
86117class EINet (bp .dyn .Network ):
87118 def __init__ (self , syn_class , scale = 1.0 , method = 'exp_auto' , ):
88119 super (EINet , self ).__init__ ()
@@ -111,7 +142,7 @@ def __init__(self, syn_class, scale=1.0, method='exp_auto', ):
111142class TestOpRegister (unittest .TestCase ):
112143 def test_op (self ):
113144
114- fig , gs = bp .visualize .get_figure (1 , 2 , 4 , 5 )
145+ fig , gs = bp .visualize .get_figure (1 , 3 , 4 , 5 )
115146
116147 net = EINet (ExponentialSyn , scale = 1. , method = 'euler' )
117148 runner = bp .dyn .DSRunner (
@@ -133,5 +164,16 @@ def test_op(self):
133164 t , _ = runner2 .run (100. , eval_time = True )
134165 print (t )
135166 ax = fig .add_subplot (gs [0 , 1 ])
136- bp .visualize .raster_plot (runner .mon .ts , runner .mon ['E.spike' ], ax = ax , show = True )
167+ bp .visualize .raster_plot (runner2 .mon .ts , runner2 .mon ['E.spike' ], ax = ax )
168+
169+ net3 = EINet (ExponentialSyn3 , scale = 1. , method = 'euler' )
170+ runner3 = bp .dyn .DSRunner (
171+ net3 ,
172+ inputs = [(net3 .E .input , 20. ), (net3 .I .input , 20. )],
173+ monitors = {'E.spike' : net3 .E .spike },
174+ )
175+ t , _ = runner3 .run (100. , eval_time = True )
176+ print (t )
177+ ax = fig .add_subplot (gs [0 , 2 ])
178+ bp .visualize .raster_plot (runner3 .mon .ts , runner3 .mon ['E.spike' ], ax = ax , show = True )
137179 plt .close ()
0 commit comments