Skip to content

Commit 68d2435

Browse files
committed
update code and test
1 parent 701b501 commit 68d2435

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class REINFORCESynapse(DenseSynapse):
6767
# Define Functions
6868
def __init__(
6969
self, name, shape, eta=1e-4, decay=0.99, weight_init=None, resist_scale=1., act_fx=None,
70-
p_conn=1., w_bound=1., batch_size=1, seed=None, mu_act_fx=None, **kwargs
70+
p_conn=1., w_bound=1., batch_size=1, seed=None, mu_act_fx=None, mu_out_min=-jnp.inf, mu_out_max=jnp.inf, **kwargs
7171
) -> None:
7272
# This is because we have weights mu and weight log sigma
7373
input_dim, output_dim = shape
@@ -82,6 +82,8 @@ def __init__(
8282
# self.out_min = out_min
8383
# self.out_max = out_max
8484
self.mu_act_fx, self.dmu_act_fx = create_function(mu_act_fx if mu_act_fx is not None else "identity")
85+
self.mu_out_min = mu_out_min
86+
self.mu_out_max = mu_out_max
8587

8688
## Compartment setup
8789
self.dWeights = Compartment(self.weights.value * 0)
@@ -97,7 +99,7 @@ def __init__(
9799
self.seed = Compartment(jax.random.PRNGKey(seed if seed is not None else 42))
98100

99101
@staticmethod
100-
def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx):
102+
def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max):
101103
# (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
102104
W_mu, W_logstd = jnp.split(weights, 2, axis=-1)
103105
# Forward pass
@@ -110,6 +112,7 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
110112
# Sample using reparameterization trick
111113
epsilon = jax.random.normal(seed, fx_mean.shape)
112114
sample = epsilon * std + fx_mean
115+
sample = jnp.clip(sample, mu_out_min, mu_out_max)
113116
outputs = sample # the actual action that we take
114117
# Compute log probability density of the Gaussian
115118
log_prob = gaussian_logpdf(sample, fx_mean, std).sum(-1)
@@ -144,10 +147,10 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
144147

145148
@transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count", "seed"])
146149
@staticmethod
147-
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, decay, accumulated_gradients, step_count, seed, mu_act_fx, dmu_act_fx):
150+
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, decay, accumulated_gradients, step_count, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max):
148151
main_seed, sub_seed = jax.random.split(seed)
149152
dWeights, objective, outputs = REINFORCESynapse._compute_update(
150-
dt, inputs, rewards, act_fx, weights, sub_seed, mu_act_fx, dmu_act_fx
153+
dt, inputs, rewards, act_fx, weights, sub_seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max
151154
)
152155
## do a gradient ascent update/shift
153156
weights = (weights + dWeights * eta) * learning_mask + weights * (1.0 - learning_mask) # update the weights only where learning_mask is 1.0

tests/components/synapses/modulated/test_REINFORCESynapse.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ def test_REINFORCESynapse1():
2626
dt = 1. # ms
2727
decay = 0.99
2828
initial_seed = 42
29+
mu_out_min = -jnp.inf
30+
mu_out_max = jnp.inf
2931

3032
# ---- build a simple Poisson cell system ----
3133
with Context(name) as ctx:
3234
a = REINFORCESynapse(
3335
name="a", shape=(1,1), decay=decay,
3436
act_fx="tanh", key=subkeys[0], seed=initial_seed,
35-
mu_act_fx="tanh"
37+
mu_act_fx="tanh", mu_out_min=mu_out_min, mu_out_max=mu_out_max
3638
)
3739

3840
evolve_process = (Process("evolve_proc") >> a.evolve)
@@ -64,6 +66,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
6466
logstd = activation @ W_logstd
6567
std = jnp.exp(logstd.clip(-10.0, 2.0))
6668
sample = jax.random.normal(seed, mean.shape) * std + mean
69+
sample = jnp.clip(sample, mu_out_min, mu_out_max)
6770
logp = gaussian_logpdf(jax.lax.stop_gradient(sample), mean, std).sum(-1)
6871
return (-logp * outputs).mean() * 1e-2
6972
grad_fn = jax.value_and_grad(fn)

0 commit comments

Comments
 (0)