22from ngcsimlib .compilers .process import transition
33from ngcsimlib .component import Component
44from ngcsimlib .compartment import Compartment
5+ from ngclearn .utils .model_utils import clip , d_clip
56import jax
67import jax .numpy as jnp
78import numpy as np
@@ -110,7 +111,7 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
110111 mean = activation @ W_mu
111112 fx_mean = mu_act_fx (mean )
112113 logstd = activation @ W_logstd
113- clip_logstd = jnp . clip (logstd , - 10.0 , 2.0 )
114+ clip_logstd = clip (logstd , - 10.0 , 2.0 )
114115 std = jnp .exp (clip_logstd )
115116 std = learning_stddev_mask * std + (1.0 - learning_stddev_mask ) * scalar_stddev # masking trick
116117 # Sample using reparameterization trick
@@ -137,11 +138,7 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
137138 dlog_prob_dlogstd = - 1.0 / std + (sample - fx_mean )** 2 / std ** 3
138139 dL_dstd = dL_dlogp * dlog_prob_dlogstd
139140 # Apply gradient clipping for logstd
140- dL_dlogstd = jnp .where (
141- (logstd <= - 10.0 ) | (logstd >= 2.0 ),
142- 0.0 , # Zero gradient when clipped
143- dL_dstd * std
144- )
141+ dL_dlogstd = d_clip (logstd , - 10.0 , 2.0 ) * dL_dstd * std
145142 dL_dWlogstd = activation .T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
146143 dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev
147144
0 commit comments