Skip to content

Commit 0d7c24b

Browse files
committed
working reinforce synapse
1 parent cac5207 commit 0d7c24b

File tree

1 file changed

+116
-66
lines changed

1 file changed

+116
-66
lines changed

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 116 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
# %%
2+
13
from jax import random, numpy as jnp, jit
2-
from ngcsimlib.compilers.process import transition
3-
from ngcsimlib.component import Component
4+
from ngcsimlib.logger import info
45
from ngcsimlib.compartment import Compartment
6+
from ngcsimlib.parser import compilable
57
from ngclearn.utils.model_utils import clip, d_clip
68
import jax
79
import jax.numpy as jnp
@@ -17,11 +19,59 @@ def gaussian_logpdf(event, mean, stddev):
1719
quadratic = (event - mean)**2 / scale_sqrd
1820
return - 0.5 * (log_normalizer + quadratic)
1921

22+
23+
def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev):
24+
learning_stddev_mask = jnp.asarray(scalar_stddev <= 0.0, dtype=jnp.float32)
25+
# (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
26+
W_mu, W_logstd = jnp.split(weights, 2, axis=-1)
27+
# Forward pass
28+
activation = act_fx(inputs)
29+
mean = activation @ W_mu
30+
fx_mean = mu_act_fx(mean)
31+
logstd = activation @ W_logstd
32+
clip_logstd = clip(logstd, -10.0, 2.0)
33+
std = jnp.exp(clip_logstd)
34+
std = learning_stddev_mask * std + (1.0 - learning_stddev_mask) * scalar_stddev # masking trick
35+
# Sample using reparameterization trick
36+
epsilon = jax.random.normal(seed, fx_mean.shape)
37+
sample = epsilon * std + fx_mean
38+
sample = jnp.clip(sample, mu_out_min, mu_out_max)
39+
outputs = sample # the actual action that we take
40+
# Compute log probability density of the Gaussian
41+
log_prob = gaussian_logpdf(sample, fx_mean, std).sum(-1)
42+
# Compute objective (negative REINFORCE objective)
43+
objective = (-log_prob * rewards).mean() * 1e-2
44+
45+
# Backward pass
46+
batch_size = inputs.shape[0] # B
47+
dL_dlogp = -rewards[:, None] * 1e-2 / batch_size # (B, 1)
48+
49+
# Compute gradients manually based on the derivation
50+
# dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
51+
dlog_prob_dfxmean = (sample - fx_mean) / (std ** 2)
52+
dL_dmean = dL_dlogp * dlog_prob_dfxmean * dmu_act_fx(mean) # (B, A)
53+
dL_dWmu = activation.T @ dL_dmean
54+
55+
# dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
56+
dlog_prob_dlogstd = - 1.0 / std + (sample - fx_mean)**2 / std**3
57+
dL_dstd = dL_dlogp * dlog_prob_dlogstd
58+
# Apply gradient clipping for logstd
59+
dL_dlogstd = d_clip(logstd, -10.0, 2.0) * dL_dstd * std
60+
dL_dWlogstd = activation.T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
61+
dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev
62+
63+
# Update weights, negate the gradient because gradient ascent in ngc-learn
64+
dW = jnp.concatenate([-dL_dWmu, -dL_dWlogstd], axis=-1)
65+
# Finally, return metrics if needed
66+
return dW, objective, outputs
67+
68+
69+
2070
class REINFORCESynapse(DenseSynapse):
2171
"""
2272
A stochastic synapse implementing the REINFORCE algorithm (policy gradient method). This synapse
2373
uses Gaussian distributions for generating actions and performs gradient-based updates.
24-
74+
2575
| --- Synapse Compartments: ---
2676
| inputs - input (takes in external signals)
2777
| outputs - output signals (sampled actions from Gaussian distribution)
@@ -89,7 +139,7 @@ def __init__(
89139
self.scalar_stddev = scalar_stddev
90140

91141
## Compartment setup
92-
self.dWeights = Compartment(self.weights.value * 0)
142+
self.dWeights = Compartment(self.weights.get() * 0)
93143
# self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate # For eligiblity traces later
94144
self.objective = Compartment(jnp.zeros(()))
95145
self.outputs = Compartment(jnp.zeros((batch_size, output_dim)))
@@ -101,72 +151,50 @@ def __init__(
101151
self.learning_mask = Compartment(jnp.zeros(()))
102152
self.seed = Compartment(jax.random.PRNGKey(seed if seed is not None else 42))
103153

104-
@staticmethod
105-
def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev):
106-
learning_stddev_mask = jnp.asarray(scalar_stddev <= 0.0, dtype=jnp.float32)
107-
# (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
108-
W_mu, W_logstd = jnp.split(weights, 2, axis=-1)
109-
# Forward pass
110-
activation = act_fx(inputs)
111-
mean = activation @ W_mu
112-
fx_mean = mu_act_fx(mean)
113-
logstd = activation @ W_logstd
114-
clip_logstd = clip(logstd, -10.0, 2.0)
115-
std = jnp.exp(clip_logstd)
116-
std = learning_stddev_mask * std + (1.0 - learning_stddev_mask) * scalar_stddev # masking trick
117-
# Sample using reparameterization trick
118-
epsilon = jax.random.normal(seed, fx_mean.shape)
119-
sample = epsilon * std + fx_mean
120-
sample = jnp.clip(sample, mu_out_min, mu_out_max)
121-
outputs = sample # the actual action that we take
122-
# Compute log probability density of the Gaussian
123-
log_prob = gaussian_logpdf(sample, fx_mean, std).sum(-1)
124-
# Compute objective (negative REINFORCE objective)
125-
objective = (-log_prob * rewards).mean() * 1e-2
126-
127-
# Backward pass
128-
batch_size = inputs.shape[0] # B
129-
dL_dlogp = -rewards[:, None] * 1e-2 / batch_size # (B, 1)
130-
131-
# Compute gradients manually based on the derivation
132-
# dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
133-
dlog_prob_dfxmean = (sample - fx_mean) / (std ** 2)
134-
dL_dmean = dL_dlogp * dlog_prob_dfxmean * dmu_act_fx(mean) # (B, A)
135-
dL_dWmu = activation.T @ dL_dmean
136-
137-
# dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
138-
dlog_prob_dlogstd = - 1.0 / std + (sample - fx_mean)**2 / std**3
139-
dL_dstd = dL_dlogp * dlog_prob_dlogstd
140-
# Apply gradient clipping for logstd
141-
dL_dlogstd = d_clip(logstd, -10.0, 2.0) * dL_dstd * std
142-
dL_dWlogstd = activation.T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
143-
dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev
144-
145-
# Update weights, negate the gradient because gradient ascent in ngc-learn
146-
dW = jnp.concatenate([-dL_dWmu, -dL_dWlogstd], axis=-1)
147-
# Finally, return metrics if needed
148-
return dW, objective, outputs
149-
150-
@transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count", "seed"])
151-
@staticmethod
152-
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, decay, accumulated_gradients, step_count, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev):
154+
155+
# @transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count", "seed"])
156+
# @staticmethod
157+
@compilable
158+
def evolve(self, dt):
159+
160+
# Get compartment values
161+
weights = self.weights.get()
162+
dWeights = self.dWeights.get()
163+
objective = self.objective.get()
164+
outputs = self.outputs.get()
165+
accumulated_gradients = self.accumulated_gradients.get()
166+
step_count = self.step_count.get()
167+
seed = self.seed.get()
168+
inputs = self.inputs.get()
169+
rewards = self.rewards.get()
170+
171+
# Main logic
153172
main_seed, sub_seed = jax.random.split(seed)
154-
dWeights, objective, outputs = REINFORCESynapse._compute_update(
155-
dt, inputs, rewards, act_fx, weights, sub_seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev
173+
dWeights, objective, outputs = _compute_update(
174+
dt, inputs, rewards, self.act_fx, weights, sub_seed, self.mu_act_fx, self.dmu_act_fx, self.mu_out_min, self.mu_out_max, self.scalar_stddev
156175
)
157176
## do a gradient ascent update/shift
158-
weights = (weights + dWeights * eta) * learning_mask + weights * (1.0 - learning_mask) # update the weights only where learning_mask is 1.0
177+
weights = (weights + dWeights * self.eta) * self.learning_mask + weights * (1.0 - self.learning_mask) # update the weights only where learning_mask is 1.0
159178
## enforce non-negativity
160179
eps = 0.0 # 0.01 # 0.001
161-
weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
180+
weights = jnp.clip(weights, eps, self.w_bound - eps) # jnp.abs(w_bound))
162181
step_count += 1
163-
accumulated_gradients = (step_count - 1) / step_count * accumulated_gradients * decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients
164-
step_count = step_count * (1 - learning_mask) # reset the step count to 0 when we have learned
165-
return weights, dWeights, objective, outputs, accumulated_gradients, step_count, main_seed
182+
accumulated_gradients = (step_count - 1) / step_count * accumulated_gradients * self.decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients
183+
step_count = step_count * (1 - self.learning_mask) # reset the step count to 0 when we have learned
184+
185+
# Set updated compartment values
186+
self.weights.set(weights)
187+
self.dWeights.set(dWeights)
188+
self.objective.set(objective)
189+
self.outputs.set(outputs)
190+
self.accumulated_gradients.set(accumulated_gradients)
191+
self.step_count.set(step_count)
192+
self.seed.set(main_seed)
166193

167-
@transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients", "step_count", "seed"])
168-
@staticmethod
169-
def reset(batch_size, shape):
194+
# @transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients", "step_count", "seed"])
195+
# @staticmethod
196+
@compilable
197+
def reset(self, batch_size, shape):
170198
preVals = jnp.zeros((batch_size, shape[0]))
171199
postVals = jnp.zeros((batch_size, shape[1]))
172200
inputs = preVals
@@ -177,7 +205,17 @@ def reset(batch_size, shape):
177205
accumulated_gradients = jnp.zeros((shape[0], shape[1] * 2))
178206
step_count = jnp.zeros(())
179207
seed = jax.random.PRNGKey(42)
180-
return inputs, outputs, objective, rewards, dWeights, accumulated_gradients, step_count, seed
208+
209+
210+
self.inputs.set(inputs)
211+
self.outputs.set(outputs)
212+
self.objective.set(objective)
213+
self.rewards.set(rewards)
214+
self.dWeights.set(dWeights)
215+
self.accumulated_gradients.set(accumulated_gradients)
216+
self.step_count.set(step_count)
217+
self.seed.set(seed)
218+
181219

182220
@classmethod
183221
def help(cls): ## component help function
@@ -223,15 +261,27 @@ def help(cls): ## component help function
223261
return info
224262

225263
def __repr__(self):
226-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
264+
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
227265
maxlen = max(len(c) for c in comps) + 5
228266
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
229267
for c in comps:
230-
stats = tensorstats(getattr(self, c).value)
268+
stats = tensorstats(getattr(self, c).get())
231269
if stats is not None:
232270
line = [f"{k}: {v}" for k, v in stats.items()]
233271
line = ", ".join(line)
234272
else:
235273
line = "None"
236274
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
237275
return lines
276+
277+
278+
if __name__ == '__main__':
279+
from ngcsimlib.context import Context
280+
with Context("Bar") as bar:
281+
syn = REINFORCESynapse(
282+
name="reinforce_syn",
283+
shape=(3, 2)
284+
)
285+
# Wab = syn.weights.get()
286+
print(syn)
287+

0 commit comments

Comments
 (0)