Skip to content

Commit 9c32ecc

Browse files
committed
fixed major recurrent_layer bug
1 parent 525ce60 commit 9c32ecc

File tree

2 files changed

+79
-105
lines changed

2 files changed

+79
-105
lines changed

nn4n/layer/recurrent_layer.py

Lines changed: 78 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -7,77 +7,36 @@
77

88

99
class RecurrentLayer(nn.Module):
10-
def __init__(
11-
self,
12-
hidden_size,
13-
positivity_constraints,
14-
sparsity_constraints,
15-
layer_distributions,
16-
layer_biases,
17-
layer_masks,
18-
preact_noise,
19-
postact_noise,
20-
learnable=True,
21-
**kwargs
22-
):
23-
"""
24-
Hidden layer of the RNN
25-
Parameters:
26-
@param hidden_size: number of hidden neurons
27-
@param positivity_constraints: whether to enforce positivity constraint
28-
@param sparsity_constraints: use sparsity_constraints or not
29-
@param layer_distributions: distribution of weights for each layer, a list of 3 strings
30-
@param layer_biases: use bias or not for each layer, a list of 3 boolean values
31-
32-
Keyword Arguments:
33-
@kwarg activation: activation function, default: "relu"
34-
@kwarg preact_noise: noise added to pre-activation, default: 0
35-
@kwarg postact_noise: noise added to post-activation, default: 0
36-
@kwarg dt: time step, default: 1
37-
@kwarg tau: time constant, default: 1
38-
39-
@kwarg input_dim: input dimension, default: 1
40-
41-
@kwarg hidden_dist: distribution of hidden layer weights, default: "normal"
42-
@kwarg self_connections: allow self connections or not, default: False
43-
@kwarg init_state: initial state of the network, 'zero', 'keep', or 'learn'
44-
"""
10+
"""
11+
Recurrent layer of the RNN. The layer is initialized by passing specs in layer_struct.
12+
13+
Required keywords in layer_struct:
14+
- activation: activation function, default: "relu"
15+
- preact_noise: noise added to pre-activation
16+
- postact_noise: noise added to post-activation
17+
- dt: time step, default: 10
18+
- tau: time constant, default: 100
19+
- init_state: initial state of the network. It defines the hidden state at t=0.
20+
- 'zero': all zeros
21+
- 'keep': keep the last state
22+
- 'learn': learn the initial state
23+
- in_struct: input layer layer_struct
24+
- hid_struct: hidden layer layer_struct
25+
"""
26+
def __init__(self, layer_struct, **kwargs):
4527
super().__init__()
46-
47-
self.hidden_size = hidden_size
48-
self.preact_noise = preact_noise
49-
self.postact_noise = postact_noise
50-
self.alpha = kwargs.get("dt", 10) / kwargs.get("tau", 100)
51-
self.layer_distributions = layer_distributions
52-
self.layer_biases = layer_biases
53-
self.layer_masks = layer_masks
28+
self.alpha = layer_struct['dt']/layer_struct['tau']
29+
self.hidden_size = layer_struct['hid_struct']['input_dim']
5430
self.hidden_state = torch.zeros(self.hidden_size)
55-
self.init_state = kwargs.get("init_state", 'zero')
56-
self.act = kwargs.get("activation", "relu")
31+
self.init_state = layer_struct['init_state']
32+
self.act = layer_struct['activation']
5733
self.activation = get_activation(self.act)
34+
self.preact_noise = kwargs.pop("preact_noise", 0)
35+
self.postact_noise = kwargs.pop("postact_noise", 0)
5836
self._set_hidden_state()
5937

60-
self.input_layer = LinearLayer(
61-
positivity_constraints=positivity_constraints[0],
62-
sparsity_constraints=sparsity_constraints[0],
63-
output_dim=self.hidden_size,
64-
input_dim=kwargs.pop("input_dim", 1),
65-
use_bias=self.layer_biases[0],
66-
dist=self.layer_distributions[0],
67-
mask=self.layer_masks[0],
68-
learnable=learnable[0],
69-
)
70-
self.hidden_layer = HiddenLayer(
71-
hidden_size=self.hidden_size,
72-
sparsity_constraints=sparsity_constraints[1],
73-
positivity_constraints=positivity_constraints[1],
74-
dist=self.layer_distributions[1],
75-
use_bias=self.layer_biases[1],
76-
scaling=kwargs.get("scaling", 1.0),
77-
mask=self.layer_masks[1],
78-
self_connections=kwargs.get("self_connections", False),
79-
learnable=learnable[1],
80-
)
38+
self.input_layer = LinearLayer(layer_struct=layer_struct['in_struct'])
39+
self.hidden_layer = HiddenLayer(layer_struct=layer_struct['hid_struct'])
8140

8241
# INITIALIZATION
8342
# ==================================================================================================
@@ -93,17 +52,59 @@ def _set_hidden_state(self):
9352

