Skip to content

Commit 1110d98

Browse files
committed
update documentation
1 parent 91da161 commit 1110d98

File tree

1 file changed

+75
-10
lines changed

1 file changed

+75
-10
lines changed

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,50 @@ def gaussian_logpdf(event, mean, stddev):
1717
return - 0.5 * (log_normalizer + quadratic)
1818

1919
class REINFORCESynapse(DenseSynapse):
20+
"""
21+
A stochastic synapse implementing the REINFORCE algorithm (policy gradient method). This synapse
22+
uses Gaussian distributions for generating actions and performs gradient-based updates.
23+
24+
| --- Synapse Compartments: ---
25+
| inputs - input (takes in external signals)
26+
| outputs - output signals (sampled actions from Gaussian distribution)
27+
| weights - current value matrix of synaptic efficacies (contains both mean and log-std parameters)
28+
| dWeights - current delta matrix containing changes to be applied to synaptic efficacies
29+
| rewards - reward signals used to modulate weight updates (takes in external signals)
30+
| objective - scalar value of the current loss/objective
31+
| accumulated_gradients - exponential moving average of gradients for tracking learning progress
32+
| step_count - counter for number of learning steps
33+
| learning_mask - binary mask determining when learning occurs
34+
| seed - JAX PRNG key for random sampling
35+
36+
Args:
37+
name: the string name of this component
38+
39+
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
40+
with number of inputs by number of outputs)
41+
42+
eta: learning rate for weight updates (Default: 1e-4)
43+
44+
decay: decay factor for computing exponential moving average of gradients (Default: 0.99)
45+
46+
weight_init: a kernel to drive initialization of this synaptic cable's values;
47+
typically a tuple with 1st element as a string calling the name of
48+
initialization to use
49+
50+
resist_scale: a fixed scaling factor to apply to synaptic transform
51+
(Default: 1.)
52+
53+
act_fx: activation function to apply to inputs (Default: "identity")
54+
55+
p_conn: probability of a connection existing (default: 1.); setting
56+
this to < 1. will result in a sparser synaptic structure
57+
58+
w_bound: upper bound for weight clipping (Default: 1.)
59+
60+
batch_size: batch size dimension of this component (Default: 1)
61+
62+
seed: random seed for reproducibility (Default: 42)
63+
"""
2064

2165
# Define Functions
2266
def __init__(
@@ -41,17 +85,16 @@ def __init__(
4185
self.outputs = Compartment(jnp.zeros((batch_size, output_dim)))
4286
self.rewards = Compartment(jnp.zeros((batch_size,))) # the normalized reward (r - r_hat), input compartment
4387
self.act_fx, self.dact_fx = create_function(act_fx if act_fx is not None else "identity")
44-
# self.seed = Component(seed)
4588
self.accumulated_gradients = Compartment(jnp.zeros((input_dim, output_dim * 2)))
4689
self.decay = decay
4790
self.step_count = Compartment(jnp.zeros(()))
4891
self.learning_mask = Compartment(jnp.zeros(()))
49-
# self.seed = Component(jnp.array(seed) if seed is not None else jnp.array(42, dtype=jnp.int32))
5092
self.seed = Compartment(jax.random.PRNGKey(seed if seed is not None else 42))
5193

5294
@staticmethod
5395
def _compute_update(dt, inputs, rewards, act_fx, weights, seed):
54-
W_mu, W_logstd = jnp.split(weights, 2, axis=-1) # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
96+
# (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
97+
W_mu, W_logstd = jnp.split(weights, 2, axis=-1)
5598
# Forward pass
5699
activation = act_fx(inputs)
57100
mean = activation @ W_mu
@@ -73,8 +116,6 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed):
73116

74117
# Compute gradients manually based on the derivation
75118
# dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
76-
# -(sample - mean) instead of (sample - mean) because we are doing straight-through gradient in the log_prob function
77-
# therefore, computation including the mean in such function does not contribute to the gradient
78119
dlog_prob_dmean = (sample - mean) / (std ** 2)
79120
dL_dmean = dL_dlogp * dlog_prob_dmean # (B, A)
80121
dL_dWmu = activation.T @ dL_dmean
@@ -130,18 +171,42 @@ def reset(batch_size, shape):
130171
@classmethod
131172
def help(cls): ## component help function
132173
properties = {
133-
174+
"synapse_type": "REINFORCESynapse - implements a stochastic synaptic cable that uses "
175+
"the REINFORCE algorithm (policy gradient) to update weights based on rewards"
134176
}
135177
compartment_props = {
136-
178+
"inputs":
179+
{"inputs": "Takes in external input signal values",
180+
"rewards": "Takes in reward signals for modulating weight updates. The reward is often normalized by baseline reward (r - r_hat)"},
181+
"states":
182+
{"weights": "Synapse efficacy/strength parameter values (mean and log-std)",
183+
"dWeights": "Weight update values",
184+
"accumulated_gradients": "EMA of gradients over time",
185+
"step_count": "Counter for learning steps",
186+
"learning_mask": "Binary mask determining when learning occurs",
187+
"seed": "a single integer as initial jax PRNG key for this component"},
188+
"outputs":
189+
{"outputs": "Output samples from Gaussian distribution",
190+
"objective": "Current value of the loss/objective function"},
137191
}
138192
hyperparams = {
139-
193+
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
194+
"eta": "Learning rate for weight updates",
195+
"decay": "Decay factor for EMA of gradients",
196+
"weight_init": "Initialization conditions for synaptic weight values",
197+
"resist_scale": "Resistance level scaling factor applied to output",
198+
"act_fx": "Activation function to apply to inputs",
199+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
200+
"w_bound": "Upper bound for weight clipping",
201+
"batch_size": "Batch size dimension of this component",
202+
"seed": "Random seed for reproducibility"
140203
}
141204
info = {cls.__name__: properties,
142205
"compartments": compartment_props,
143-
# "dynamics": "outputs = [(W * Rscale) * inputs] ;"
144-
# "dW_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i",
206+
"dynamics": "mean = act_fx(inputs) @ W_mu; logstd = act_fx(inputs) @ W_logstd; "
207+
"outputs ~ N(mean, exp(logstd)); "
208+
"dW = -grad_reinforce(rewards, log_prob(outputs)). ",
209+
"Check compute_update() for more details."
145210
"hyperparameters": hyperparams}
146211
return info
147212

0 commit comments

Comments
 (0)