11# -*- coding: utf-8 -*-
22
3- from typing import Union , Callable
3+ from typing import Union , Callable , Optional
44
55import brainpy .math as bm
66from brainpy .dyn .base import NeuGroup
@@ -204,9 +204,9 @@ def __init__(
204204 V_th : Union [float , Tensor , Initializer , Callable ] = 20. ,
205205 C : Union [float , Tensor , Initializer , Callable ] = 1.0 ,
206206 V_initializer : Union [Initializer , Callable , Tensor ] = Uniform (- 70 , - 60. ),
207- m_initializer : Union [Initializer , Callable , Tensor ] = OneInit ( 0.5 ) ,
208- h_initializer : Union [Initializer , Callable , Tensor ] = OneInit ( 0.6 ) ,
209- n_initializer : Union [Initializer , Callable , Tensor ] = OneInit ( 0.32 ) ,
207+ m_initializer : Optional [ Union [Initializer , Callable , Tensor ]] = None ,
208+ h_initializer : Optional [ Union [Initializer , Callable , Tensor ]] = None ,
209+ n_initializer : Optional [ Union [Initializer , Callable , Tensor ]] = None ,
210210 noise : Union [float , Tensor , Initializer , Callable ] = None ,
211211 method : str = 'exp_auto' ,
212212 name : str = None ,
@@ -233,20 +233,29 @@ def __init__(
233233 self .noise = init_noise (noise , self .varshape , num_vars = 4 )
234234
235235 # initializers
236- check_initializer (m_initializer , 'm_initializer' , allow_none = False )
237- check_initializer (h_initializer , 'h_initializer' , allow_none = False )
238- check_initializer (n_initializer , 'n_initializer' , allow_none = False )
236+ check_initializer (m_initializer , 'm_initializer' , allow_none = True )
237+ check_initializer (h_initializer , 'h_initializer' , allow_none = True )
238+ check_initializer (n_initializer , 'n_initializer' , allow_none = True )
239239 check_initializer (V_initializer , 'V_initializer' , allow_none = False )
240240 self ._m_initializer = m_initializer
241241 self ._h_initializer = h_initializer
242242 self ._n_initializer = n_initializer
243243 self ._V_initializer = V_initializer
244244
245245 # variables
246- self .m = variable (self ._m_initializer , mode , self .varshape )
247- self .h = variable (self ._h_initializer , mode , self .varshape )
248- self .n = variable (self ._n_initializer , mode , self .varshape )
249246 self .V = variable (self ._V_initializer , mode , self .varshape )
247+ if self ._m_initializer is None :
248+ self .m = bm .Variable (self .m_inf (self .V .value ))
249+ else :
250+ self .m = variable (self ._m_initializer , mode , self .varshape )
251+ if self ._h_initializer is None :
252+ self .h = bm .Variable (self .h_inf (self .V .value ))
253+ else :
254+ self .h = variable (self ._h_initializer , mode , self .varshape )
255+ if self ._n_initializer is None :
256+ self .n = bm .Variable (self .n_inf (self .V .value ))
257+ else :
258+ self .n = variable (self ._n_initializer , mode , self .varshape )
250259 self .input = variable (bm .zeros , mode , self .varshape )
251260 self .spike = variable (lambda s : bm .zeros (s , dtype = bool ), mode , self .varshape )
252261
@@ -256,32 +265,41 @@ def __init__(
256265 else :
257266 self .integral = sdeint (method = method , f = self .derivative , g = self .noise )
258267
268+ # m channel
269+ m_alpha = lambda self , V : 0.1 * (V + 40 ) / (1 - bm .exp (- (V + 40 ) / 10 ))
270+ m_beta = lambda self , V : 4.0 * bm .exp (- (V + 65 ) / 18 )
271+ m_inf = lambda self , V : self .m_alpha (V ) / (self .m_alpha (V ) + self .m_beta (V ))
272+ dm = lambda self , m , t , V : self .m_alpha (V ) * (1 - m ) - self .m_beta (V ) * m
273+
274+ # h channel
275+ h_alpha = lambda self , V : 0.07 * bm .exp (- (V + 65 ) / 20. )
276+ h_beta = lambda self , V : 1 / (1 + bm .exp (- (V + 35 ) / 10 ))
277+ h_inf = lambda self , V : self .h_alpha (V ) / (self .h_alpha (V ) + self .h_beta (V ))
278+ dh = lambda self , h , t , V : self .h_alpha (V ) * (1 - h ) - self .h_beta (V ) * h
279+
280+ # n channel
281+ n_alpha = lambda self , V : 0.01 * (V + 55 ) / (1 - bm .exp (- (V + 55 ) / 10 ))
282+ n_beta = lambda self , V : 0.125 * bm .exp (- (V + 65 ) / 80 )
283+ n_inf = lambda self , V : self .n_alpha (V ) / (self .n_alpha (V ) + self .n_beta (V ))
284+ dn = lambda self , n , t , V : self .n_alpha (V ) * (1 - n ) - self .n_beta (V ) * n
285+
259286 def reset_state (self , batch_size = None ):
260- self .m .value = variable (self ._m_initializer , batch_size , self .varshape )
261- self .h .value = variable (self ._h_initializer , batch_size , self .varshape )
262- self .n .value = variable (self ._n_initializer , batch_size , self .varshape )
263287 self .V .value = variable (self ._V_initializer , batch_size , self .varshape )
288+ if self ._m_initializer is None :
289+ self .m .value = self .m_inf (self .V .value )
290+ else :
291+ self .m .value = variable (self ._m_initializer , batch_size , self .varshape )
292+ if self ._h_initializer is None :
293+ self .h .value = self .h_inf (self .V .value )
294+ else :
295+ self .h .value = variable (self ._h_initializer , batch_size , self .varshape )
296+ if self ._n_initializer is None :
297+ self .n .value = self .n_inf (self .V .value )
298+ else :
299+ self .n .value = variable (self ._n_initializer , batch_size , self .varshape )
264300 self .input .value = variable (bm .zeros , batch_size , self .varshape )
265301 self .spike .value = variable (lambda s : bm .zeros (s , dtype = bool ), batch_size , self .varshape )
266302
267- def dm (self , m , t , V ):
268- alpha = 0.1 * (V + 40 ) / (1 - bm .exp (- (V + 40 ) / 10 ))
269- beta = 4.0 * bm .exp (- (V + 65 ) / 18 )
270- dmdt = alpha * (1 - m ) - beta * m
271- return dmdt
272-
273- def dh (self , h , t , V ):
274- alpha = 0.07 * bm .exp (- (V + 65 ) / 20. )
275- beta = 1 / (1 + bm .exp (- (V + 35 ) / 10 ))
276- dhdt = alpha * (1 - h ) - beta * h
277- return dhdt
278-
279- def dn (self , n , t , V ):
280- alpha = 0.01 * (V + 55 ) / (1 - bm .exp (- (V + 55 ) / 10 ))
281- beta = 0.125 * bm .exp (- (V + 65 ) / 80 )
282- dndt = alpha * (1 - n ) - beta * n
283- return dndt
284-
285303 def dV (self , V , t , m , h , n , I_ext ):
286304 I_Na = (self .gNa * m ** 3.0 * h ) * (V - self .ENa )
287305 I_K = (self .gK * n ** 4.0 ) * (V - self .EK )
0 commit comments