@@ -17,6 +17,50 @@ def gaussian_logpdf(event, mean, stddev):
1717 return - 0.5 * (log_normalizer + quadratic )
1818
1919class REINFORCESynapse (DenseSynapse ):
20+ """
21+ A stochastic synapse implementing the REINFORCE algorithm (policy gradient method). This synapse
22+ uses Gaussian distributions for generating actions and performs gradient-based updates.
23+
24+ | --- Synapse Compartments: ---
25+ | inputs - input (takes in external signals)
26+ | outputs - output signals (sampled actions from Gaussian distribution)
27+ | weights - current value matrix of synaptic efficacies (contains both mean and log-std parameters)
28+ | dWeights - current delta matrix containing changes to be applied to synaptic efficacies
29+ | rewards - reward signals used to modulate weight updates (takes in external signals)
30+ | objective - scalar value of the current loss/objective
31+ | accumulated_gradients - exponential moving average of gradients for tracking learning progress
32+ | step_count - counter for number of learning steps
33+ | learning_mask - binary mask determining when learning occurs
34+ | seed - JAX PRNG key for random sampling
35+
36+ Args:
37+ name: the string name of this component
38+
39+ shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
40+ with number of inputs by number of outputs)
41+
42+ eta: learning rate for weight updates (Default: 1e-4)
43+
44+ decay: decay factor for computing exponential moving average of gradients (Default: 0.99)
45+
46+ weight_init: a kernel to drive initialization of this synaptic cable's values;
47+ typically a tuple with 1st element as a string calling the name of
48+ initialization to use
49+
50+ resist_scale: a fixed scaling factor to apply to synaptic transform
51+ (Default: 1.)
52+
53+ act_fx: activation function to apply to inputs (Default: "identity")
54+
55+ p_conn: probability of a connection existing (default: 1.); setting
56+ this to < 1. will result in a sparser synaptic structure
57+
58+ w_bound: upper bound for weight clipping (Default: 1.)
59+
60+ batch_size: batch size dimension of this component (Default: 1)
61+
62+ seed: random seed for reproducibility (Default: 42)
63+ """
2064
2165 # Define Functions
2266 def __init__ (
@@ -41,17 +85,16 @@ def __init__(
4185 self .outputs = Compartment (jnp .zeros ((batch_size , output_dim )))
4286 self .rewards = Compartment (jnp .zeros ((batch_size ,))) # the normalized reward (r - r_hat), input compartment
4387 self .act_fx , self .dact_fx = create_function (act_fx if act_fx is not None else "identity" )
44- # self.seed = Component(seed)
4588 self .accumulated_gradients = Compartment (jnp .zeros ((input_dim , output_dim * 2 )))
4689 self .decay = decay
4790 self .step_count = Compartment (jnp .zeros (()))
4891 self .learning_mask = Compartment (jnp .zeros (()))
49- # self.seed = Component(jnp.array(seed) if seed is not None else jnp.array(42, dtype=jnp.int32))
5092 self .seed = Compartment (jax .random .PRNGKey (seed if seed is not None else 42 ))
5193
5294 @staticmethod
5395 def _compute_update (dt , inputs , rewards , act_fx , weights , seed ):
54- W_mu , W_logstd = jnp .split (weights , 2 , axis = - 1 ) # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
96+ # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
97+ W_mu , W_logstd = jnp .split (weights , 2 , axis = - 1 )
5598 # Forward pass
5699 activation = act_fx (inputs )
57100 mean = activation @ W_mu
@@ -73,8 +116,6 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed):
73116
74117 # Compute gradients manually based on the derivation
75118 # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
76- # -(sample - mean) instead of (sample - mean) because we are doing straight-through gradient in the log_prob function
77- # therefore, computation including the mean in such function does not contribute to the gradient
78119 dlog_prob_dmean = (sample - mean ) / (std ** 2 )
79120 dL_dmean = dL_dlogp * dlog_prob_dmean # (B, A)
80121 dL_dWmu = activation .T @ dL_dmean
@@ -130,18 +171,42 @@ def reset(batch_size, shape):
130171 @classmethod
131172 def help (cls ): ## component help function
132173 properties = {
133-
174+ "synapse_type" : "REINFORCESynapse - implements a stochastic synaptic cable that uses "
175+ "the REINFORCE algorithm (policy gradient) to update weights based on rewards"
134176 }
135177 compartment_props = {
136-
178+ "inputs" :
179+ {"inputs" : "Takes in external input signal values" ,
180+ "rewards" : "Takes in reward signals for modulating weight updates. The reward is often normalized by baseline reward (r - r_hat)" },
181+ "states" :
182+ {"weights" : "Synapse efficacy/strength parameter values (mean and log-std)" ,
183+ "dWeights" : "Weight update values" ,
184+ "accumulated_gradients" : "EMA of gradients over time" ,
185+ "step_count" : "Counter for learning steps" ,
186+ "learning_mask" : "Binary mask determining when learning occurs" ,
187+ "seed" : "a single integer as initial jax PRNG key for this component" },
188+ "outputs" :
189+ {"outputs" : "Output samples from Gaussian distribution" ,
190+ "objective" : "Current value of the loss/objective function" },
137191 }
138192 hyperparams = {
139-
193+ "shape" : "Shape of synaptic weight value matrix; number inputs x number outputs" ,
194+ "eta" : "Learning rate for weight updates" ,
195+ "decay" : "Decay factor for EMA of gradients" ,
196+ "weight_init" : "Initialization conditions for synaptic weight values" ,
197+ "resist_scale" : "Resistance level scaling factor applied to output" ,
198+ "act_fx" : "Activation function to apply to inputs" ,
199+ "p_conn" : "Probability of a connection existing (otherwise, it is masked to zero)" ,
200+ "w_bound" : "Upper bound for weight clipping" ,
201+ "batch_size" : "Batch size dimension of this component" ,
202+ "seed" : "Random seed for reproducibility"
140203 }
141204 info = {cls .__name__ : properties ,
142205 "compartments" : compartment_props ,
143- # "dynamics": "outputs = [(W * Rscale) * inputs] ;"
144- # "dW_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i",
206+ "dynamics" : "mean = act_fx(inputs) @ W_mu; logstd = act_fx(inputs) @ W_logstd; "
207+ "outputs ~ N(mean, exp(logstd)); "
208+ "dW = -grad_reinforce(rewards, log_prob(outputs)). " ,
209+ "Check compute_update() for more details."
145210 "hyperparameters" : hyperparams }
146211 return info
147212
0 commit comments