Skip to content

Commit f361291

Browse files
author
Bodo Rueckauer
committed
Updated 'temporal_pattern' code to work with tensorflow 2.
1 parent 0b0b8a2 commit f361291

File tree

3 files changed

+50
-16
lines changed

3 files changed

+50
-16
lines changed

snntoolbox/simulation/backends/inisim/temporal_pattern.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def init_neurons(self, input_shape):
7575
self.spikerates = k.zeros(output_shape)
7676

7777
def update_spikevars(self, x):
78-
return [tf.assign(self.spikerates, x)]
78+
return [self.spikerates.assign(x)]
7979

8080

8181
def spike_call(call):
@@ -123,7 +123,7 @@ def to_binary(x, num_bits):
123123
distributed across the first dimension of ``binary_array``.
124124
"""
125125

126-
shape = k.shape(x)
126+
shape = k.int_shape(x)
127127

128128
binary_array = k.zeros([num_bits] + list(shape[1:]), k.floatx())
129129

@@ -177,7 +177,7 @@ def iterate_powers(act_value, idx_p, idx_l, idx_m, idx_n):
177177
p = powers[idx_p]
178178
c = k.greater_equal(act_value, p)
179179
b = tf.cond(c, lambda: 1., lambda: 0.)
180-
a = tf.assign(binary_array[idx_p, idx_l, idx_m, idx_n], b)
180+
a = binary_array[idx_p, idx_l, idx_m, idx_n].assign(b)
181181
new_act_value = tf.cond(c, lambda: act_value - p,
182182
lambda: act_value)
183183
with tf.control_dependencies([a]):
@@ -210,7 +210,7 @@ def iterate_powers(act_value, idx_p, idx_l):
210210
p = powers[idx_p]
211211
c = k.greater_equal(act_value, p)
212212
b = tf.cond(c, lambda: 1., lambda: 0.)
213-
a = tf.assign(binary_array[idx_p, idx_l], b)
213+
a = binary_array[idx_p, idx_l].assign(b)
214214
new_act_value = tf.cond(c, lambda: act_value - p,
215215
lambda: act_value)
216216
with tf.control_dependencies([a]):
@@ -249,20 +249,20 @@ def to_binary_numpy(x, num_bits):
249249
powers = [2**-i for i in range(num_bits)]
250250

251251
if len(x.shape) > 2:
252-
for l in range(x.shape[1]):
252+
for j in range(x.shape[1]):
253253
for m in range(x.shape[2]):
254254
for n in range(x.shape[3]):
255-
f = x[0, l, m, n]
255+
f = x[0, j, m, n]
256256
for i in range(num_bits):
257257
if f >= powers[i]:
258-
binary_array[i, l, m, n] = 1
258+
binary_array[i, j, m, n] = 1
259259
f -= powers[i]
260260
else:
261-
for l in range(x.shape[1]):
262-
f = x[0, l]
261+
for j in range(x.shape[1]):
262+
f = x[0, j]
263263
for i in range(num_bits):
264264
if f >= powers[i]:
265-
binary_array[i, l] = 1
265+
binary_array[i, j] = 1
266266
f -= powers[i]
267267
return binary_array
268268

snntoolbox/simulation/target_simulators/INI_temporal_mean_rate_target_sim.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,23 @@ def build_pooling(self, layer):
103103

104104
def compile(self):
105105

106+
self.snn = keras.models.Model(
107+
self._input_images,
108+
self._spiking_layers[self.parsed_model.layers[-1].name])
109+
self.snn.compile('sgd', 'categorical_crossentropy', ['accuracy'])
110+
111+
# Tensorflow 2 lists all variables as weights, including our state
112+
# variables (membrane potential etc). So a simple
113+
# snn.set_weights(parsed_model.get_weights()) does not work any more.
114+
# Need to extract the actual weights here.
115+
106116
def remove_name_counter(name_in):
107117
splits = str(name_in).split('_')
108118
name_out = splits[0] + '_' + splits[1]
109119
if len(splits) == 3:
110-
name_out += re.sub('\d+/', '/', splits[2])
120+
name_out += re.sub(r'\d+/', '/', splits[2])
111121
return name_out
112122

113-
self.snn = keras.models.Model(
114-
self._input_images,
115-
self._spiking_layers[self.parsed_model.layers[-1].name])
116-
self.snn.compile('sgd', 'categorical_crossentropy', ['accuracy'])
117123
parameter_map = {remove_name_counter(p.name): v for p, v in
118124
zip(self.parsed_model.weights,
119125
self.parsed_model.get_weights())}
@@ -125,6 +131,7 @@ def remove_name_counter(name_in):
125131
count += 1
126132
assert count == len(parameter_map), "Not all weights have been " \
127133
"transferred from ANN to SNN."
134+
128135
for layer in self.snn.layers:
129136
if hasattr(layer, 'bias'):
130137
# Adjust biases to time resolution of simulator.

snntoolbox/simulation/target_simulators/INI_temporal_pattern_target_sim.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from __future__ import division, absolute_import
88
from __future__ import print_function, unicode_literals
99

10+
import re
11+
1012
import keras
1113
import numpy as np
1214
from future import standard_library
@@ -45,11 +47,36 @@ def __init__(self, config, queue=None):
4547
self.num_bits = self.config.getint('conversion', 'num_bits')
4648

4749
def compile(self):
50+
4851
self.snn = keras.models.Model(
4952
self._input_images,
5053
self._spiking_layers[self.parsed_model.layers[-1].name])
5154
self.snn.compile('sgd', 'categorical_crossentropy', ['accuracy'])
52-
self.snn.set_weights(self.parsed_model.get_weights())
55+
56+
# Tensorflow 2 lists all variables as weights, including our state
57+
# variables (membrane potential etc). So a simple
58+
# snn.set_weights(parsed_model.get_weights()) does not work any more.
59+
# Need to extract the actual weights here:
60+
61+
def remove_name_counter(name_in):
62+
splits = str(name_in).split('_')
63+
name_out = splits[0] + '_' + splits[1]
64+
if len(splits) == 3:
65+
name_out += re.sub(r'\d+/', '/', splits[2])
66+
return name_out
67+
68+
parameter_map = {remove_name_counter(p.name): v for p, v in
69+
zip(self.parsed_model.weights,
70+
self.parsed_model.get_weights())}
71+
count = 0
72+
for p in self.snn.weights:
73+
name = remove_name_counter(p.name)
74+
if name in parameter_map:
75+
keras.backend.set_value(p, parameter_map[name])
76+
count += 1
77+
assert count == len(parameter_map), "Not all weights have been " \
78+
"transferred from ANN to SNN."
79+
5380
for layer in self.snn.layers:
5481
if hasattr(layer, 'bias'):
5582
# Adjust biases to time resolution of simulator.

0 commit comments

Comments
 (0)