77
88
99class RecurrentLayer (nn .Module ):
10- def __init__ (
11- self ,
12- hidden_size ,
13- positivity_constraints ,
14- sparsity_constraints ,
15- layer_distributions ,
16- layer_biases ,
17- layer_masks ,
18- preact_noise ,
19- postact_noise ,
20- learnable = True ,
21- ** kwargs
22- ):
23- """
24- Hidden layer of the RNN
25- Parameters:
26- @param hidden_size: number of hidden neurons
27- @param positivity_constraints: whether to enforce positivity constraint
28- @param sparsity_constraints: use sparsity_constraints or not
29- @param layer_distributions: distribution of weights for each layer, a list of 3 strings
30- @param layer_biases: use bias or not for each layer, a list of 3 boolean values
31-
32- Keyword Arguments:
33- @kwarg activation: activation function, default: "relu"
34- @kwarg preact_noise: noise added to pre-activation, default: 0
35- @kwarg postact_noise: noise added to post-activation, default: 0
36- @kwarg dt: time step, default: 1
37- @kwarg tau: time constant, default: 1
38-
39- @kwarg input_dim: input dimension, default: 1
40-
41- @kwarg hidden_dist: distribution of hidden layer weights, default: "normal"
42- @kwarg self_connections: allow self connections or not, default: False
43- @kwarg init_state: initial state of the network, 'zero', 'keep', or 'learn'
44- """
10+ """
11+ Recurrent layer of the RNN. The layer is initialized by passing specs in layer_struct.
12+
13+ Required keywords in layer_struct:
14+ - activation: activation function, default: "relu"
15+ - preact_noise: noise added to pre-activation
16+ - postact_noise: noise added to post-activation
17+ - dt: time step, default: 10
18+ - tau: time constant, default: 100
19+ - init_state: initial state of the network. It defines the hidden state at t=0.
20+ - 'zero': all zeros
21+ - 'keep': keep the last state
22+ - 'learn': learn the initial state
23+ - in_struct: input layer layer_struct
24+ - hid_struct: hidden layer layer_struct
25+ """
26+ def __init__ (self , layer_struct , ** kwargs ):
4527 super ().__init__ ()
46-
47- self .hidden_size = hidden_size
48- self .preact_noise = preact_noise
49- self .postact_noise = postact_noise
50- self .alpha = kwargs .get ("dt" , 10 ) / kwargs .get ("tau" , 100 )
51- self .layer_distributions = layer_distributions
52- self .layer_biases = layer_biases
53- self .layer_masks = layer_masks
28+ self .alpha = layer_struct ['dt' ]/ layer_struct ['tau' ]
29+ self .hidden_size = layer_struct ['hid_struct' ]['input_dim' ]
5430 self .hidden_state = torch .zeros (self .hidden_size )
55- self .init_state = kwargs . get ( " init_state" , 'zero' )
56- self .act = kwargs . get ( " activation" , "relu" )
31+ self .init_state = layer_struct [ ' init_state' ]
32+ self .act = layer_struct [ ' activation' ]
5733 self .activation = get_activation (self .act )
34+ self .preact_noise = kwargs .pop ("preact_noise" , 0 )
35+ self .postact_noise = kwargs .pop ("postact_noise" , 0 )
5836 self ._set_hidden_state ()
5937
60- self .input_layer = LinearLayer (
61- positivity_constraints = positivity_constraints [0 ],
62- sparsity_constraints = sparsity_constraints [0 ],
63- output_dim = self .hidden_size ,
64- input_dim = kwargs .pop ("input_dim" , 1 ),
65- use_bias = self .layer_biases [0 ],
66- dist = self .layer_distributions [0 ],
67- mask = self .layer_masks [0 ],
68- learnable = learnable [0 ],
69- )
70- self .hidden_layer = HiddenLayer (
71- hidden_size = self .hidden_size ,
72- sparsity_constraints = sparsity_constraints [1 ],
73- positivity_constraints = positivity_constraints [1 ],
74- dist = self .layer_distributions [1 ],
75- use_bias = self .layer_biases [1 ],
76- scaling = kwargs .get ("scaling" , 1.0 ),
77- mask = self .layer_masks [1 ],
78- self_connections = kwargs .get ("self_connections" , False ),
79- learnable = learnable [1 ],
80- )
38+ self .input_layer = LinearLayer (layer_struct = layer_struct ['in_struct' ])
39+ self .hidden_layer = HiddenLayer (layer_struct = layer_struct ['hid_struct' ])
8140
8241 # INITIALIZATION
8342 # ==================================================================================================
@@ -93,17 +52,59 @@ def _set_hidden_state(self):
9352
9453 # FORWARD
9554 # ==================================================================================================
55+ def to (self , device ):
56+ """ Move the network to the device (cpu/gpu) """
57+ super ().to (device )
58+ self .input_layer .to (device )
59+ self .hidden_layer .to (device )
60+ self .hidden_state = self .hidden_state .to (device )
61+
62+ def forward (self , x ):
63+ """
64+ Forwardly update network
65+
66+ Inputs:
67+ - x: input, shape: (n_timesteps, batch_size, input_dim)
68+
69+ Returns:
70+ - states: shape: (n_timesteps, batch_size, hidden_size)
71+ """
72+ v_t = self ._reset_state ().to (x .device )
73+ fr_t = self .activation (v_t )
74+ # update hidden state and append to stacked_states
75+ stacked_states = []
76+ for i in range (x .size (0 )):
77+ fr_t , v_t = self ._recurrence (fr_t , v_t , x [i ])
78+ # append to stacked_states
79+ stacked_states .append (fr_t )
80+
81+ # if keeping the last state, save it to hidden_state
82+ if self .init_state == 'keep' :
83+ self .hidden_state = fr_t .detach ().clone () # TODO: haven't tested this yet
84+
85+ return torch .stack (stacked_states , dim = 0 )
86+
9687 def _reset_state (self ):
9788 if self .init_state == 'learn' or self .init_state == 'keep' :
9889 return self .hidden_state
9990 else :
10091 return torch .zeros (self .hidden_size )
10192
93+ def apply_plasticity (self ):
94+ """ Apply plasticity masks to the weight gradients """
95+ self .input_layer .apply_plasticity ()
96+ self .hidden_layer .apply_plasticity ()
97+
10298 def enforce_constraints (self ):
99+ """
100+ Enforce sparsity and excitatory/inhibitory constraints if applicable.
101+ This is by default automatically called after each forward pass,
102+ but can be called manually if needed
103+ """
103104 self .input_layer .enforce_constraints ()
104105 self .hidden_layer .enforce_constraints ()
105106
106- def recurrence (self , fr_t , v_t , u_t ):
107+ def _recurrence (self , fr_t , v_t , u_t ):
107108 """ Recurrence function """
108109 # through input layer
109110 v_in_u_t = self .input_layer (u_t ) # u_t @ W_in
@@ -126,55 +127,28 @@ def recurrence(self, fr_t, v_t, u_t):
126127 fr_t = fr_t + postact_epsilon
127128
128129 return fr_t , v_t
129-
130- def forward (self , input ):
131- """
132- Propogate input through the network.
133- @param input: shape=(seq_len, batch, input_dim), network input
134- @return stacked_states: shape=(seq_len, batch, hidden_size), stack of hidden layer status
135- """
136- v_t = self ._reset_state ().to (input .device )
137- fr_t = self .activation (v_t )
138- # update hidden state and append to stacked_states
139- stacked_states = []
140- for i in range (input .size (0 )):
141- fr_t , v_t = self .recurrence (fr_t , v_t , input [i ])
142- # append to stacked_states
143- stacked_states .append (fr_t )
144-
145- # if keeping the last state, save it to hidden_state
146- if self .init_state == 'keep' :
147- self .hidden_state = fr_t .detach ().clone () # TODO: haven't tested this yet
148-
149- return torch .stack (stacked_states , dim = 0 )
150130 # ==================================================================================================
151131
152132 # HELPER FUNCTIONS
153133 # ==================================================================================================
154- def to (self , device ):
155- """
156- Move the network to the device (cpu/gpu)
157- """
158- super ().to (device )
159- self .input_layer .to (device )
160- self .hidden_layer .to (device )
161- self .hidden_state = self .hidden_state .to (device )
134+ def plot_layers (self , ** kwargs ):
135+ """ Plot the weights matrix and distribution of each layer """
136+ self .input_layer .plot_layers ()
137+ self .hidden_layer .plot_layers ()
162138
163139 def print_layers (self ):
140+ """ Print the weights matrix and distribution of each layer """
164141 param_dict = {
165- "hidden_min" : self .hidden_state .min (),
166- "hidden_max" : self .hidden_state .max (),
167- "hidden_mean" : self .hidden_state .mean (),
142+ "init_hidden_min" : self .hidden_state .min (),
143+ "init_hidden_max" : self .hidden_state .max (),
168144 "preact_noise" : self .preact_noise ,
169145 "postact_noise" : self .postact_noise ,
170146 "activation" : self .act ,
171147 "alpha" : self .alpha ,
148+ "init_state" : self .init_state ,
149+ "init_state_learnable" : self .hidden_state .requires_grad ,
172150 }
173151 self .input_layer .print_layers ()
174152 print_dict ("Recurrence" , param_dict )
175153 self .hidden_layer .print_layers ()
176-
177- def plot_layers (self , ** kwargs ):
178- self .input_layer .plot_layers ()
179- self .hidden_layer .plot_layers ()
180- # ==================================================================================================
154+ # ==================================================================================================
0 commit comments