9453
# FORWARD
9554
# ==================================================================================================
55+
def to(self, device):
56+
""" Move the network to the device (cpu/gpu) """
57+
super().to(device)
58+
self.input_layer.to(device)
59+
self.hidden_layer.to(device)
60+
self.hidden_state = self.hidden_state.to(device)
61+
62+
def forward(self, x):
63+
"""
64+
Forwardly update network
65+
66+
Inputs:
67+
- x: input, shape: (n_timesteps, batch_size, input_dim)
68+
69+
Returns:
70+
- states: shape: (n_timesteps, batch_size, hidden_size)
71+
"""
72+
v_t = self._reset_state().to(x.device)
73+
fr_t = self.activation(v_t)
74+
# update hidden state and append to stacked_states
75+
stacked_states = []
76+
for i in range(x.size(0)):
77+
fr_t, v_t = self._recurrence(fr_t, v_t, x[i])
78+
# append to stacked_states
79+
stacked_states.append(fr_t)
80+
81+
# if keeping the last state, save it to hidden_state
82+
if self.init_state == 'keep':
83+
self.hidden_state = fr_t.detach().clone() # TODO: haven't tested this yet
84+
85+
return torch.stack(stacked_states, dim=0)
86+
9687
def _reset_state(self):
9788
if self.init_state == 'learn' or self.init_state == 'keep':
9889
return self.hidden_state
9990
else:
10091
return torch.zeros(self.hidden_size)
10192

93+
def apply_plasticity(self):
94+
""" Apply plasticity masks to the weight gradients """
95+
self.input_layer.apply_plasticity()
96+
self.hidden_layer.apply_plasticity()
97+
10298
def enforce_constraints(self):
99+
"""
100+
Enforce sparsity and excitatory/inhibitory constraints if applicable.
101+
This is by default automatically called after each forward pass,
102+
but can be called manually if needed
103+
"""
103104
self.input_layer.enforce_constraints()
104105
self.hidden_layer.enforce_constraints()
105106

106-
def recurrence(self, fr_t, v_t, u_t):
107+
def _recurrence(self, fr_t, v_t, u_t):
107108
""" Recurrence function """
108109
# through input layer
109110
v_in_u_t = self.input_layer(u_t) # u_t @ W_in
@@ -126,55 +127,28 @@ def recurrence(self, fr_t, v_t, u_t):
126127
fr_t = fr_t + postact_epsilon
127128

128129
return fr_t, v_t
129-
130-
def forward(self, input):
131-
"""
132-
Propogate input through the network.
133-
@param input: shape=(seq_len, batch, input_dim), network input
134-
@return stacked_states: shape=(seq_len, batch, hidden_size), stack of hidden layer status
135-
"""
136-
v_t = self._reset_state().to(input.device)
137-
fr_t = self.activation(v_t)
138-
# update hidden state and append to stacked_states
139-
stacked_states = []
140-
for i in range(input.size(0)):
141-
fr_t, v_t = self.recurrence(fr_t, v_t, input[i])
142-
# append to stacked_states
143-
stacked_states.append(fr_t)
144-
145-
# if keeping the last state, save it to hidden_state
146-
if self.init_state == 'keep':
147-
self.hidden_state = fr_t.detach().clone() # TODO: haven't tested this yet
148-
149-
return torch.stack(stacked_states, dim=0)
150130
# ==================================================================================================
151131

152132
# HELPER FUNCTIONS
153133
# ==================================================================================================
154-
def to(self, device):
155-
"""
156-
Move the network to the device (cpu/gpu)
157-
"""
158-
super().to(device)
159-
self.input_layer.to(device)
160-
self.hidden_layer.to(device)
161-
self.hidden_state = self.hidden_state.to(device)
134+
def plot_layers(self, **kwargs):
135+
""" Plot the weights matrix and distribution of each layer """
136+
self.input_layer.plot_layers()
137+
self.hidden_layer.plot_layers()
162138

163139
def print_layers(self):
140+
""" Print the weights matrix and distribution of each layer """
164141
param_dict = {
165-
"hidden_min": self.hidden_state.min(),
166-
"hidden_max": self.hidden_state.max(),
167-
"hidden_mean": self.hidden_state.mean(),
142+
"init_hidden_min": self.hidden_state.min(),
143+
"init_hidden_max": self.hidden_state.max(),
168144
"preact_noise": self.preact_noise,
169145
"postact_noise": self.postact_noise,
170146
"activation": self.act,
171147
"alpha": self.alpha,
148+
"init_state": self.init_state,
149+
"init_state_learnable": self.hidden_state.requires_grad,
172150
}
173151
self.input_layer.print_layers()
174152
print_dict("Recurrence", param_dict)
175153
self.hidden_layer.print_layers()
176-
177-
def plot_layers(self, **kwargs):
178-
self.input_layer.plot_layers()
179-
self.hidden_layer.plot_layers()
180-
# ==================================================================================================
154+
# ==================================================================================================

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setup(
88
name='nn4n',
9-
version='1.1.0',
9+
version='1.1.1',
1010
description='Neural Networks for Neuroscience Research',
1111
long_description=long_description,
1212
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)