@@ -104,33 +104,8 @@ def __init__(
104104 self .check_post_attrs ('refractory' )
105105
106106 # connections and weights
107- self .conn_type = conn_type
108- if conn_type not in ['sparse' , 'dense' ]:
109- raise ValueError (f'"conn_type" must be in "sparse" and "dense", but we got { conn_type } ' )
110- if self .conn is None :
111- raise ValueError (f'Must provide "conn" when initialize the model { self .name } ' )
112- if isinstance (self .conn , One2One ):
113- self .weights = init_param (weights , (self .pre .num ,), allow_none = False )
114- self .weight_type = 'heter' if bm .size (self .weights ) != 1 else 'homo'
115- elif isinstance (self .conn , All2All ):
116- self .weights = init_param (weights , (self .pre .num , self .post .num ), allow_none = False )
117- if bm .size (self .weights ) != 1 :
118- self .weight_type = 'heter'
119- bm .fill_diagonal (self .weights , 0. )
120- else :
121- self .weight_type = 'homo'
122- else :
123- if conn_type == 'sparse' :
124- self .pre2post = self .conn .require ('pre2post' )
125- self .weights = init_param (weights , self .pre2post [1 ].shape , allow_none = False )
126- self .weight_type = 'heter' if bm .size (self .weights ) != 1 else 'homo'
127- elif conn_type == 'dense' :
128- self .weights = init_param (weights , (self .pre .num , self .post .num ), allow_none = False )
129- self .weight_type = 'heter' if bm .size (self .weights ) != 1 else 'homo'
130- if self .weight_type == 'homo' :
131- self .conn_mat = self .conn .require ('conn_mat' )
132- else :
133- raise ValueError (f'Unknown connection type: { conn_type } ' )
107+ self .weights = weights
108+ self .pre2post = self .conn .require ('pre2post' )
134109
135110 # variables
136111 self .delay_step = self .register_delay (f"{ self .pre .name } .spike" ,
@@ -144,33 +119,7 @@ def update(self, t, dt):
144119 # delays
145120 pre_spike = self .get_delay_data (f"{ self .pre .name } .spike" , delay_step = self .delay_step )
146121
147- # post values
148- assert self .weight_type in ['homo' , 'heter' ]
149- assert self .conn_type in ['sparse' , 'dense' ]
150- if isinstance (self .conn , All2All ):
151- pre_spike = pre_spike .astype (bm .float_ )
152- if self .weight_type == 'homo' :
153- post_vs = bm .sum (pre_spike )
154- if not self .conn .include_self :
155- post_vs = post_vs - pre_spike
156- post_vs *= self .weights
157- else :
158- post_vs = pre_spike @ self .weights
159- elif isinstance (self .conn , One2One ):
160- pre_spike = pre_spike .astype (bm .float_ )
161- post_vs = pre_spike * self .weights
162- else :
163- if self .conn_type == 'sparse' :
164- post_vs = bm .pre2post_event_sum (pre_spike ,
165- self .pre2post ,
166- self .post .num ,
167- self .weights )
168- else :
169- pre_spike = pre_spike .astype (bm .float_ )
170- if self .weight_type == 'homo' :
171- post_vs = self .weights * (pre_spike @ self .conn_mat )
172- else :
173- post_vs = pre_spike @ self .weights
122+ post_vs = bm .pre2post_event_sum (pre_spike , self .pre2post , self .post .num , self .weights )
174123
175124 # update outputs
176125 target = getattr (self .post , self .post_key )
@@ -299,33 +248,8 @@ def __init__(
299248 f'But we got { self .tau } ' )
300249
301250 # connections and weights
302- self .conn_type = conn_type
303- if conn_type not in ['sparse' , 'dense' ]:
304- raise ValueError (f'"conn_type" must be in "sparse" and "dense", but we got { conn_type } ' )
305- if self .conn is None :
306- raise ValueError (f'Must provide "conn" when initialize the model { self .name } ' )
307- if isinstance (self .conn , One2One ):
308- self .g_max = init_param (g_max , (self .pre .num ,), allow_none = False )
309- self .weight_type = 'heter' if bm .size (self .g_max ) != 1 else 'homo'
310- elif isinstance (self .conn , All2All ):
311- self .g_max = init_param (g_max , (self .pre .num , self .post .num ), allow_none = False )
312- if bm .size (self .g_max ) != 1 :
313- self .weight_type = 'heter'
314- bm .fill_diagonal (self .g_max , 0. )
315- else :
316- self .weight_type = 'homo'
317- else :
318- if conn_type == 'sparse' :
319- self .pre2post = self .conn .require ('pre2post' )
320- self .g_max = init_param (g_max , self .pre2post [1 ].shape , allow_none = False )
321- self .weight_type = 'heter' if bm .size (self .g_max ) != 1 else 'homo'
322- elif conn_type == 'dense' :
323- self .g_max = init_param (g_max , (self .pre .num , self .post .num ), allow_none = False )
324- self .weight_type = 'heter' if bm .size (self .g_max ) != 1 else 'homo'
325- if self .weight_type == 'homo' :
326- self .conn_mat = self .conn .require ('conn_mat' )
327- else :
328- raise ValueError (f'Unknown connection type: { conn_type } ' )
251+ self .pre2post = self .conn .require ('pre2post' )
252+ self .g_max = init_param (g_max , self .pre2post [1 ].shape , allow_none = False )
329253
330254 # variables
331255 self .g = bm .Variable (bm .zeros (self .post .num ))
@@ -344,33 +268,10 @@ def update(self, t, dt):
344268 pre_spike = self .get_delay_data (f"{ self .pre .name } .spike" , self .delay_step )
345269
346270 # post values
347- assert self .weight_type in ['homo' , 'heter' ]
348- assert self .conn_type in ['sparse' , 'dense' ]
349- if isinstance (self .conn , All2All ):
350- pre_spike = pre_spike .astype (bm .float_ )
351- if self .weight_type == 'homo' :
352- post_vs = bm .sum (pre_spike )
353- if not self .conn .include_self :
354- post_vs = post_vs - pre_spike
355- post_vs = self .g_max * post_vs
356- else :
357- post_vs = pre_spike @ self .g_max
358- elif isinstance (self .conn , One2One ):
359- pre_spike = pre_spike .astype (bm .float_ )
360- post_vs = pre_spike * self .g_max
361- else :
362- if self .conn_type == 'sparse' :
363- post_vs = bm .pre2post_event_sum (pre_spike ,
364- self .pre2post ,
365- self .post .num ,
366- self .g_max )
367- else :
368- pre_spike = pre_spike .astype (bm .float_ )
369- if self .weight_type == 'homo' :
370- post_vs = self .g_max * (pre_spike @ self .conn_mat )
371- else :
372- post_vs = pre_spike @ self .g_max
373-
271+ post_vs = bm .pre2post_event_sum (pre_spike ,
272+ self .pre2post ,
273+ self .post .num ,
274+ self .g_max )
374275 # updates
375276 self .g .value = self .integral (self .g .value , t , dt = dt ) + post_vs
376277 self .post .input += self .output (self .g )
@@ -619,33 +520,8 @@ def __init__(
619520 f'But we got { self .tau_decay } ' )
620521
621522 # connections
622- self .conn_type = conn_type
623- if conn_type not in ['sparse' , 'dense' ]:
624- raise ValueError (f'"conn_type" must be in "sparse" and "dense", but we got { conn_type } ' )
625- if self .conn is None :
626- raise ValueError (f'Must provide "conn" when initialize the model { self .name } ' )
627- if isinstance (self .conn , One2One ):
628- self .g_max = init_param (g_max , (self .pre .num ,), allow_none = False )
629- self .weight_type = 'heter' if bm .size (self .g_max ) != 1 else 'homo'
630- elif isinstance (self .conn , All2All ):
631- self .g_max = init_param (g_max , (self .pre .num , self .post .num ), allow_none = False )
632- if bm .size (self .g_max ) != 1 :
633- self .weight_type = 'heter'
634- bm .fill_diagonal (self .g_max , 0. )
635- else :
636- self .weight_type = 'homo'
637- else :
638- if conn_type == 'sparse' :
639- self .pre_ids , self .post_ids = self .conn .require ('pre_ids' , 'post_ids' )
640- self .g_max = init_param (g_max , self .post_ids .shape , allow_none = False )
641- self .weight_type = 'heter' if bm .size (self .g_max ) != 1 else 'homo'
642- elif conn_type == 'dense' :
643- self .g_max = init_param (g_max , (self .pre .num , self .post .num ), allow_none = False )
644- self .weight_type = 'heter' if bm .size (self .g_max ) != 1 else 'homo'
645- if self .weight_type == 'homo' :
646- self .conn_mat = self .conn .require ('conn_mat' )
647- else :
648- raise ValueError (f'Unknown connection type: { conn_type } ' )
523+ self .pre_ids , self .post_ids = self .conn .require ('pre_ids' , 'post_ids' )
524+ self .g_max = init_param (g_max , self .post_ids .shape , allow_none = False )
649525
650526 # variables
651527 self .h = bm .Variable (bm .zeros (self .pre .num ))
@@ -674,26 +550,7 @@ def update(self, t, dt):
674550 self .h += pre_spike
675551
676552 # post-synaptic values
677- assert self .weight_type in ['homo' , 'heter' ]
678- assert self .conn_type in ['sparse' , 'dense' ]
679- if isinstance (self .conn , All2All ):
680- if self .weight_type == 'homo' :
681- post_vs = bm .sum (self .g )
682- if not self .conn .include_self :
683- post_vs = post_vs - self .g
684- post_vs = self .g_max * post_vs
685- else :
686- post_vs = self .g @ self .g_max
687- elif isinstance (self .conn , One2One ):
688- post_vs = self .g_max * self .g
689- else :
690- if self .conn_type == 'sparse' :
691- post_vs = bm .pre2post_sum (self .g , self .post .num , self .post_ids , self .pre_ids )
692- else :
693- if self .weight_type == 'homo' :
694- post_vs = (self .g_max * self .g ) @ self .conn_mat
695- else :
696- post_vs = self .g @ self .g_max
553+ post_vs = bm .pre2post_sum (self .g , self .post .num , self .post_ids , self .pre_ids )
697554
698555 # output
699556 self .post .input += self .output (post_vs )
@@ -1199,33 +1056,8 @@ def __init__(
11991056 raise ValueError (f'"tau_rise" must be a scalar or a tensor with size of 1. But we got { tau_rise } ' )
12001057
12011058 # connections and weights
1202- self .conn_type = conn_type
1203- if conn_type not in ['sparse' , 'dense' ]:
1204- raise ValueError (f'"conn_type" must be in "sparse" and "dense", but we got { conn_type } ' )
1205- if self .conn is None :
1206- raise ValueError (f'Must provide "conn" when initialize the model { self .name } ' )
1207- if isinstance (self .conn , One2One ):
1208- self .g_max = init_param (g_max , (self .pre .num ,), allow_none = False )
1209- self .weight_type = 'heter' if bm .size (self .g_max ) != 1 else 'homo'
1210- elif isinstance (self .conn , All2All ):
1211- self .g_max = init_param (g_max , (self .pre .num , self .post .num ), allow_none = False )
1212- if bm .size (self .g_max ) != 1 :
1213- self .weight_type = 'heter'
1214- bm .fill_diagonal (self .g_max , 0. )
1215- else :
1216- self .weight_type = 'homo'
1217- else :
1218- if conn_type == 'sparse' :
1219- self .pre_ids , self .post_ids = self .conn .require ('pre_ids' , 'post_ids' )
1220- self .g_max = init_param (g_max , self .post_ids .shape , allow_none = False )
1221- self .weight_type = 'heter' if bm .size (self .g_max ) != 1 else 'homo'
1222- elif conn_type == 'dense' :
1223- self .g_max = init_param (g_max , (self .pre .num , self .post .num ), allow_none = False )
1224- self .weight_type = 'heter' if bm .size (self .g_max ) != 1 else 'homo'
1225- if self .weight_type == 'homo' :
1226- self .conn_mat = self .conn .require ('conn_mat' )
1227- else :
1228- raise ValueError (f'Unknown connection type: { conn_type } ' )
1059+ self .pre_ids , self .post_ids = self .conn .require ('pre_ids' , 'post_ids' )
1060+ self .g_max = init_param (g_max , self .post_ids .shape , allow_none = False )
12291061
12301062 # variables
12311063 self .g = bm .Variable (bm .zeros (self .pre .num , dtype = bm .float_ ))
@@ -1254,26 +1086,7 @@ def update(self, t, dt):
12541086 self .x += delayed_pre_spike
12551087
12561088 # post-synaptic value
1257- assert self .weight_type in ['homo' , 'heter' ]
1258- assert self .conn_type in ['sparse' , 'dense' ]
1259- if isinstance (self .conn , All2All ):
1260- if self .weight_type == 'homo' :
1261- post_g = bm .sum (self .g )
1262- if not self .conn .include_self :
1263- post_g = post_g - self .g
1264- post_g = post_g * self .g_max
1265- else :
1266- post_g = self .g @ self .g_max
1267- elif isinstance (self .conn , One2One ):
1268- post_g = self .g_max * self .g
1269- else :
1270- if self .conn_type == 'sparse' :
1271- post_g = bm .pre2post_sum (self .g , self .post .num , self .post_ids , self .pre_ids )
1272- else :
1273- if self .weight_type == 'homo' :
1274- post_g = (self .g_max * self .g ) @ self .conn_mat
1275- else :
1276- post_g = self .g @ self .g_max
1089+ post_g = bm .pre2post_sum (self .g , self .post .num , self .post_ids , self .pre_ids )
12771090
12781091 # output
12791092 g_inf = 1 + self .cc_Mg / self .beta * bm .exp (- self .alpha * self .post .V )
0 commit comments