Skip to content

Commit 624ce46

Browse files
author
Bodo Rueckauer
committed
2 parents 82c6ac7 + 99a6600 commit 624ce46

File tree

2 files changed

+28
-26
lines changed

2 files changed

+28
-26
lines changed

examples/mnist_keras_brian2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
'duration': 50, # Number of time steps to run each sample.
124124
'num_to_test': 5, # How many test samples to run.
125125
'batch_size': 1, # Batch size for simulation.
126+
'dt': 0.1 # Time interval for the differential equations to be solved over.
126127
}
127128

128129
config['input'] = {

snntoolbox/simulation/target_simulators/brian2_target_sim.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,12 @@ def __init__(self, config, queue=None):
6161
self.layers = []
6262
self.connections = [] # Final container for all layers.
6363
self.threshold = 'v >= v_thresh'
64-
self.v_reset = 'v = v_reset'
65-
self.eqs = 'v : 1'
64+
if 'subtraction' in config.get('cell', 'reset'):
65+
self.v_reset = 'v = v - v_thresh'
66+
else:
67+
self.v_reset = 'v = v_reset'
68+
self.eqs = ''' dv/dt = bias : 1
69+
bias : hertz'''
6670
self.spikemonitors = []
6771
self.statemonitors = []
6872
self.snn = None
@@ -78,9 +82,15 @@ def is_parallelizable(self):
7882

7983
def add_input_layer(self, input_shape):
8084

81-
self.layers.append(self.sim.PoissonGroup(
82-
np.prod(input_shape[1:]), rates=0 * self.sim.Hz,
83-
dt=self._dt * self.sim.ms))
85+
if self._poisson_input:
86+
self.layers.append(self.sim.PoissonGroup(
87+
np.prod(input_shape[1:]), rates=0*self.sim.Hz,
88+
dt=self._dt*self.sim.ms))
89+
else:
90+
self.layers.append(self.sim.NeuronGroup(
91+
np.prod(input_shape[1:]), model=self.eqs, method='euler',
92+
reset=self.v_reset, threshold=self.threshold,
93+
dt=self._dt * self.sim.ms))
8494
self.layers[0].add_attribute('label')
8595
self.layers[0].label = 'InputLayer'
8696
self.spikemonitors.append(self.sim.SpikeMonitor(self.layers[0]))
@@ -98,7 +108,7 @@ def add_layer(self, layer):
98108
return
99109

100110
self.layers.append(self.sim.NeuronGroup(
101-
np.prod(layer.output_shape[1:]), model=self.eqs, method='linear',
111+
np.prod(layer.output_shape[1:]), model=self.eqs, method='euler',
102112
reset=self.v_reset, threshold=self.threshold,
103113
dt=self._dt * self.sim.ms))
104114
self.connections.append(self.sim.Synapses(
@@ -123,10 +133,11 @@ def build_dense(self, layer, weights=None):
123133
if weights is None:
124134
weights = _weights
125135

126-
set_biases(biases)
136+
self.set_biases(biases)
127137

128138
delay = self.config.getfloat('cell', 'delay')
129139
connections = []
140+
130141
if len(self.flatten_shapes) == 1:
131142
print("Swapping data_format of Flatten layer.")
132143
flatten_name, shape = self.flatten_shapes.pop()
@@ -170,7 +181,7 @@ def build_convolution(self, layer, weights=None):
170181
conns, biases = build_convolution(layer, delay, transpose_kernel)
171182
connections = np.array(conns)
172183

173-
set_biases(biases)
184+
self.set_biases(biases)
174185

175186
print("Connecting layer...")
176187

@@ -202,7 +213,7 @@ def compile(self):
202213

203214
# Set input layer
204215
for obj in self.snn.objects:
205-
if 'poissongroup' in obj.name and 'thresholder' not in obj.name:
216+
if hasattr(obj, 'label') and obj.label == 'InputLayer':
206217
self._input_layer = obj
207218
assert self._input_layer, "No input layer found."
208219

@@ -215,11 +226,7 @@ def simulate(self, **kwargs):
215226
# TODO: Implement by using brian2.SpikeGeneratorGroup.
216227
raise NotImplementedError
217228
else:
218-
try:
219-
# TODO: Implement constant input by using brian2.TimedArray.
220-
self._input_layer.current = kwargs[str('x_b_l')].flatten()
221-
except AttributeError:
222-
raise NotImplementedError
229+
self._input_layer.bias = kwargs[str('x_b_l')].flatten() / self.sim.ms
223230

224231
self.snn.run(self._duration * self.sim.ms, namespace=self._cell_params,
225232
report='stdout', report_period=10 * self.sim.ms)
@@ -377,15 +384,9 @@ def get_vmem(self, **kwargs):
377384
def set_spiketrain_stats_input(self):
378385
AbstractSNN.set_spiketrain_stats_input(self)
379386

380-
381-
def set_biases(biases):
382-
"""Set biases.
383-
384-
Notes
385-
-----
386-
387-
This has not been tested yet.
388-
"""
389-
390-
if any(biases): # TODO: Implement biases.
391-
warnings.warn("Biases not implemented.", RuntimeWarning)
387+
def set_biases(self, biases):
388+
"""Set biases.
389+
"""
390+
if any(biases):
391+
assert self.layers[-1].bias.shape == biases.shape, "Shape of biases and network do not match."
392+
self.layers[-1].bias = biases * 1000 * self.sim.Hz

0 commit comments

Comments
 (0)