1010from ngclearn .utils import tensorstats
1111from ngclearn .utils .model_utils import create_function
1212
13+ def gaussian_logpdf (event , mean , stddev ):
14+ scale_sqrd = stddev ** 2
15+ log_normalizer = jnp .log (2 * jnp .pi * scale_sqrd )
16+ quadratic = (jax .lax .stop_gradient (event - 2 * mean ) + mean )** 2 / scale_sqrd
17+ return - 0.5 * (log_normalizer + quadratic )
1318
1419class REINFORCESynapse (DenseSynapse ):
1520
@@ -39,6 +44,8 @@ def __init__(
3944 # self.seed = Component(seed)
4045 self .accumulated_gradients = Compartment (jnp .zeros ((input_dim , output_dim * 2 )))
4146 self .decay = decay
47+ self .step_count = Compartment (jnp .zeros (()))
48+ self .learning_mask = Compartment (jnp .zeros (()))
4249
4350 @staticmethod
4451 def _compute_update (dt , inputs , rewards , act_fx , weights ):
@@ -53,41 +60,44 @@ def _compute_update(dt, inputs, rewards, act_fx, weights):
5360 sample = epsilon * std + mean
5461 outputs = sample # the actual action that we take
5562 # Compute log probability density of the Gaussian
56- log_prob = - 0.5 * jnp . log ( 2 * jnp . pi ) - logstd - 0.5 * (( sample - mean ) / std ) ** 2
63+ log_prob = gaussian_logpdf ( sample , mean , std )
5764 log_prob = log_prob .sum (- 1 )
5865 # Compute objective (negative REINFORCE objective)
5966 objective = (- log_prob * rewards ).mean () * 1e-2
6067 # Backward pass
6168 # Compute gradients manually based on the derivation
62- # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * (sample-mu)/sigma^2
63- dlog_prob_dmean = (sample - mean ) / (std ** 2 )
69+ # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * - (sample-mu)/sigma^2
70+ dlog_prob_dmean = - (sample - mean ) / (std ** 2 )
6471 # dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
6572 dlog_prob_dlogstd = ((sample - mean ) / std ) ** 2 - 1.0
6673 # Compute gradients with respect to weights
6774 # Using chain rule: dL/dW_mu = dL/dmu * dmu/dW_mu = dL/dmu * activation^T
6875 # Similarly for W_logstd
69- dL_dWmu = activation .T @ (- rewards [:, None ] * dlog_prob_dmean ) * 1e-2
70- dL_dWlstd = activation .T @ (- rewards [:, None ] * dlog_prob_dlogstd ) * 1e-2
76+ # Gradient ascent instead of descent
77+ dL_dWmu = activation .T @ (rewards [:, None ] * dlog_prob_dmean ) * 1e-2
78+ dL_dWlstd = activation .T @ (rewards [:, None ] * dlog_prob_dlogstd ) * 1e-2
7179 # Update weights
7280 dW = jnp .concatenate ([dL_dWmu , dL_dWlstd ], axis = - 1 )
7381 # Finally, return metrics if needed
7482 return dW , objective , outputs
7583
76- @transition (output_compartments = ["weights" , "dWeights" , "objective" , "outputs" , "accumulated_gradients" ])
84+ @transition (output_compartments = ["weights" , "dWeights" , "objective" , "outputs" , "accumulated_gradients" , "step_count" ])
7785 @staticmethod
78- def evolve (dt , w_bound , inputs , rewards , act_fx , weights , eta , decay , accumulated_gradients ):
86+ def evolve (dt , w_bound , inputs , rewards , act_fx , weights , eta , learning_mask , decay , accumulated_gradients , step_count ):
7987 dWeights , objective , outputs = REINFORCESynapse ._compute_update (
8088 dt , inputs , rewards , act_fx , weights
8189 )
8290 ## do a gradient ascent update/shift
83- weights = weights + dWeights * eta
91+ weights = ( weights + dWeights * eta ) * learning_mask + weights * ( 1.0 - learning_mask ) # update the weights only where learning_mask is 1.0
8492 ## enforce non-negativity
85- eps = 0.01 # 0.001
93+ eps = 0.0 # 0. 01 # 0.001
8694 weights = jnp .clip (weights , eps , w_bound - eps ) # jnp.abs(w_bound))
87- accumulated_gradients = accumulated_gradients * decay + dWeights
88- return weights , dWeights , objective , outputs , accumulated_gradients
95+ step_count += 1
96+ accumulated_gradients = (step_count - 1 ) / step_count * accumulated_gradients * decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients
97+ step_count = step_count * (1 - learning_mask ) # reset the step count to 0 when we have learned
98+ return weights , dWeights , objective , outputs , accumulated_gradients , step_count
8999
90- @transition (output_compartments = ["inputs" , "outputs" , "objective" , "rewards" , "dWeights" , "accumulated_gradients" ])
100+ @transition (output_compartments = ["inputs" , "outputs" , "objective" , "rewards" , "dWeights" , "accumulated_gradients" , "step_count" ])
91101 @staticmethod
92102 def reset (batch_size , shape ):
93103 preVals = jnp .zeros ((batch_size , shape [0 ]))
@@ -98,7 +108,8 @@ def reset(batch_size, shape):
98108 rewards = jnp .zeros ((batch_size ,))
99109 dWeights = jnp .zeros (shape )
100110 accumulated_gradients = jnp .zeros ((shape [0 ], shape [1 ] * 2 ))
101- return inputs , outputs , objective , rewards , dWeights , accumulated_gradients
111+ step_count = jnp .zeros (())
112+ return inputs , outputs , objective , rewards , dWeights , accumulated_gradients , step_count
102113
103114 @classmethod
104115 def help (cls ): ## component help function
0 commit comments