55import brainpy .math as bm
66from brainpy .dyn .base import NeuGroup
77from brainpy .initialize import ZeroInit , OneInit , Initializer , init_param
8- from brainpy .integrators .joint_eq import JointEq
9- from brainpy .integrators .ode import odeint
8+ from brainpy .integrators import sdeint , odeint , JointEq
109from brainpy .tools .checking import check_initializer
1110from brainpy .types import Shape , Tensor
1211
@@ -46,31 +45,34 @@ class LIF(NeuGroup):
4645
4746 - `(Brette, Romain. 2004) LIF phase locking <https://brainpy-examples.readthedocs.io/en/latest/neurons/Romain_2004_LIF_phase_locking.html>`_
4847
49- **Model Parameters**
50-
51- ============= ============== ======== =========================================
52- **Parameter** **Init Value** **Unit** **Explanation**
53- ------------- -------------- -------- -----------------------------------------
54- V_rest 0 mV Resting membrane potential.
55- V_reset -5 mV Reset potential after spike.
56- V_th 20 mV Threshold potential of spike.
57- tau 10 ms Membrane time constant. Compute by R * C.
58- tau_ref 5 ms Refractory period length.(ms)
59- ============= ============== ======== =========================================
6048
61- **Neuron Variables**
62-
63- ================== ================= =========================================================
64- **Variables name** **Initial Value** **Explanation**
65- ------------------ ----------------- ---------------------------------------------------------
66- V 0 Membrane potential.
67- input 0 External and synaptic input current.
68- spike False Flag to mark whether the neuron is spiking.
69- refractory False Flag to mark whether the neuron is in refractory period.
70- t_last_spike -1e7 Last spike time stamp.
71- ================== ================= =========================================================
72-
73- **References**
49+ Parameters
50+ ----------
51+ size: sequence of int, int
52+ The size of the neuron group.
53+ V_rest: float, JaxArray, ndarray, Initializer, callable
54+ Resting membrane potential.
55+ V_reset: float, JaxArray, ndarray, Initializer, callable
56+ Reset potential after spike.
57+ V_th: float, JaxArray, ndarray, Initializer, callable
58+ Threshold potential of spike.
59+ tau: float, JaxArray, ndarray, Initializer, callable
60+ Membrane time constant.
61+ tau_ref: float, JaxArray, ndarray, Initializer, callable
62+ Refractory period length.(ms)
63+ V_initializer: JaxArray, ndarray, Initializer, callable
64+ The initializer of membrane potential.
65+ noise: JaxArray, ndarray, Initializer, callable
66+ The noise added onto the membrane potential
67+ noise_type: str
68+ The type of the provided noise. Can be `value` or `func`.
69+ method: str
70+ The numerical integration method.
71+ name: str
72+ The group name.
73+
74+ References
75+ ----------
7476
7577 .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model
7678 neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304.
@@ -85,44 +87,57 @@ def __init__(
8587 tau : Union [float , Tensor , Initializer , Callable ] = 10. ,
8688 tau_ref : Union [float , Tensor , Initializer , Callable ] = 1. ,
8789 V_initializer : Union [Initializer , Callable , Tensor ] = ZeroInit (),
90+ noise : Union [float , Tensor , Initializer , Callable ] = None ,
91+ noise_type : str = 'value' ,
92+ keep_size : bool = False ,
8893 method : str = 'exp_auto' ,
8994 name : str = None
9095 ):
9196 # initialization
9297 super (LIF , self ).__init__ (size = size , name = name )
9398
9499 # parameters
95- self .V_rest = init_param (V_rest , self .num , allow_none = False )
96- self .V_reset = init_param (V_reset , self .num , allow_none = False )
97- self .V_th = init_param (V_th , self .num , allow_none = False )
98- self .tau = init_param (tau , self .num , allow_none = False )
99- self .tau_ref = init_param (tau_ref , self .num , allow_none = False )
100+ self .keep_size = keep_size
101+ self .noise_type = noise_type
102+ if noise_type not in ['func' , 'value' ]:
103+ raise ValueError (f'noise_type only supports `func` and `value`, but we got { noise_type } ' )
104+ size = self .size if keep_size else self .num
105+ self .V_rest = init_param (V_rest , size , allow_none = False )
106+ self .V_reset = init_param (V_reset , size , allow_none = False )
107+ self .V_th = init_param (V_th , size , allow_none = False )
108+ self .tau = init_param (tau , size , allow_none = False )
109+ self .tau_ref = init_param (tau_ref , size , allow_none = False )
110+ if noise_type == 'func' :
111+ self .noise = noise
112+ else :
113+ self .noise = init_param (noise , size , allow_none = True )
100114
101115 # initializers
102116 check_initializer (V_initializer , 'V_initializer' )
103117 self ._V_initializer = V_initializer
104118
105119 # variables
106- self .V = bm .Variable (init_param (V_initializer , ( self . num ,) ))
107- self .input = bm .Variable (bm .zeros (self . num ))
108- self .spike = bm .Variable (bm .zeros (self . num , dtype = bool ))
109- self .t_last_spike = bm .Variable (bm .ones (self . num ) * - 1e7 )
110- self .refractory = bm .Variable (bm .zeros (self . num , dtype = bool ))
120+ self .V = bm .Variable (init_param (V_initializer , size ))
121+ self .input = bm .Variable (bm .zeros (size ))
122+ self .spike = bm .Variable (bm .zeros (size , dtype = bool ))
123+ self .t_last_spike = bm .Variable (bm .ones (size ) * - 1e7 )
124+ self .refractory = bm .Variable (bm .zeros (size , dtype = bool ))
111125
112126 # integral
113- self .integral = odeint (method = method , f = self .derivative )
127+ f = lambda V , t , I_ext : (- V + self .V_rest + I_ext ) / self .tau
128+ if self .noise is not None :
129+ g = noise if (noise_type == 'func' ) else (lambda V , t , I_ext : self .noise / bm .sqrt (self .tau ))
130+ self .integral = sdeint (method = method , f = f , g = g )
131+ else :
132+ self .integral = odeint (method = method , f = f )
114133
115134 def reset (self ):
116- self .V .value = init_param (self ._V_initializer , ( self .num ,) )
135+ self .V .value = init_param (self ._V_initializer , self .size if self . keep_size else self . num )
117136 self .input [:] = 0
118137 self .spike [:] = False
119138 self .t_last_spike [:] = - 1e7
120139 self .refractory [:] = False
121140
122- def derivative (self , V , t , I_ext ):
123- dvdt = (- V + self .V_rest + I_ext ) / self .tau
124- return dvdt
125-
126141 def update (self , t , dt ):
127142 refractory = (t - self .t_last_spike ) <= self .tau_ref
128143 V = self .integral (self .V , t , self .input , dt = dt )
0 commit comments