Skip to content

Commit 7d88dac

Browse files
committed
upgrade examples
1 parent dcb41b1 commit 7d88dac

File tree

7 files changed

+49
-67
lines changed

7 files changed

+49
-67
lines changed

examples/ANN_models/mnist-cnn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import brainpy as bp
44
import brainpy.math as bm
5+
import brainpy_datasets as bd
56

67

78
class FeedForwardModel(bp.DynamicalSystem):
@@ -25,7 +26,7 @@ def update(self, s, x):
2526

2627

2728
# train dataset
28-
train_dataset = bp.datasets.MNIST(root='./data', train=True, download=True)
29+
train_dataset = bd.vision.MNIST(root='./data', split='train', download=True)
2930
x_train = bm.array(train_dataset.data, dtype=bm.dftype())
3031
x_train = x_train.reshape(x_train.shape + (1,)) / 255
3132
y_train = bm.array(train_dataset.targets, dtype=bm.ditype())
@@ -41,7 +42,7 @@ def update(self, s, x):
4142
trainer.fit([x_train, y_train], num_epoch=2, batch_size=64)
4243

4344
# test dataset
44-
test_dataset = bp.datasets.MNIST(root='./data', train=False, download=True)
45+
test_dataset = bd.vision.MNIST(root='./data', split='train', download=True)
4546
x_test = bm.array(test_dataset.data, dtype=bm.dftype())
4647
x_test = x_test.reshape(x_test.shape + (1,)) / 255
4748
y_test = bm.array(test_dataset.targets, dtype=bm.ditype())

examples/training/Song_2016_EI_RNN.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def cell(self, x, h):
154154
def readout(self, h):
155155
return h @ self.w_ro + self.b_ro
156156

157-
def make_update(self, h: bm.JaxArray, o: bm.JaxArray):
157+
def make_update(self, h: bm.Array, o: bm.Array):
158158
def f(x):
159159
h.value = self.cell(x, h.value)
160160
o.value = self.readout(h.value[:, :self.e_size])
@@ -164,7 +164,7 @@ def f(x):
164164

165165
def predict(self, xs):
166166
self.h[:] = 0.
167-
return bm.for_loop(self.make_update(self.h, self.o), self.vars(), xs)
167+
return bm.for_loop(self.make_update(self.h, self.o), xs, dyn_vars=self.vars())
168168

169169
def loss(self, xs, ys):
170170
hs, os = self.predict(xs)
@@ -191,19 +191,19 @@ def loss(self, xs, ys):
191191

192192
# %%
193193
# gradient function
194-
grad_f = bm.grad(net.loss,
195-
dyn_vars=net.vars(),
196-
grad_vars=net.train_vars().unique(),
197-
return_value=True)
194+
grad = bm.grad(net.loss,
195+
dyn_vars=net.vars().unique(),
196+
grad_vars=net.train_vars().unique(),
197+
return_value=True)
198198

199199

200200
# %%
201201
@bm.jit
202-
@bm.function(nodes=(net, opt))
202+
@bm.to_object(child_objs=(grad, opt)) # add nodes and vars used
203203
def train(xs, ys):
204-
grads, loss = grad_f(xs, ys)
204+
grads, l = grad(xs, ys)
205205
opt.update(grads)
206-
return loss
206+
return l
207207

208208

209209
# %% [markdown]

examples/training/SurrogateGrad_lif-ANN-style.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
import time
1212

13-
import numpy as np
1413
import matplotlib.pyplot as plt
14+
import numpy as np
1515
from matplotlib.gridspec import GridSpec
1616

1717
import brainpy as bp
@@ -94,41 +94,34 @@ def print_classification_accuracy(output, target):
9494
print_classification_accuracy(out, y_data)
9595

9696

97+
@bm.to_object(child_objs=net, dyn_vars=rng) # add nodes and vars used in this function
9798
def loss():
9899
key = rng.split_key()
99-
X = rng.permutation(x_data, key=key)
100-
Y = rng.permutation(y_data, key=key)
100+
X = bm.random.permutation(x_data, key=key)
101+
Y = bm.random.permutation(y_data, key=key)
101102
looper = bp.dyn.DSRunner(net, numpy_mon_after_run=False, progress_bar=False)
102103
predictions = looper.run(inputs=X, inputs_are_batching=True, reset_state=True)
103104
predictions = bm.max(predictions, axis=1)
104105
return bp.losses.cross_entropy_loss(predictions, Y)
105106

106107

107-
f_grad = bm.grad(loss,
108-
grad_vars=net.train_vars().unique(),
109-
dyn_vars=net.vars().unique() + {'rng': rng},
110-
return_value=True)
111-
f_opt = bp.optim.Adam(lr=2e-3, train_vars=net.train_vars().unique())
108+
grad = bm.grad(loss, grad_vars=loss.train_vars().unique(), return_value=True)
109+
optimizer = bp.optim.Adam(lr=2e-3, train_vars=net.train_vars().unique())
112110

