@@ -61,8 +61,12 @@ def __init__(self, config, queue=None):
6161 self .layers = []
6262 self .connections = [] # Final container for all layers.
6363 self .threshold = 'v >= v_thresh'
64- self .v_reset = 'v = v_reset'
65- self .eqs = 'v : 1'
64+ if 'subtraction' in config .get ('cell' , 'reset' ):
65+ self .v_reset = 'v = v - v_thresh'
66+ else :
67+ self .v_reset = 'v = v_reset'
68+ self .eqs = ''' dv/dt = bias : 1
69+ bias : hertz'''
6670 self .spikemonitors = []
6771 self .statemonitors = []
6872 self .snn = None
@@ -78,9 +82,15 @@ def is_parallelizable(self):
7882
7983 def add_input_layer (self , input_shape ):
8084
81- self .layers .append (self .sim .PoissonGroup (
82- np .prod (input_shape [1 :]), rates = 0 * self .sim .Hz ,
83- dt = self ._dt * self .sim .ms ))
85+ if self ._poisson_input :
86+ self .layers .append (self .sim .PoissonGroup (
87+ np .prod (input_shape [1 :]), rates = 0 * self .sim .Hz ,
88+ dt = self ._dt * self .sim .ms ))
89+ else :
90+ self .layers .append (self .sim .NeuronGroup (
91+ np .prod (input_shape [1 :]), model = self .eqs , method = 'euler' ,
92+ reset = self .v_reset , threshold = self .threshold ,
93+ dt = self ._dt * self .sim .ms ))
8494 self .layers [0 ].add_attribute ('label' )
8595 self .layers [0 ].label = 'InputLayer'
8696 self .spikemonitors .append (self .sim .SpikeMonitor (self .layers [0 ]))
@@ -98,7 +108,7 @@ def add_layer(self, layer):
98108 return
99109
100110 self .layers .append (self .sim .NeuronGroup (
101- np .prod (layer .output_shape [1 :]), model = self .eqs , method = 'linear ' ,
111+ np .prod (layer .output_shape [1 :]), model = self .eqs , method = 'euler ' ,
102112 reset = self .v_reset , threshold = self .threshold ,
103113 dt = self ._dt * self .sim .ms ))
104114 self .connections .append (self .sim .Synapses (
@@ -123,10 +133,11 @@ def build_dense(self, layer, weights=None):
123133 if weights is None :
124134 weights = _weights
125135
126- set_biases (biases )
136+ self . set_biases (biases )
127137
128138 delay = self .config .getfloat ('cell' , 'delay' )
129139 connections = []
140+
130141 if len (self .flatten_shapes ) == 1 :
131142 print ("Swapping data_format of Flatten layer." )
132143 flatten_name , shape = self .flatten_shapes .pop ()
@@ -170,7 +181,7 @@ def build_convolution(self, layer, weights=None):
170181 conns , biases = build_convolution (layer , delay , transpose_kernel )
171182 connections = np .array (conns )
172183
173- set_biases (biases )
184+ self . set_biases (biases )
174185
175186 print ("Connecting layer..." )
176187
@@ -202,7 +213,7 @@ def compile(self):
202213
203214 # Set input layer
204215 for obj in self .snn .objects :
205- if 'poissongroup' in obj .name and 'thresholder' not in obj . name :
216+ if hasattr ( obj , 'label' ) and obj .label == 'InputLayer' :
206217 self ._input_layer = obj
207218 assert self ._input_layer , "No input layer found."
208219
@@ -215,11 +226,7 @@ def simulate(self, **kwargs):
215226 # TODO: Implement by using brian2.SpikeGeneratorGroup.
216227 raise NotImplementedError
217228 else :
218- try :
219- # TODO: Implement constant input by using brian2.TimedArray.
220- self ._input_layer .current = kwargs [str ('x_b_l' )].flatten ()
221- except AttributeError :
222- raise NotImplementedError
229+ self ._input_layer .bias = kwargs [str ('x_b_l' )].flatten () / self .sim .ms
223230
224231 self .snn .run (self ._duration * self .sim .ms , namespace = self ._cell_params ,
225232 report = 'stdout' , report_period = 10 * self .sim .ms )
@@ -377,15 +384,9 @@ def get_vmem(self, **kwargs):
377384 def set_spiketrain_stats_input (self ):
378385 AbstractSNN .set_spiketrain_stats_input (self )
379386
380-
381- def set_biases (biases ):
382- """Set biases.
383-
384- Notes
385- -----
386-
387- This has not been tested yet.
388- """
389-
390- if any (biases ): # TODO: Implement biases.
391- warnings .warn ("Biases not implemented." , RuntimeWarning )
387+ def set_biases (self , biases ):
388+ """Set biases.
389+ """
390+ if any (biases ):
391+ assert self .layers [- 1 ].bias .shape == biases .shape , "Shape of biases and network do not match."
392+ self .layers [- 1 ].bias = biases * 1000 * self .sim .Hz
0 commit comments