Skip to content

Commit e302af2

Browse files
committed
update reinforce cell to the new model utils clip
1 parent 24963e5 commit e302af2

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ngcsimlib.compilers.process import transition
33
from ngcsimlib.component import Component
44
from ngcsimlib.compartment import Compartment
5+
from ngclearn.utils.model_utils import clip, d_clip
56
import jax
67
import jax.numpy as jnp
78
import numpy as np
@@ -110,7 +111,7 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
110111
mean = activation @ W_mu
111112
fx_mean = mu_act_fx(mean)
112113
logstd = activation @ W_logstd
113-
clip_logstd = jnp.clip(logstd, -10.0, 2.0)
114+
clip_logstd = clip(logstd, -10.0, 2.0)
114115
std = jnp.exp(clip_logstd)
115116
std = learning_stddev_mask * std + (1.0 - learning_stddev_mask) * scalar_stddev # masking trick
116117
# Sample using reparameterization trick
@@ -137,11 +138,7 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
137138
dlog_prob_dlogstd = - 1.0 / std + (sample - fx_mean)**2 / std**3
138139
dL_dstd = dL_dlogp * dlog_prob_dlogstd
139140
# Apply gradient clipping for logstd
140-
dL_dlogstd = jnp.where(
141-
(logstd <= -10.0) | (logstd >= 2.0),
142-
0.0, # Zero gradient when clipped
143-
dL_dstd * std
144-
)
141+
dL_dlogstd = d_clip(logstd, -10.0, 2.0) * dL_dstd * std
145142
dL_dWlogstd = activation.T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
146143
dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev
147144

0 commit comments

Comments
 (0)