@@ -67,7 +67,8 @@ class REINFORCESynapse(DenseSynapse):
6767 # Define Functions
6868 def __init__ (
6969 self , name , shape , eta = 1e-4 , decay = 0.99 , weight_init = None , resist_scale = 1. , act_fx = None ,
70- p_conn = 1. , w_bound = 1. , batch_size = 1 , seed = None , mu_act_fx = None , mu_out_min = - jnp .inf , mu_out_max = jnp .inf , ** kwargs
70+ p_conn = 1. , w_bound = 1. , batch_size = 1 , seed = None , mu_act_fx = None , mu_out_min = - jnp .inf , mu_out_max = jnp .inf ,
71+ scalar_stddev = - 1.0 , ** kwargs
7172 ) -> None :
7273 # This is because we have weights mu and weight log sigma
7374 input_dim , output_dim = shape
@@ -84,6 +85,7 @@ def __init__(
8485 self .mu_act_fx , self .dmu_act_fx = create_function (mu_act_fx if mu_act_fx is not None else "identity" )
8586 self .mu_out_min = mu_out_min
8687 self .mu_out_max = mu_out_max
88+ self .scalar_stddev = scalar_stddev
8789
8890 ## Compartment setup
8991 self .dWeights = Compartment (self .weights .value * 0 )
@@ -99,7 +101,8 @@ def __init__(
99101 self .seed = Compartment (jax .random .PRNGKey (seed if seed is not None else 42 ))
100102
101103 @staticmethod
102- def _compute_update (dt , inputs , rewards , act_fx , weights , seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max ):
104+ def _compute_update (dt , inputs , rewards , act_fx , weights , seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max , scalar_stddev ):
105+ learning_stddev_mask = jnp .asarray (scalar_stddev <= 0.0 , dtype = jnp .float32 )
103106 # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
104107 W_mu , W_logstd = jnp .split (weights , 2 , axis = - 1 )
105108 # Forward pass
@@ -109,6 +112,7 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
109112 logstd = activation @ W_logstd
110113 clip_logstd = jnp .clip (logstd , - 10.0 , 2.0 )
111114 std = jnp .exp (clip_logstd )
115+ std = learning_stddev_mask * std + (1.0 - learning_stddev_mask ) * scalar_stddev # masking trick
112116 # Sample using reparameterization trick
113117 epsilon = jax .random .normal (seed , fx_mean .shape )
114118 sample = epsilon * std + fx_mean
@@ -139,6 +143,7 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
139143 dL_dstd * std
140144 )
141145 dL_dWlogstd = activation .T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
146+ dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev
142147
143148 # Update weights, negate the gradient because gradient ascent in ngc-learn
144149 dW = jnp .concatenate ([- dL_dWmu , - dL_dWlogstd ], axis = - 1 )
@@ -147,10 +152,10 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
147152
148153 @transition (output_compartments = ["weights" , "dWeights" , "objective" , "outputs" , "accumulated_gradients" , "step_count" , "seed" ])
149154 @staticmethod
150- def evolve (dt , w_bound , inputs , rewards , act_fx , weights , eta , learning_mask , decay , accumulated_gradients , step_count , seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max ):
155+ def evolve (dt , w_bound , inputs , rewards , act_fx , weights , eta , learning_mask , decay , accumulated_gradients , step_count , seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max , scalar_stddev ):
151156 main_seed , sub_seed = jax .random .split (seed )
152157 dWeights , objective , outputs = REINFORCESynapse ._compute_update (
153- dt , inputs , rewards , act_fx , weights , sub_seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max
158+ dt , inputs , rewards , act_fx , weights , sub_seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max , scalar_stddev
154159 )
155160 ## do a gradient ascent update/shift
156161 weights = (weights + dWeights * eta ) * learning_mask + weights * (1.0 - learning_mask ) # update the weights only where learning_mask is 1.0
0 commit comments