99from brainpy ._src .context import share
1010from brainpy ._src .dyn import synapses
1111from brainpy ._src .dyn .base import NeuDyn
12+ from brainpy ._src .dnn import linear
1213from brainpy ._src .dynold .synouts import MgBlock , CUBA
1314from brainpy ._src .initialize import Initializer , variable_
1415from brainpy ._src .integrators .ode .generic import odeint
16+ from brainpy ._src .dyn .projections .aligns import _pre_delay_repr , _init_delay
1517from brainpy .types import ArrayType
16- from .base import TwoEndConn , _SynSTP , _SynOut , _TwoEndConnAlignPre , _DelayedSyn , _init_stp
18+ from .base import TwoEndConn , _SynSTP , _SynOut , _TwoEndConnAlignPre
1719
1820__all__ = [
1921 'Delta' ,
@@ -100,12 +102,12 @@ def __init__(
100102 stop_spike_gradient : bool = False ,
101103 ):
102104 super ().__init__ (name = name ,
103- pre = pre ,
104- post = post ,
105- conn = conn ,
106- output = output ,
107- stp = stp ,
108- mode = mode )
105+ pre = pre ,
106+ post = post ,
107+ conn = conn ,
108+ output = output ,
109+ stp = stp ,
110+ mode = mode )
109111
110112 # parameters
111113 self .stop_spike_gradient = stop_spike_gradient
@@ -298,29 +300,40 @@ def __init__(
298300 mode = mode )
299301 # parameters
300302 self .stop_spike_gradient = stop_spike_gradient
301- self .comp_method = comp_method
302- self .tau = tau
303- if bm .size (self .tau ) != 1 :
304- raise ValueError (f'"tau" must be a scalar or a tensor with size of 1. But we got { self .tau } ' )
305303
306- # connections and weights
307- self .g_max , self .conn_mask = self ._init_weights (g_max , comp_method , sparse_data = 'csr' )
304+ # synapse dynamics
305+ self .syn = synapses .Expon (post .varshape , tau = tau , method = method )
306+
307+ # Projection
308+ if isinstance (conn , All2All ):
309+ self .comm = linear .AllToAll (pre .num , post .num , g_max )
310+ elif isinstance (conn , One2One ):
311+ assert post .num == pre .num
312+ self .comm = linear .OneToOne (pre .num , g_max )
313+ else :
314+ if comp_method == 'dense' :
315+ self .comm = linear .MaskedLinear (conn , g_max )
316+ elif comp_method == 'sparse' :
317+ if self .stp is None :
318+ self .comm = linear .EventCSRLinear (conn , g_max )
319+ else :
320+ self .comm = linear .CSRLinear (conn , g_max )
321+ else :
322+ raise ValueError (f'Does not support { comp_method } , only "sparse" or "dense".' )
308323
309324 # variables
310- self .g = variable_ (bm .zeros , self .post .num , self .mode )
311- self .delay_step = self .register_delay (f"{ self .pre .name } .spike" , delay_step , self .pre .spike )
325+ self .g = self .syn .g
312326
313- # function
314- self .integral = odeint ( lambda g , t : - g / self .tau , method = method )
327+ # delay
328+ self .delay_step = self . register_delay ( f" { self . pre . name } .spike" , delay_step , self .pre . spike )
315329
316330 def reset_state (self , batch_size = None ):
317- self .g . value = variable_ ( bm . zeros , self . post . num , batch_size )
331+ self .syn . reset_state ( batch_size )
318332 self .output .reset_state (batch_size )
319- if self .stp is not None : self .stp .reset_state (batch_size )
333+ if self .stp is not None :
334+ self .stp .reset_state (batch_size )
320335
321336 def update (self , pre_spike = None ):
322- t , dt = share ['t' ], share ['dt' ]
323-
324337 # delays
325338 if pre_spike is None :
326339 pre_spike = self .get_delay_data (f"{ self .pre .name } .spike" , self .delay_step )
@@ -332,52 +345,13 @@ def update(self, pre_spike=None):
332345 self .output .update ()
333346 if self .stp is not None :
334347 self .stp .update (pre_spike )
348+ pre_spike = self .stp (pre_spike )
335349
336350 # post values
337- if isinstance (self .conn , All2All ):
338- syn_value = bm .asarray (pre_spike , dtype = bm .float_ )
339- if self .stp is not None : syn_value = self .stp (syn_value )
340- post_vs = self ._syn2post_with_all2all (syn_value , self .g_max )
341- elif isinstance (self .conn , One2One ):
342- syn_value = bm .asarray (pre_spike , dtype = bm .float_ )
343- if self .stp is not None : syn_value = self .stp (syn_value )
344- post_vs = self ._syn2post_with_one2one (syn_value , self .g_max )
345- else :
346- if self .comp_method == 'sparse' :
347- f = lambda s : bm .event .csrmv (self .g_max ,
348- self .conn_mask [0 ],
349- self .conn_mask [1 ],
350- s ,
351- shape = (self .pre .num , self .post .num ),
352- transpose = True )
353- if isinstance (self .mode , bm .BatchingMode ): f = jax .vmap (f )
354- post_vs = f (pre_spike )
355- # if not isinstance(self.stp, _NullSynSTP):
356- # raise NotImplementedError()
357- else :
358- syn_value = bm .asarray (pre_spike , dtype = bm .float_ )
359- if self .stp is not None :
360- syn_value = self .stp (syn_value )
361- post_vs = self ._syn2post_with_dense (syn_value , self .g_max , self .conn_mask )
362- # updates
363- self .g .value = self .integral (self .g .value , t , dt ) + post_vs
351+ g = self .syn (self .comm (pre_spike ))
364352
365353 # output
366- return self .output (self .g )
367-
368-
369- class _DelayedDualExp (_DelayedSyn ):
370- not_desc_params = ('master' , 'mode' )
371-
372- def __init__ (self , size , keep_size , mode , tau_decay , tau_rise , method , master , stp = None ):
373- syn = synapses .DualExpon (size ,
374- keep_size ,
375- mode = mode ,
376- tau_decay = tau_decay ,
377- tau_rise = tau_rise ,
378- method = method )
379- stp = _init_stp (stp , master )
380- super ().__init__ (syn , stp )
354+ return self .output (g )
381355
382356
383357class DualExponential (_TwoEndConnAlignPre ):
@@ -507,14 +481,12 @@ def __init__(
507481 raise ValueError (f'"tau_decay" must be a scalar or a tensor with size of 1. '
508482 f'But we got { self .tau_decay } ' )
509483
510- syn = _DelayedDualExp .desc (pre .size ,
511- pre .keep_size ,
512- mode = mode ,
513- tau_decay = tau_decay ,
514- tau_rise = tau_rise ,
515- method = method ,
516- stp = stp ,
517- master = self )
484+ syn = synapses .DualExpon (pre .size ,
485+ pre .keep_size ,
486+ mode = mode ,
487+ tau_decay = tau_decay ,
488+ tau_rise = tau_rise ,
489+ method = method , )
518490
519491 super ().__init__ (pre = pre ,
520492 post = post ,
@@ -530,7 +502,6 @@ def __init__(
530502
531503 self .check_post_attrs ('input' )
532504 # copy the references
533- syn = self .post .before_updates [self .proj ._syn_id ].syn .syn
534505 self .g = syn .g
535506 self .h = syn .h
536507
@@ -652,21 +623,6 @@ def __init__(
652623 stop_spike_gradient = stop_spike_gradient )
653624
654625
655- class _DelayedNMDA (_DelayedSyn ):
656- not_desc_params = ('master' , 'stp' , 'mode' )
657-
658- def __init__ (self , size , keep_size , mode , a , tau_decay , tau_rise , method , master , stp = None ):
659- syn = synapses .NMDA (size ,
660- keep_size ,
661- mode = mode ,
662- a = a ,
663- tau_decay = tau_decay ,
664- tau_rise = tau_rise ,
665- method = method )
666- stp = _init_stp (stp , master )
667- super ().__init__ (syn , stp )
668-
669-
670626class NMDA (_TwoEndConnAlignPre ):
671627 r"""NMDA synapse model.
672628
@@ -825,15 +781,13 @@ def __init__(
825781 self .comp_method = comp_method
826782 self .stop_spike_gradient = stop_spike_gradient
827783
828- syn = _DelayedNMDA .desc (pre .size ,
829- pre .keep_size ,
830- mode = mode ,
831- a = a ,
832- tau_decay = tau_decay ,
833- tau_rise = tau_rise ,
834- method = method ,
835- stp = stp ,
836- master = self )
784+ syn = synapses .NMDA (pre .size ,
785+ pre .keep_size ,
786+ mode = mode ,
787+ a = a ,
788+ tau_decay = tau_decay ,
789+ tau_rise = tau_rise ,
790+ method = method , )
837791
838792 super ().__init__ (pre = pre ,
839793 post = post ,
@@ -848,7 +802,6 @@ def __init__(
848802 mode = mode )
849803
850804 # copy the references
851- syn = self .post .before_updates [self .proj ._syn_id ].syn .syn
852805 self .g = syn .g
853806 self .x = syn .x
854807
0 commit comments