1111
1212import brainpy as bp
1313import brainpy .math as bm
14+ from brainpy .tools import DotDict
1415
1516bm .set_environment (mode = bm .training_mode , dt = 1. )
1617
@@ -82,7 +83,8 @@ def reset_state(self, batch_size):
8283 def update (self , s , x ):
8384 self .V .value += x
8485 spike = self .spike_fun (self .V - self .v_threshold )
85- s = lax .stop_gradient (spike )
86+ # s = lax.stop_gradient(spike)
87+ s = spike
8688 if self .reset_mode == 'hard' :
8789 one = lax .convert_element_type (1. , bm .float_ )
8890 self .V .value = self .v_reset * s + (one - s ) * self .V
@@ -97,24 +99,24 @@ def __init__(self, n_time: int, n_channel: int):
9799 self .n_time = n_time
98100
99101 self .block1 = bp .Sequential (
100- bp .layers .Conv2D (1 , n_channel , kernel_size = 3 , padding = (1 , 1 ), b_initializer = None ),
102+ bp .layers .Conv2D (1 , n_channel , kernel_size = 3 , padding = (1 , 1 ), ),
101103 bp .layers .BatchNorm2D (n_channel , momentum = 0.9 ),
102104 IFNode ((28 , 28 , n_channel ), spike_fun = bm .surrogate .arctan )
103105 )
104106 self .block2 = bp .Sequential (
105107 bp .layers .MaxPool ([2 , 2 ], 2 , channel_axis = - 1 ), # 14 * 14
106- bp .layers .Conv2D (n_channel , n_channel , kernel_size = 3 , padding = (1 , 1 ), b_initializer = None ),
108+ bp .layers .Conv2D (n_channel , n_channel , kernel_size = 3 , padding = (1 , 1 ), ),
107109 bp .layers .BatchNorm2D (n_channel , momentum = 0.9 ),
108110 IFNode ((14 , 14 , n_channel ), spike_fun = bm .surrogate .arctan ),
109111 )
110112 self .block3 = bp .Sequential (
111113 bp .layers .MaxPool ([2 , 2 ], 2 , channel_axis = - 1 ), # 7 * 7
112114 bp .layers .Flatten (),
113- bp .layers .Dense (n_channel * 7 * 7 , n_channel * 4 * 4 , b_initializer = None ),
115+ bp .layers .Dense (n_channel * 7 * 7 , n_channel * 4 * 4 ,),
114116 IFNode ((4 * 4 * n_channel ,), spike_fun = bm .surrogate .arctan ),
115117 )
116118 self .block4 = bp .Sequential (
117- bp .layers .Dense (n_channel * 4 * 4 , 10 , b_initializer = None ),
119+ bp .layers .Dense (n_channel * 4 * 4 , 10 , ),
118120 IFNode ((10 ,), spike_fun = bm .surrogate .arctan ),
119121 )
120122
@@ -138,8 +140,6 @@ def main():
138140 parser .add_argument ('-data-dir' , default = './data' , type = str , help = 'root dir of Fashion-MNIST dataset' )
139141 parser .add_argument ('-out-dir' , default = './logs' , type = str , help = 'root dir for saving logs and checkpoint' )
140142 parser .add_argument ('-lr' , default = 0.1 , type = float , help = 'learning rate' )
141- parser .add_argument ('-save-es' , default = None ,
142- help = 'filepath for saving a batch spikes encoded by the first {Conv2d-BatchNorm2d-IFNode}' )
143143 args = parser .parse_args ()
144144 print (args )
145145
@@ -163,20 +163,20 @@ def main():
163163 def inference_fun (X , fit = True ):
164164 net .reset_state (X .shape [0 ])
165165 return bm .for_loop (lambda sha : net (sha .update (dt = bm .dt , fit = fit ), X ),
166- bp . tools . DotDict (t = bm .arange (args .n_time , dtype = bm .float_ ),
167- i = bm .arange (args .n_time , dtype = bm .int_ )),
168- dyn_vars = net . vars (). unique () )
166+ DotDict (t = bm .arange (args .n_time , dtype = bm .float_ ),
167+ i = bm .arange (args .n_time , dtype = bm .int_ )),
168+ child_objs = net )
169169
170170 # loss function
171171 @bm .to_object (child_objs = net )
172172 def loss_fun (X , Y , fit = True ):
173- fr = bm .mean (inference_fun (X , fit ), axis = 0 )
173+ fr = bm .max (inference_fun (X , fit ), axis = 0 )
174174 ys_onehot = bm .one_hot (Y , 10 , dtype = bm .float_ )
175175 l = bp .losses .mean_squared_error (fr , ys_onehot )
176176 n = bm .sum (fr .argmax (1 ) == Y )
177177 return l , n
178178
179- predict_loss_fun = bm .jit (partial (loss_fun , fit = True ), dyn_vars = loss_fun . vars (). unique () )
179+ predict_loss_fun = bm .jit (partial (loss_fun , fit = True ), child_objs = loss_fun )
180180
181181 grad_fun = bm .grad (loss_fun , grad_vars = net .train_vars ().unique (), has_aux = True , return_value = True )
182182
@@ -242,7 +242,7 @@ def train_fun(X, Y):
242242 'train_acc' : train_acc ,
243243 'test_acc' : test_acc ,
244244 }
245- bp .checkpoints .save (out_dir , states , epoch_i )
245+ # bp.checkpoints.save(out_dir, states, epoch_i)
246246
247247 # inference
248248 state_dict = bp .checkpoints .load (out_dir )
0 commit comments