-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathneuron.py
More file actions
283 lines (218 loc) · 10.6 KB
/
neuron.py
File metadata and controls
283 lines (218 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class SpikingActivation(torch.autograd.Function):
"""
Custom autograd function for spiking activation.
This function implements a spiking activation function with a Gaussian-based surrogate gradient for backpropagation.
(Based on https://github.com/byin-cwi/sFPTT)
"""
@staticmethod
def forward(ctx, input, scale=6.0, height=0.15, lens=0.3, gamma=0.5):
"""
Forward pass for the custom neuron function.
Returns:
torch.Tensor: Output tensor where each element is 1 if the corresponding input element is greater than 0, otherwise 0.
"""
ctx.save_for_backward(input)
ctx.scale = scale
ctx.height = height
ctx.lens = lens
ctx.gamma = gamma
return input.gt(0).float()
@staticmethod
def backward(ctx, grad_output):
"""
Performs the backward pass for the custom spiking activation function using a surrogate gradient.
Args:
ctx: The context object that contains saved tensors and other information.
grad_output: The gradient of the loss with respect to the output of the forward pass.
Returns:
A tuple containing the gradient of the loss with respect to the input of the forward pass,
and None for other parameters (to match the expected return signature).
"""
input, = ctx.saved_tensors
grad_input = grad_output.clone()
temp = SpikingActivation.gaussian(input, mu=0., sigma=ctx.lens) * (1. + ctx.height) \
- SpikingActivation.gaussian(input, mu=ctx.lens, sigma=ctx.scale * ctx.lens) * ctx.height \
- SpikingActivation.gaussian(input, mu=-ctx.lens, sigma=ctx.scale * ctx.lens) * ctx.height
return grad_input * temp.float() * ctx.gamma, None, None, None, None
@staticmethod
def gaussian(x, mu=0., sigma=.5):
"""
Computes the Gaussian function.
Args:
x (torch.Tensor): The input tensor.
mu (float, optional): The mean of the Gaussian distribution. Default is 0.
sigma (float, optional): The standard deviation of the Gaussian distribution. Default is 0.5.
Returns:
torch.Tensor: The result of applying the Gaussian function to the input tensor.
"""
return torch.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / torch.sqrt(2 * torch.tensor(np.pi)) / sigma
class AdaptiveSpikingNeuron(nn.Module):
"""
AdaptiveSpikingNeuron is a PyTorch module that models an Adaptive Leaky Integrate and Fire neuron.
Args:
channels (int): Number of input channels for convolutional layers.
linear_features (int): Number of input features for linear layers.
kernel_size (int): Size of the convolutional kernel.
concatenate (bool): Whether to concatenate inputs with previous states.
disable_reset (bool): Whether to disable the reset mechanism.
b_j0 (float): Initial threshold value.
beta (float): Scaling factor for the adaptive threshold.
"""
def __init__(self, channels=None, linear_features=None, kernel_size=None, concatenate=False, disable_reset=False, b_j0=0.5, beta=1.8):
super(AdaptiveSpikingNeuron, self).__init__()
self.channels = channels
self.linear_features = linear_features
self.kernel_size = kernel_size
self.concatenate = concatenate
self.disable_reset = disable_reset
self.b_j0 = b_j0
self.beta = beta
assert (
(
linear_features != None and
kernel_size == None
) or
(
linear_features == None and
kernel_size != None
)
), "Specify either linear features or specify kernel_size, not both!"
if linear_features == None:
self.layer_tau_m = nn.Conv2d(channels * (2 if concatenate else 1), channels, kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
self.batch_norm_tau_m = nn.BatchNorm2d(channels)
if not self.disable_reset:
self.layer_tau_adp = nn.Conv2d(channels * (2 if concatenate else 1), channels, kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
self.batch_norm_tau_adp = nn.BatchNorm2d(channels)
else:
self.layer_tau_m = nn.Linear(linear_features * (2 if concatenate else 1), linear_features)
self.batch_norm_tau_m = nn.BatchNorm1d(linear_features)
if not self.disable_reset:
self.layer_tau_adp = nn.Linear(linear_features * (2 if concatenate else 1), linear_features)
self.batch_norm_tau_adp = nn.BatchNorm1d(linear_features)
self.activation_tau_m = nn.Sigmoid()
if not self.disable_reset:
self.activation_tau_adp = nn.Sigmoid()
nn.init.xavier_uniform_(self.layer_tau_m.weight)
if not self.disable_reset:
nn.init.xavier_uniform_(self.layer_tau_adp.weight)
if linear_features == None:
nn.init.constant_(self.layer_tau_m.bias,0)
if not self.disable_reset:
nn.init.constant_(self.layer_tau_adp.bias,0)
spk = torch.zeros(0)
self.register_buffer("spk", spk, False)
mem = torch.zeros(0)
self.register_buffer("mem", mem, False)
b = torch.zeros(0)
self.register_buffer("b", b, False)
def forward(self, x):
"""
Perform the forward pass of the neuron model.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying the neuron model.
"""
if not self.spk.shape == x.shape:
self.spk = torch.zeros_like(x, device=self.spk.device)
if not self.mem.shape == x.shape:
self.mem = torch.zeros_like(x, device=self.mem.device)
if not self.b.shape == x.shape:
self.b = torch.full_like(x, self.b_j0, device=self.b.device)
tau_m = self.activation_tau_m(self.batch_norm_tau_m(self.layer_tau_m(torch.cat((x, self.mem), dim=(1 if self.linear_features == None else -1)) if self.concatenate else (x + self.mem))))
d_mem = -self.mem + x
self.mem += d_mem * tau_m
if self.disable_reset:
return self.mem
else:
tau_adp = self.activation_tau_adp(self.batch_norm_tau_adp(self.layer_tau_adp(torch.cat((x, self.b), dim=(1 if self.linear_features == None else -1)) if self.concatenate else (x + self.b))))
self.b = tau_adp * self.b + (1 - tau_adp) * self.spk
B = self.b_j0 + self.beta * self.b
self.spk = SpikingActivation.apply(self.mem - B)
self.mem = (1 - self.spk) * self.mem
return self.spk
def reset_mem(self):
"""
Resets the neuron state variables to their initial values.
This method resets the following attributes:
- `spk`: Sets the spike tensor to zeros, maintaining the same shape and device.
- `mem`: Sets the membrane potential tensor to zeros, maintaining the same shape and device.
- `b`: Resets the bias tensor to its initial value `b_j0`, maintaining the same shape and device.
"""
self.spk = torch.zeros_like(self.spk, device=self.spk.device)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
self.b = torch.full_like(self.b, self.b_j0, device=self.b.device)
def detach_hidden(self):
"""
Detach the hidden state variables from the current computation graph.
This method detaches the spiking activity (spk), membrane potential (mem), and bias (b) from the current computation graph,
which is useful for truncated Backpropagation Through Time.
"""
self.spk.detach_()
self.mem.detach_()
self.b.detach_()
class ParametricSpikingNeuron(nn.Module):
def __init__(self, disable_reset=False, b_j0=0.5, beta=1.8):
super(ParametricSpikingNeuron, self).__init__()
self.disable_reset = disable_reset
self.b_j0 = b_j0
self.beta = beta
self.tau_m = nn.Parameter(torch.tensor(0.5), requires_grad=True)
if not self.disable_reset:
self.tau_adp = nn.Parameter(torch.tensor(0.5), requires_grad=True)
spk = torch.zeros(0)
self.register_buffer("spk", spk, False)
mem = torch.zeros(0)
self.register_buffer("mem", mem, False)
b = torch.zeros(0)
self.register_buffer("b", b, False)
def forward(self, x):
"""
Perform the forward pass of the neuron model.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying the neuron model.
"""
if not self.spk.shape == x.shape:
self.spk = torch.zeros_like(x, device=self.spk.device)
if not self.mem.shape == x.shape:
self.mem = torch.zeros_like(x, device=self.mem.device)
if not self.b.shape == x.shape:
self.b = torch.full_like(x, self.b_j0, device=self.b.device)
tau_m = F.sigmoid(self.tau_m)
d_mem = -self.mem + x
self.mem += d_mem * tau_m
if self.disable_reset:
return self.mem
else:
tau_adp = F.sigmoid(self.tau_adp)
self.b = tau_adp * self.b + (1 - tau_adp) * self.spk
B = self.b_j0 + self.beta * self.b
self.spk = SpikingActivation.apply(self.mem - B)
self.mem = (1 - self.spk) * self.mem
return self.spk
def reset_mem(self):
"""
Resets the neuron state variables to their initial values.
This method resets the following attributes:
- `spk`: Sets the spike tensor to zeros, maintaining the same shape and device.
- `mem`: Sets the membrane potential tensor to zeros, maintaining the same shape and device.
- `b`: Resets the bias tensor to its initial value `b_j0`, maintaining the same shape and device.
"""
self.spk = torch.zeros_like(self.spk, device=self.spk.device)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
self.b = torch.full_like(self.b, self.b_j0, device=self.b.device)
def detach_hidden(self):
"""
Detach the hidden state variables from the current computation graph.
This method detaches the spiking activity (spk), membrane potential (mem), and bias (b) from the current computation graph,
which is useful for truncated Backpropagation Through Time.
"""
self.spk.detach_()
self.mem.detach_()
self.b.detach_()