1111from ngclearn .utils .model_utils import create_function
1212
1313def gaussian_logpdf (event , mean , stddev ):
14- scale_sqrd = stddev ** 2
15- log_normalizer = jnp .log (2 * jnp .pi * scale_sqrd )
16- quadratic = (jax .lax .stop_gradient (event - 2 * mean ) + mean )** 2 / scale_sqrd
17- return - 0.5 * (log_normalizer + quadratic )
14+ # scale_sqrd = stddev ** 2
15+ # log_normalizer = jnp.log(2 * jnp.pi * scale_sqrd)
16+ # quadratic = (jax.lax.stop_gradient(event - 2 * mean) + mean)**2 / scale_sqrd
17+ # return - 0.5 * (log_normalizer + quadratic)
18+ return - 0.5 * jnp .log (2 * jnp .pi ) - jnp .log (stddev ) - 0.5 * ( (jax .lax .stop_gradient (event - 2 * mean ) + mean ) / stddev )** 2
1819
1920class REINFORCESynapse (DenseSynapse ):
2021
2122 # Define Functions
2223 def __init__ (
2324 self , name , shape , eta = 1e-4 , decay = 0.99 , weight_init = None , resist_scale = 1. , act_fx = None ,
24- p_conn = 1. , w_bound = 1. , batch_size = 1 , ** kwargs
25+ p_conn = 1. , w_bound = 1. , batch_size = 1 , seed = None , ** kwargs
2526 ):
2627 # This is because we have weights mu and weight log sigma
2728 input_dim , output_dim = shape
@@ -46,48 +47,61 @@ def __init__(
4647 self .decay = decay
4748 self .step_count = Compartment (jnp .zeros (()))
4849 self .learning_mask = Compartment (jnp .zeros (()))
50+ # self.seed = Component(jnp.array(seed) if seed is not None else jnp.array(42, dtype=jnp.int32))
51+ self .seed = Compartment (jax .random .PRNGKey (seed if seed is not None else 42 ))
4952
5053 @staticmethod
51- def _compute_update (dt , inputs , rewards , act_fx , weights ):
54+ def _compute_update (dt , inputs , rewards , act_fx , weights , seed ):
5255 W_mu , W_logstd = jnp .split (weights , 2 , axis = - 1 ) # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
5356 # Forward pass
5457 activation = act_fx (inputs )
5558 mean = activation @ W_mu
5659 logstd = activation @ W_logstd
57- std = jnp .exp (logstd .clip (- 10.0 , 2.0 ))
60+ clip_logstd = jnp .clip (logstd , - 10.0 , 2.0 )
61+ std = jnp .exp (clip_logstd )
5862 # Sample using reparameterization trick
59- epsilon = jnp . asarray ( np . random .normal (0 , 1 , mean .shape ) )
63+ epsilon = jax . random .normal (seed , mean .shape )
6064 sample = epsilon * std + mean
6165 outputs = sample # the actual action that we take
6266 # Compute log probability density of the Gaussian
63- log_prob = gaussian_logpdf (sample , mean , std )
64- log_prob = log_prob .sum (- 1 )
67+ log_prob = gaussian_logpdf (sample , mean , std ).sum (- 1 )
6568 # Compute objective (negative REINFORCE objective)
6669 objective = (- log_prob * rewards ).mean () * 1e-2
70+
6771 # Backward pass
72+ batch_size = inputs .shape [0 ] # B
73+ dL_dlogp = - rewards [:, None ] * 1e-2 / batch_size # (B, 1)
74+
6875 # Compute gradients manually based on the derivation
6976 # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
7077 # -(sample - mean) instead of (sample - mean) because we are doing straight-through gradient in the log_prob function
7178 # therefore, computation including the mean in such function does not contribute to the gradient
7279 dlog_prob_dmean = - (sample - mean ) / (std ** 2 )
80+ dL_dmean = dL_dlogp * dlog_prob_dmean # (B, A)
81+ dL_dWmu = activation .T @ dL_dmean
82+
7383 # dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
74- dlog_prob_dlogstd = ((sample - mean ) / std ) ** 2 - 1.0
75- # Compute gradients with respect to weights
76- # Using chain rule: dL/dW_mu = dL/dmu * dmu/dW_mu = dL/dmu * activation^T
77- # Similarly for W_logstd
78- # Gradient ascent instead of descent
79- dL_dWmu = activation .T @ (rewards [:, None ] * dlog_prob_dmean ) * 1e-2
80- dL_dWlstd = activation .T @ (rewards [:, None ] * dlog_prob_dlogstd ) * 1e-2
81- # Update weights
82- dW = jnp .concatenate ([dL_dWmu , dL_dWlstd ], axis = - 1 )
84+ dlog_prob_dlogstd = (sample - mean )** 2 / std ** 3 - 1.0 / std
85+ dL_dstd = dL_dlogp * dlog_prob_dlogstd
86+ # Apply gradient clipping for logstd
87+ dL_dlogstd = jnp .where (
88+ (logstd <= - 10.0 ) | (logstd >= 2.0 ),
89+ 0.0 , # Zero gradient when clipped
90+ dL_dstd * std
91+ )
92+ dL_dWlogstd = activation .T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
93+
94+ # Update weights, negate the gradient because gradient ascent in ngc-learn
95+ dW = jnp .concatenate ([- dL_dWmu , - dL_dWlogstd ], axis = - 1 )
8396 # Finally, return metrics if needed
8497 return dW , objective , outputs
8598
86- @transition (output_compartments = ["weights" , "dWeights" , "objective" , "outputs" , "accumulated_gradients" , "step_count" ])
99+ @transition (output_compartments = ["weights" , "dWeights" , "objective" , "outputs" , "accumulated_gradients" , "step_count" , "seed" ])
87100 @staticmethod
88- def evolve (dt , w_bound , inputs , rewards , act_fx , weights , eta , learning_mask , decay , accumulated_gradients , step_count ):
101+ def evolve (dt , w_bound , inputs , rewards , act_fx , weights , eta , learning_mask , decay , accumulated_gradients , step_count , seed ):
102+ main_seed , sub_seed = jax .random .split (seed )
89103 dWeights , objective , outputs = REINFORCESynapse ._compute_update (
90- dt , inputs , rewards , act_fx , weights
104+ dt , inputs , rewards , act_fx , weights , sub_seed
91105 )
92106 ## do a gradient ascent update/shift
93107 weights = (weights + dWeights * eta ) * learning_mask + weights * (1.0 - learning_mask ) # update the weights only where learning_mask is 1.0
@@ -97,9 +111,9 @@ def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, de
97111 step_count += 1
98112 accumulated_gradients = (step_count - 1 ) / step_count * accumulated_gradients * decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients
99113 step_count = step_count * (1 - learning_mask ) # reset the step count to 0 when we have learned
100- return weights , dWeights , objective , outputs , accumulated_gradients , step_count
114+ return weights , dWeights , objective , outputs , accumulated_gradients , step_count , main_seed
101115
102- @transition (output_compartments = ["inputs" , "outputs" , "objective" , "rewards" , "dWeights" , "accumulated_gradients" , "step_count" ])
116+ @transition (output_compartments = ["inputs" , "outputs" , "objective" , "rewards" , "dWeights" , "accumulated_gradients" , "step_count" , "seed" ])
103117 @staticmethod
104118 def reset (batch_size , shape ):
105119 preVals = jnp .zeros ((batch_size , shape [0 ]))
@@ -111,7 +125,8 @@ def reset(batch_size, shape):
111125 dWeights = jnp .zeros (shape )
112126 accumulated_gradients = jnp .zeros ((shape [0 ], shape [1 ] * 2 ))
113127 step_count = jnp .zeros (())
114- return inputs , outputs , objective , rewards , dWeights , accumulated_gradients , step_count
128+ seed = jax .random .PRNGKey (42 )
129+ return inputs , outputs , objective , rewards , dWeights , accumulated_gradients , step_count , seed
115130
116131 @classmethod
117132 def help (cls ): ## component help function
0 commit comments