Skip to content

Commit db5716a

Browse files
committed
update examples
1 parent 78cbc1d commit db5716a

File tree

4 files changed

+21
-19
lines changed

4 files changed

+21
-19
lines changed

examples/dynamics_simulation/Wang_2002_decision_making_spiking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,11 @@ def batching_run():
267267
runner = bp.DSRunner(
268268
net,
269269
monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq'],
270-
time_major=False
270+
data_first_axis=False
271271
)
272272
runner.run(total_period)
273273

274-
coherence = coherence.to_numpy()
274+
coherence = bm.as_numpy(coherence)
275275
fig, gs = bp.visualize.get_figure(num_row, num_col, 3, 4)
276276
for i in range(num_row):
277277
for j in range(num_col):

examples/training_snn_models/SurrogateGrad_lif_fashion_mnist.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def loss_fun(predicts, targets):
163163
return loss + l2_loss + l1_loss
164164

165165
trainer = bp.train.BPTT(
166-
model, loss_fun,
166+
model,
167+
loss_fun,
167168
optimizer=bp.optim.Adam(lr=lr),
168169
monitors={'r.spike': net.r.spike},
169170
)

examples/training_snn_models/fashion_mnist_conv_lif.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import brainpy as bp
1313
import brainpy.math as bm
14+
from brainpy.tools import DotDict
1415

1516
bm.set_environment(mode=bm.training_mode, dt=1.)
1617

@@ -82,7 +83,8 @@ def reset_state(self, batch_size):
8283
def update(self, s, x):
8384
self.V.value += x
8485
spike = self.spike_fun(self.V - self.v_threshold)
85-
s = lax.stop_gradient(spike)
86+
# s = lax.stop_gradient(spike)
87+
s = spike
8688
if self.reset_mode == 'hard':
8789
one = lax.convert_element_type(1., bm.float_)
8890
self.V.value = self.v_reset * s + (one - s) * self.V
@@ -97,24 +99,24 @@ def __init__(self, n_time: int, n_channel: int):
9799
self.n_time = n_time
98100

99101
self.block1 = bp.Sequential(
100-
bp.layers.Conv2D(1, n_channel, kernel_size=3, padding=(1, 1), b_initializer=None),
102+
bp.layers.Conv2D(1, n_channel, kernel_size=3, padding=(1, 1), ),
101103
bp.layers.BatchNorm2D(n_channel, momentum=0.9),
102104
IFNode((28, 28, n_channel), spike_fun=bm.surrogate.arctan)
103105
)
104106
self.block2 = bp.Sequential(
105107
bp.layers.MaxPool([2, 2], 2, channel_axis=-1), # 14 * 14
106-
bp.layers.Conv2D(n_channel, n_channel, kernel_size=3, padding=(1, 1), b_initializer=None),
108+
bp.layers.Conv2D(n_channel, n_channel, kernel_size=3, padding=(1, 1), ),
107109
bp.layers.BatchNorm2D(n_channel, momentum=0.9),
108110
IFNode((14, 14, n_channel), spike_fun=bm.surrogate.arctan),
109111
)
110112
self.block3 = bp.Sequential(
111113
bp.layers.MaxPool([2, 2], 2, channel_axis=-1), # 7 * 7
112114
bp.layers.Flatten(),
113-
bp.layers.Dense(n_channel * 7 * 7, n_channel * 4 * 4, b_initializer=None),
115+
bp.layers.Dense(n_channel * 7 * 7, n_channel * 4 * 4,),
114116
IFNode((4 * 4 * n_channel,), spike_fun=bm.surrogate.arctan),
115117
)
116118
self.block4 = bp.Sequential(
117-
bp.layers.Dense(n_channel * 4 * 4, 10, b_initializer=None),
119+
bp.layers.Dense(n_channel * 4 * 4, 10, ),
118120
IFNode((10,), spike_fun=bm.surrogate.arctan),
119121
)
120122

@@ -138,8 +140,6 @@ def main():
138140
parser.add_argument('-data-dir', default='./data', type=str, help='root dir of Fashion-MNIST dataset')
139141
parser.add_argument('-out-dir', default='./logs', type=str, help='root dir for saving logs and checkpoint')
140142
parser.add_argument('-lr', default=0.1, type=float, help='learning rate')
141-
parser.add_argument('-save-es', default=None,
142-
help='filepath for saving a batch spikes encoded by the first {Conv2d-BatchNorm2d-IFNode}')
143143
args = parser.parse_args()
144144
print(args)
145145

@@ -163,20 +163,20 @@ def main():
163163
def inference_fun(X, fit=True):
164164
net.reset_state(X.shape[0])
165165
return bm.for_loop(lambda sha: net(sha.update(dt=bm.dt, fit=fit), X),
166-
bp.tools.DotDict(t=bm.arange(args.n_time, dtype=bm.float_),
167-
i=bm.arange(args.n_time, dtype=bm.int_)),
168-
dyn_vars=net.vars().unique())
166+
DotDict(t=bm.arange(args.n_time, dtype=bm.float_),
167+
i=bm.arange(args.n_time, dtype=bm.int_)),
168+
child_objs=net)
169169

170170
# loss function
171171
@bm.to_object(child_objs=net)
172172
def loss_fun(X, Y, fit=True):
173-
fr = bm.mean(inference_fun(X, fit), axis=0)
173+
fr = bm.max(inference_fun(X, fit), axis=0)
174174
ys_onehot = bm.one_hot(Y, 10, dtype=bm.float_)
175175
l = bp.losses.mean_squared_error(fr, ys_onehot)
176176
n = bm.sum(fr.argmax(1) == Y)
177177
return l, n
178178

179-
predict_loss_fun = bm.jit(partial(loss_fun, fit=True), dyn_vars=loss_fun.vars().unique())
179+
predict_loss_fun = bm.jit(partial(loss_fun, fit=True), child_objs=loss_fun)
180180

181181
grad_fun = bm.grad(loss_fun, grad_vars=net.train_vars().unique(), has_aux=True, return_value=True)
182182

@@ -242,7 +242,7 @@ def train_fun(X, Y):
242242
'train_acc': train_acc,
243243
'test_acc': test_acc,
244244
}
245-
bp.checkpoints.save(out_dir, states, epoch_i)
245+
# bp.checkpoints.save(out_dir, states, epoch_i)
246246

247247
# inference
248248
state_dict = bp.checkpoints.load(out_dir)

examples/training_snn_models/mnist_lif_readout.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,10 @@ def update(self, p, x):
6565

6666
@bm.to_object(child_objs=(net, encoder))
6767
def loss_fun(xs, ys):
68-
net.reset_state(xs.shape[0])
68+
net.reset_state(batch_size=xs.shape[0])
6969
xs = encoder(xs, num_step=args.T)
70-
shared = bm.form_shared_args(num_step=xs.shape[0])
70+
# shared arguments for looping over time
71+
shared = bm.shared_args_over_time(num_step=args.T)
7172
outs = bm.for_loop(net, (shared, xs))
7273
out_fr = bm.mean(outs, axis=0)
7374
ys_onehot = bm.one_hot(ys, 10, dtype=bm.float_)
@@ -140,7 +141,7 @@ def train(xs, ys):
140141
state_dict = bp.checkpoints.load(out_dir)
141142
net.load_state_dict(state_dict['net'])
142143

143-
runner = bp.DSRunner(net, time_major=True)
144+
runner = bp.DSRunner(net, data_first_axis='T')
144145
correct_num = 0
145146
for i in range(0, x_test.shape[0], 512):
146147
X = encoder(x_test[i: i + 512], num_step=args.T)

0 commit comments

Comments
 (0)