@@ -64,10 +64,10 @@ def __init__(
6464 self .num_times = len (times )
6565
6666 # data about times and indices
67- self .i = bm .Variable (jnp .zeros (1 , dtype = bm .int_ ))
68- self .times = bm .Variable (jnp .asarray (times , dtype = bm .float_ ))
69- self .indices = bm .Variable (jnp .asarray (indices , dtype = bm .int_ ))
70- self .spike = bm .Variable (jnp .zeros (self .num , dtype = bool ))
67+ self .i = bm .Variable (bm .zeros (1 , dtype = bm .int_ ))
68+ self .times = bm .Variable (bm .asarray (times , dtype = bm .float_ ))
69+ self .indices = bm .Variable (bm .asarray (indices , dtype = bm .int_ ))
70+ self .spike = bm .Variable (bm .zeros (self .num , dtype = bool ))
7171 if need_sort :
7272 sort_idx = bm .argsort (self .times )
7373 self .indices .value = self .indices [sort_idx ]
@@ -121,8 +121,8 @@ def __init__(
121121 self .freqs = init_param (freqs , self .num , allow_none = False )
122122 self .dt = bm .get_dt () / 1000.
123123 self .size = (size ,) if isinstance (size , int ) else tuple (size )
124- self .spike = bm .Variable (jnp .zeros (self .num , dtype = bool ))
125- self .t_last_spike = bm .Variable (jnp .ones (self .num ) * - 1e7 )
124+ self .spike = bm .Variable (bm .zeros (self .num , dtype = bool ))
125+ self .t_last_spike = bm .Variable (bm .ones (self .num ) * - 1e7 )
126126 self .rng = bm .random .RandomState (seed = seed )
127127
128128 def update (self , _t , _i ):
0 commit comments