88from brainpy .integrators .joint_eq import JointEq
99from brainpy .integrators .ode import odeint
1010from brainpy .integrators .sde import sdeint
11- from brainpy .modes import Mode , BatchingMode , TrainingMode , normal
11+ from brainpy .modes import Mode , BatchingMode , TrainingMode , NormalMode , normal , check
1212from brainpy .tools .checking import check_initializer
1313from brainpy .types import Shape , Tensor
1414
@@ -219,6 +219,7 @@ def __init__(
219219 keep_size = keep_size ,
220220 name = name ,
221221 mode = mode )
222+ check (self .mode , (BatchingMode , NormalMode ), self .__class__ .__name__ )
222223
223224 # parameters
224225 self .ENa = parameter (ENa , self .varshape , allow_none = False )
@@ -247,8 +248,7 @@ def __init__(
247248 self .n = variable (self ._n_initializer , mode , self .varshape )
248249 self .V = variable (self ._V_initializer , mode , self .varshape )
249250 self .input = variable (bm .zeros , mode , self .varshape )
250- sp_type = bm .dftype () if isinstance (self .mode , TrainingMode ) else bool
251- self .spike = variable (lambda s : bm .zeros (s , dtype = sp_type ), mode , self .varshape )
251+ self .spike = variable (lambda s : bm .zeros (s , dtype = bool ), mode , self .varshape )
252252
253253 # integral
254254 if self .noise is None :
@@ -262,8 +262,7 @@ def reset_state(self, batch_size=None):
262262 self .n .value = variable (self ._n_initializer , batch_size , self .varshape )
263263 self .V .value = variable (self ._V_initializer , batch_size , self .varshape )
264264 self .input .value = variable (bm .zeros , batch_size , self .varshape )
265- sp_type = bm .dftype () if isinstance (self .mode , TrainingMode ) else bool
266- self .spike .value = variable (lambda s : bm .zeros (s , dtype = sp_type ), batch_size , self .varshape )
265+ self .spike .value = variable (lambda s : bm .zeros (s , dtype = bool ), batch_size , self .varshape )
267266
268267 def dm (self , m , t , V ):
269268 alpha = 0.1 * (V + 40 ) / (1 - bm .exp (- (V + 40 ) / 10 ))
@@ -413,6 +412,7 @@ def __init__(
413412 keep_size = keep_size ,
414413 name = name ,
415414 mode = mode )
415+ check (self .mode , (BatchingMode , NormalMode ), self .__class__ )
416416
417417 # params
418418 self .V_Ca = parameter (V_Ca , self .varshape , allow_none = False )
@@ -440,8 +440,7 @@ def __init__(
440440 self .W = variable (self ._W_initializer , mode , self .varshape )
441441 self .V = variable (self ._V_initializer , mode , self .varshape )
442442 self .input = variable (bm .zeros , mode , self .varshape )
443- sp_type = bm .dftype () if isinstance (self .mode , TrainingMode ) else bool
444- self .spike = variable (lambda s : bm .zeros (s , dtype = sp_type ), mode , self .varshape )
443+ self .spike = variable (lambda s : bm .zeros (s , dtype = bool ), mode , self .varshape )
445444
446445 # integral
447446 if self .noise is None :
@@ -453,8 +452,7 @@ def reset_state(self, batch_size=None):
453452 self .W .value = variable (self ._W_initializer , batch_size , self .varshape )
454453 self .V .value = variable (self ._V_initializer , batch_size , self .varshape )
455454 self .input .value = variable (bm .zeros , batch_size , self .varshape )
456- sp_type = bm .dftype () if isinstance (self .mode , TrainingMode ) else bool
457- self .spike .value = variable (lambda s : bm .zeros (s , dtype = sp_type ), batch_size , self .varshape )
455+ self .spike .value = variable (lambda s : bm .zeros (s , dtype = bool ), batch_size , self .varshape )
458456
459457 def dV (self , V , t , W , I_ext ):
460458 M_inf = (1 / 2 ) * (1 + bm .tanh ((V - self .V1 ) / self .V2 ))
@@ -672,6 +670,7 @@ def __init__(
672670 keep_size = keep_size ,
673671 name = name ,
674672 mode = mode )
673+ check (self .mode , (NormalMode , BatchingMode ), self .__class__ )
675674
676675 # conductance parameters
677676 self .gAHP = parameter (gAHP , self .varshape , allow_none = False )
@@ -980,6 +979,7 @@ def __init__(
980979 ):
981980 # initialization
982981 super (WangBuzsakiModel , self ).__init__ (size = size , keep_size = keep_size , name = name , mode = mode )
982+ check (self .mode , (BatchingMode , NormalMode ), self .__class__ )
983983
984984 # parameters
985985 self .ENa = parameter (ENa , self .varshape , allow_none = False )
@@ -1006,8 +1006,7 @@ def __init__(
10061006 self .n = variable (self ._n_initializer , mode , self .varshape )
10071007 self .V = variable (self ._V_initializer , mode , self .varshape )
10081008 self .input = variable (bm .zeros , mode , self .varshape )
1009- sp_type = bm .dftype () if isinstance (self .mode , TrainingMode ) else bool
1010- self .spike = variable (lambda s : bm .zeros (s , dtype = sp_type ), mode , self .varshape )
1009+ self .spike = variable (lambda s : bm .zeros (s , dtype = bool ), mode , self .varshape )
10111010
10121011 # integral
10131012 if self .noise is None :
@@ -1020,8 +1019,7 @@ def reset_state(self, batch_size=None):
10201019 self .n .value = variable (self ._n_initializer , batch_size , self .varshape )
10211020 self .V .value = variable (self ._V_initializer , batch_size , self .varshape )
10221021 self .input .value = variable (bm .zeros , batch_size , self .varshape )
1023- sp_type = bm .dftype () if isinstance (self .mode , TrainingMode ) else bool
1024- self .spike .value = variable (lambda s : bm .zeros (s , dtype = sp_type ), batch_size , self .varshape )
1022+ self .spike .value = variable (lambda s : bm .zeros (s , dtype = bool ), batch_size , self .varshape )
10251023
10261024 def m_inf (self , V ):
10271025 alpha = - 0.1 * (V + 35 ) / (bm .exp (- 0.1 * (V + 35 )) - 1 )
0 commit comments