@@ -15,7 +15,7 @@ class REINFORCESynapse(DenseSynapse):
1515
1616 # Define Functions
1717 def __init__ (
18- self , name , shape , eta = 1e-4 , weight_init = None , resist_scale = 1. , act_fx = None ,
18+ self , name , shape , eta = 1e-4 , decay = 0.99 , weight_init = None , resist_scale = 1. , act_fx = None ,
1919 p_conn = 1. , w_bound = 1. , batch_size = 1 , ** kwargs
2020 ):
2121 # This is because we have weights mu and weight log sigma
@@ -37,7 +37,8 @@ def __init__(
3737 self .rewards = Compartment (jnp .zeros ((batch_size ,))) # the normalized reward (r - r_hat), input compartment
3838 self .act_fx , self .dact_fx = create_function (act_fx if act_fx is not None else "identity" )
3939 # self.seed = Component(seed)
40-
40+ self .accumulated_gradients = Compartment (jnp .zeros ((input_dim , output_dim * 2 )))
41+ self .decay = decay
4142
4243 @staticmethod
4344 def _compute_update (dt , inputs , rewards , act_fx , weights ):
@@ -72,9 +73,9 @@ def _compute_update(dt, inputs, rewards, act_fx, weights):
7273 # Finally, return metrics if needed
7374 return dW , objective , outputs
7475
75- @transition (output_compartments = ["weights" , "dWeights" , "objective" , "outputs" ])
76+ @transition (output_compartments = ["weights" , "dWeights" , "objective" , "outputs" , "accumulated_gradients" ])
7677 @staticmethod
77- def evolve (dt , w_bound , inputs , rewards , act_fx , weights , eta ):
78+ def evolve (dt , w_bound , inputs , rewards , act_fx , weights , eta , decay , accumulated_gradients ):
7879 dWeights , objective , outputs = REINFORCESynapse ._compute_update (
7980 dt , inputs , rewards , act_fx , weights
8081 )
@@ -83,9 +84,10 @@ def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta):
8384 ## enforce non-negativity
8485 eps = 0.01 # 0.001
8586 weights = jnp .clip (weights , eps , w_bound - eps ) # jnp.abs(w_bound))
86- return weights , dWeights , objective , outputs
87+ accumulated_gradients = accumulated_gradients * decay + dWeights
88+ return weights , dWeights , objective , outputs , accumulated_gradients
8789
88- @transition (output_compartments = ["inputs" , "outputs" , "objective" , "rewards" , "dWeights" ])
90+ @transition (output_compartments = ["inputs" , "outputs" , "objective" , "rewards" , "dWeights" , "accumulated_gradients" ])
8991 @staticmethod
9092 def reset (batch_size , shape ):
9193 preVals = jnp .zeros ((batch_size , shape [0 ]))
@@ -95,7 +97,8 @@ def reset(batch_size, shape):
9597 objective = jnp .zeros (())
9698 rewards = jnp .zeros ((batch_size ,))
9799 dWeights = jnp .zeros (shape )
98- return inputs , outputs , objective , rewards , dWeights
100+ accumulated_gradients = jnp .zeros ((shape [0 ], shape [1 ] * 2 ))
101+ return inputs , outputs , objective , rewards , dWeights , accumulated_gradients
99102
100103 @classmethod
101104 def help (cls ): ## component help function
@@ -110,8 +113,8 @@ def help(cls): ## component help function
110113 }
111114 info = {cls .__name__ : properties ,
112115 "compartments" : compartment_props ,
113- "dynamics" : "outputs = [(W * Rscale) * inputs] ;"
114- "dW_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i" ,
116+ # "dynamics": "outputs = [(W * Rscale) * inputs] ;"
117+ # "dW_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i",
115118 "hyperparameters" : hyperparams }
116119 return info
117120
0 commit comments