@@ -67,7 +67,7 @@ class REINFORCESynapse(DenseSynapse):
6767 # Define Functions
6868 def __init__ (
6969 self , name , shape , eta = 1e-4 , decay = 0.99 , weight_init = None , resist_scale = 1. , act_fx = None ,
70- p_conn = 1. , w_bound = 1. , batch_size = 1 , seed = None , mu_act_fx = None , ** kwargs
70+ p_conn = 1. , w_bound = 1. , batch_size = 1 , seed = None , mu_act_fx = None , mu_out_min = - jnp . inf , mu_out_max = jnp . inf , ** kwargs
7171 ) -> None :
7272 # This is because we have weights mu and weight log sigma
7373 input_dim , output_dim = shape
@@ -82,6 +82,8 @@ def __init__(
8282 # self.out_min = out_min
8383 # self.out_max = out_max
8484 self .mu_act_fx , self .dmu_act_fx = create_function (mu_act_fx if mu_act_fx is not None else "identity" )
85+ self .mu_out_min = mu_out_min
86+ self .mu_out_max = mu_out_max
8587
8688 ## Compartment setup
8789 self .dWeights = Compartment (self .weights .value * 0 )
@@ -97,7 +99,7 @@ def __init__(
9799 self .seed = Compartment (jax .random .PRNGKey (seed if seed is not None else 42 ))
98100
99101 @staticmethod
100- def _compute_update (dt , inputs , rewards , act_fx , weights , seed , mu_act_fx , dmu_act_fx ):
102+ def _compute_update (dt , inputs , rewards , act_fx , weights , seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max ):
101103 # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
102104 W_mu , W_logstd = jnp .split (weights , 2 , axis = - 1 )
103105 # Forward pass
@@ -110,6 +112,7 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
110112 # Sample using reparameterization trick
111113 epsilon = jax .random .normal (seed , fx_mean .shape )
112114 sample = epsilon * std + fx_mean
115+ sample = jnp .clip (sample , mu_out_min , mu_out_max )
113116 outputs = sample # the actual action that we take
114117 # Compute log probability density of the Gaussian
115118 log_prob = gaussian_logpdf (sample , fx_mean , std ).sum (- 1 )
@@ -144,10 +147,10 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
144147
145148 @transition (output_compartments = ["weights" , "dWeights" , "objective" , "outputs" , "accumulated_gradients" , "step_count" , "seed" ])
146149 @staticmethod
147- def evolve (dt , w_bound , inputs , rewards , act_fx , weights , eta , learning_mask , decay , accumulated_gradients , step_count , seed , mu_act_fx , dmu_act_fx ):
150+ def evolve (dt , w_bound , inputs , rewards , act_fx , weights , eta , learning_mask , decay , accumulated_gradients , step_count , seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max ):
148151 main_seed , sub_seed = jax .random .split (seed )
149152 dWeights , objective , outputs = REINFORCESynapse ._compute_update (
150- dt , inputs , rewards , act_fx , weights , sub_seed , mu_act_fx , dmu_act_fx
153+ dt , inputs , rewards , act_fx , weights , sub_seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max
151154 )
152155 ## do a gradient ascent update/shift
153156 weights = (weights + dWeights * eta ) * learning_mask + weights * (1.0 - learning_mask ) # update the weights only where learning_mask is 1.0
0 commit comments