113111

112+
@bm.to_object(child_objs=(grad, optimizer)) # add nodes and vars used in this function
114113
def train(_):
115-
grads, l = f_grad()
116-
f_opt.update(grads)
114+
grads, l = grad()
115+
optimizer.update(grads)
117116
return l
118117

119118

120-
f_train = bm.make_loop(
121-
train,
122-
dyn_vars=f_opt.vars() + net.vars() + {'rng': rng},
123-
has_return=True
124-
)
125-
126119
# train the network
127120
net.reset_state(num_sample)
128121
train_losses = []
129122
for i in range(0, 3000, 100):
130123
t0 = time.time()
131-
_, ls = f_train(bm.arange(i, i + 100, 1))
124+
ls = bm.for_loop(train, operands=bm.arange(i, i + 100, 1))
132125
print(f'Train {i + 100} epoch, loss = {bm.mean(ls):.4f}, used time {time.time() - t0:.4f} s')
133126
train_losses.append(ls)
134127

examples/training/SurrogateGrad_lif.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ def print_classification_accuracy(output, target):
9090
# Before training
9191
runner = bp.dyn.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V})
9292
out = runner.run(inputs=x_data, inputs_are_batching=True, reset_state=True)
93-
# plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike'))
94-
# plot_voltage_traces(out)
93+
plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike'))
94+
plot_voltage_traces(out)
9595
print_classification_accuracy(out, y_data)
9696

9797

98-
@bm.function(nodes=net, dyn_vars=rng) # add nodes and vars used in this function
98+
@bm.to_object(child_objs=net, dyn_vars=rng) # add nodes and vars used here
9999
def loss():
100100
key = rng.split_key()
101101
X = bm.random.permutation(x_data, key=key)
@@ -106,14 +106,11 @@ def loss():
106106
return bp.losses.cross_entropy_loss(predictions, Y)
107107

108108

109-
grad = bm.grad(loss,
110-
grad_vars=loss.train_vars().unique(),
111-
dyn_vars=loss.vars().unique(),
112-
return_value=True)
109+
grad = bm.grad(loss, grad_vars=loss.train_vars().unique(), return_value=True)
113110
optimizer = bp.optim.Adam(lr=2e-3, train_vars=net.train_vars().unique())
114111

115112

116-
@bm.function(nodes=(grad, optimizer)) # add nodes and vars used in this function
113+
@bm.to_object(child_objs=(grad, optimizer)) # add nodes and vars used here
117114
def train(_):
118115
grads, l = grad()
119116
optimizer.update(grads)
@@ -125,7 +122,7 @@ def train(_):
125122
train_losses = []
126123
for i in range(0, 3000, 100):
127124
t0 = time.time()
128-
ls = bm.for_loop(train, dyn_vars=train.vars().unique(), operands=bm.arange(i, i + 100, 1))
125+
ls = bm.for_loop(train, operands=bm.arange(i, i + 100, 1))
129126
print(f'Train {i + 100} epoch, loss = {bm.mean(ls):.4f}, used time {time.time() - t0:.4f} s')
130127
train_losses.append(ls)
131128

examples/training/SurrogateGrad_lif_fashion_mnist.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import brainpy as bp
1616
import brainpy.math as bm
17+
import brainpy_datasets as bd
1718

1819

1920
class SNN(bp.dyn.Network):
@@ -167,7 +168,7 @@ def loss_fun(predicts, targets):
167168
)
168169
trainer.fit(lambda: sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs),
169170
num_epoch=nb_epochs)
170-
return trainer.train_losses
171+
return trainer.get_hist_metric('fit')
171172

172173

173174
def compute_classification_accuracy(model, x_data, y_data, batch_size=128, nb_steps=100, nb_inputs=28 * 28):
@@ -198,16 +199,8 @@ def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100,
198199

199200
# load the dataset
200201
root = r"D:\data\fashion-mnist"
201-
train_dataset = bp.datasets.FashionMNIST(root,
202-
train=True,
203-
transform=None,
204-
target_transform=None,
205-
download=True)
206-
test_dataset = bp.datasets.FashionMNIST(root,
207-
train=False,
208-
transform=None,
209-
target_transform=None,
210-
download=True)
202+
train_dataset = bd.vision.FashionMNIST(root, split='train', download=True)
203+
test_dataset = bd.vision.FashionMNIST(root, split='test', download=True)
211204

212205
# Standardize data
213206
x_train = np.array(train_dataset.data, dtype=bm.dftype())
@@ -237,7 +230,7 @@ def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100,
237230

238231
nb_plt = 4
239232
gs = GridSpec(1, nb_plt)
240-
fig = plt.figure(figsize=(7, 3), dpi=150)
233+
plt.figure(figsize=(7, 3), dpi=150)
241234
for i in range(nb_plt):
242235
plt.subplot(gs[i])
243236
plt.imshow(bm.as_numpy(spikes[i]).T, cmap=plt.cm.gray_r, origin="lower")

examples/training/Sussillo_Abbott_2009_FORCE_Learning.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
# %% [markdown]
66
# Implementation of the paper:
77
#
8-
# - Sussillo, David, and Larry F. Abbott. "Generating coherent patterns of activity from chaotic neural networks." Neuron 63, no. 4 (2009): 544-557.
8+
# - Sussillo, David, and Larry F. Abbott. "Generating coherent patterns
9+
# of activity from chaotic neural networks."
10+
# Neuron 63, no. 4 (2009): 544-557.
911

1012
# %%
1113
import brainpy as bp
@@ -46,7 +48,7 @@ def __init__(self, num_input, num_hidden, num_output,
4648
self.w_rr = g * bm.random.normal(size=(num_hidden, num_hidden)) / bm.sqrt(num_hidden)
4749
self.w_or = bm.random.normal(size=(num_output, num_hidden))
4850
w_ro = bm.random.normal(size=(num_hidden, num_output)) / bm.sqrt(num_hidden)
49-
self.w_ro = bm.Variable(w_ro)
51+
self.w_ro = bm.Variable(w_ro) # dynamically change this weight
5052

5153
# variables
5254
self.h = bm.Variable(bm.random.normal(size=num_hidden) * 0.5) # hidden
@@ -62,6 +64,7 @@ def update(self, x):
6264
self.h += self.dt / self.tau * dhdt
6365
self.r.value = bm.tanh(self.h)
6466
self.o.value = bm.dot(self.r, self.w_ro)
67+
return self.r.value, self.o.value
6568

6669
def rls(self, target):
6770
# update the inverse correlation matrix
@@ -75,17 +78,14 @@ def rls(self, target):
7578
self.w_ro += dw
7679

7780
def simulate(self, xs):
78-
f = bm.make_loop(self.update, dyn_vars=self.vars(), out_vars=[self.r, self.o])
79-
return f(xs)
81+
return bm.for_loop(self.update, dyn_vars=self.vars(), operands=xs)
8082

8183
def train(self, xs, targets):
82-
def _f(x):
83-
input, target = x
84-
self.update(input)
84+
def _f(x, target):
85+
r, o = self.update(x)
8586
self.rls(target)
86-
87-
f = bm.make_loop(_f, dyn_vars=self.vars(), out_vars=[self.r, self.o])
88-
return f([xs, targets])
87+
return r, o
88+
return bm.for_loop(_f, dyn_vars=self.vars(), operands=[xs, targets])
8989

9090

9191
# %%

examples/training/integrator_rnn.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
num_batch = 128
1313

1414

15-
@partial(bm.jit,
16-
dyn_vars=bp.TensorCollector({'a': bm.random.DEFAULT}),
17-
static_argnames=['batch_size'])
15+
@partial(bm.jit, dyn_vars=bm.random.DEFAULT, static_argnames=['batch_size'])
1816
def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10):
1917
# Create the white noise input
2018
sample = bm.random.normal(size=(batch_size, 1, 1))
@@ -31,14 +29,15 @@ def train_data():
3129
yield build_inputs_and_targets(batch_size=num_batch)
3230

3331

34-
class RNN(bp.dyn.DynamicalSystem):
32+
class RNN(bp.DynamicalSystem):
3533
def __init__(self, num_in, num_hidden):
3634
super(RNN, self).__init__()
37-
self.rnn = bp.layers.VanillaRNN(num_in, num_hidden, train_state=True)
35+
self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True)
3836
self.out = bp.layers.Dense(num_hidden, 1)
3937

4038
def update(self, sha, x):
41-
return self.out(sha, self.rnn(sha, x))
39+
return self.out(sha,
40+
self.rnn(sha, x))
4241

4342

4443
model = RNN(1, 100)
@@ -58,11 +57,10 @@ def loss(predictions, targets, l2_reg=2e-4):
5857
# create a trainer
5958
trainer = bp.train.BPTT(model, loss_fun=loss, optimizer=opt)
6059
trainer.fit(train_data,
61-
batch_size=num_batch,
6260
num_epoch=30,
6361
num_report=200)
6462

65-
plt.plot(bm.as_numpy(trainer.train_losses))
63+
plt.plot(bm.as_numpy(trainer.get_hist_metric()))
6664
plt.show()
6765

6866
model.reset_state(1)

0 commit comments

Comments
 (0)