1+ # %%
2+
13from jax import random , numpy as jnp , jit
2- from ngcsimlib .compilers .process import transition
3- from ngcsimlib .component import Component
4+ from ngcsimlib .logger import info
45from ngcsimlib .compartment import Compartment
6+ from ngcsimlib .parser import compilable
57from ngclearn .utils .model_utils import clip , d_clip
68import jax
79import jax .numpy as jnp
@@ -17,11 +19,59 @@ def gaussian_logpdf(event, mean, stddev):
1719 quadratic = (event - mean )** 2 / scale_sqrd
1820 return - 0.5 * (log_normalizer + quadratic )
1921
22+
23+ def _compute_update (dt , inputs , rewards , act_fx , weights , seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max , scalar_stddev ):
24+ learning_stddev_mask = jnp .asarray (scalar_stddev <= 0.0 , dtype = jnp .float32 )
25+ # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
26+ W_mu , W_logstd = jnp .split (weights , 2 , axis = - 1 )
27+ # Forward pass
28+ activation = act_fx (inputs )
29+ mean = activation @ W_mu
30+ fx_mean = mu_act_fx (mean )
31+ logstd = activation @ W_logstd
32+ clip_logstd = clip (logstd , - 10.0 , 2.0 )
33+ std = jnp .exp (clip_logstd )
34+ std = learning_stddev_mask * std + (1.0 - learning_stddev_mask ) * scalar_stddev # masking trick
35+ # Sample using reparameterization trick
36+ epsilon = jax .random .normal (seed , fx_mean .shape )
37+ sample = epsilon * std + fx_mean
38+ sample = jnp .clip (sample , mu_out_min , mu_out_max )
39+ outputs = sample # the actual action that we take
40+ # Compute log probability density of the Gaussian
41+ log_prob = gaussian_logpdf (sample , fx_mean , std ).sum (- 1 )
42+ # Compute objective (negative REINFORCE objective)
43+ objective = (- log_prob * rewards ).mean () * 1e-2
44+
45+ # Backward pass
46+ batch_size = inputs .shape [0 ] # B
47+ dL_dlogp = - rewards [:, None ] * 1e-2 / batch_size # (B, 1)
48+
49+ # Compute gradients manually based on the derivation
50+ # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
51+ dlog_prob_dfxmean = (sample - fx_mean ) / (std ** 2 )
52+ dL_dmean = dL_dlogp * dlog_prob_dfxmean * dmu_act_fx (mean ) # (B, A)
53+ dL_dWmu = activation .T @ dL_dmean
54+
55+ # dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
56+ dlog_prob_dlogstd = - 1.0 / std + (sample - fx_mean )** 2 / std ** 3
57+ dL_dstd = dL_dlogp * dlog_prob_dlogstd
58+ # Apply gradient clipping for logstd
59+ dL_dlogstd = d_clip (logstd , - 10.0 , 2.0 ) * dL_dstd * std
60+ dL_dWlogstd = activation .T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
61+ dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev
62+
63+ # Update weights, negate the gradient because gradient ascent in ngc-learn
64+ dW = jnp .concatenate ([- dL_dWmu , - dL_dWlogstd ], axis = - 1 )
65+ # Finally, return metrics if needed
66+ return dW , objective , outputs
67+
68+
69+
2070class REINFORCESynapse (DenseSynapse ):
2171 """
2272 A stochastic synapse implementing the REINFORCE algorithm (policy gradient method). This synapse
2373 uses Gaussian distributions for generating actions and performs gradient-based updates.
24-
74+
2575 | --- Synapse Compartments: ---
2676 | inputs - input (takes in external signals)
2777 | outputs - output signals (sampled actions from Gaussian distribution)
@@ -89,7 +139,7 @@ def __init__(
89139 self .scalar_stddev = scalar_stddev
90140
91141 ## Compartment setup
92- self .dWeights = Compartment (self .weights .value * 0 )
142+ self .dWeights = Compartment (self .weights .get () * 0 )
93143 # self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate # For eligiblity traces later
94144 self .objective = Compartment (jnp .zeros (()))
95145 self .outputs = Compartment (jnp .zeros ((batch_size , output_dim )))
@@ -101,72 +151,50 @@ def __init__(
101151 self .learning_mask = Compartment (jnp .zeros (()))
102152 self .seed = Compartment (jax .random .PRNGKey (seed if seed is not None else 42 ))
103153
104- @staticmethod
105- def _compute_update (dt , inputs , rewards , act_fx , weights , seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max , scalar_stddev ):
106- learning_stddev_mask = jnp .asarray (scalar_stddev <= 0.0 , dtype = jnp .float32 )
107- # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
108- W_mu , W_logstd = jnp .split (weights , 2 , axis = - 1 )
109- # Forward pass
110- activation = act_fx (inputs )
111- mean = activation @ W_mu
112- fx_mean = mu_act_fx (mean )
113- logstd = activation @ W_logstd
114- clip_logstd = clip (logstd , - 10.0 , 2.0 )
115- std = jnp .exp (clip_logstd )
116- std = learning_stddev_mask * std + (1.0 - learning_stddev_mask ) * scalar_stddev # masking trick
117- # Sample using reparameterization trick
118- epsilon = jax .random .normal (seed , fx_mean .shape )
119- sample = epsilon * std + fx_mean
120- sample = jnp .clip (sample , mu_out_min , mu_out_max )
121- outputs = sample # the actual action that we take
122- # Compute log probability density of the Gaussian
123- log_prob = gaussian_logpdf (sample , fx_mean , std ).sum (- 1 )
124- # Compute objective (negative REINFORCE objective)
125- objective = (- log_prob * rewards ).mean () * 1e-2
126-
127- # Backward pass
128- batch_size = inputs .shape [0 ] # B
129- dL_dlogp = - rewards [:, None ] * 1e-2 / batch_size # (B, 1)
130-
131- # Compute gradients manually based on the derivation
132- # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
133- dlog_prob_dfxmean = (sample - fx_mean ) / (std ** 2 )
134- dL_dmean = dL_dlogp * dlog_prob_dfxmean * dmu_act_fx (mean ) # (B, A)
135- dL_dWmu = activation .T @ dL_dmean
136-
137- # dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
138- dlog_prob_dlogstd = - 1.0 / std + (sample - fx_mean )** 2 / std ** 3
139- dL_dstd = dL_dlogp * dlog_prob_dlogstd
140- # Apply gradient clipping for logstd
141- dL_dlogstd = d_clip (logstd , - 10.0 , 2.0 ) * dL_dstd * std
142- dL_dWlogstd = activation .T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
143- dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev
144-
145- # Update weights, negate the gradient because gradient ascent in ngc-learn
146- dW = jnp .concatenate ([- dL_dWmu , - dL_dWlogstd ], axis = - 1 )
147- # Finally, return metrics if needed
148- return dW , objective , outputs
149-
150- @transition (output_compartments = ["weights" , "dWeights" , "objective" , "outputs" , "accumulated_gradients" , "step_count" , "seed" ])
151- @staticmethod
152- def evolve (dt , w_bound , inputs , rewards , act_fx , weights , eta , learning_mask , decay , accumulated_gradients , step_count , seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max , scalar_stddev ):
154+
155+ # @transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count", "seed"])
156+ # @staticmethod
157+ @compilable
158+ def evolve (self , dt ):
159+
160+ # Get compartment values
161+ weights = self .weights .get ()
162+ dWeights = self .dWeights .get ()
163+ objective = self .objective .get ()
164+ outputs = self .outputs .get ()
165+ accumulated_gradients = self .accumulated_gradients .get ()
166+ step_count = self .step_count .get ()
167+ seed = self .seed .get ()
168+ inputs = self .inputs .get ()
169+ rewards = self .rewards .get ()
170+
171+ # Main logic
153172 main_seed , sub_seed = jax .random .split (seed )
154- dWeights , objective , outputs = REINFORCESynapse . _compute_update (
155- dt , inputs , rewards , act_fx , weights , sub_seed , mu_act_fx , dmu_act_fx , mu_out_min , mu_out_max , scalar_stddev
173+ dWeights , objective , outputs = _compute_update (
174+ dt , inputs , rewards , self . act_fx , weights , sub_seed , self . mu_act_fx , self . dmu_act_fx , self . mu_out_min , self . mu_out_max , self . scalar_stddev
156175 )
157176 ## do a gradient ascent update/shift
158- weights = (weights + dWeights * eta ) * learning_mask + weights * (1.0 - learning_mask ) # update the weights only where learning_mask is 1.0
177+ weights = (weights + dWeights * self . eta ) * self . learning_mask + weights * (1.0 - self . learning_mask ) # update the weights only where learning_mask is 1.0
159178 ## enforce non-negativity
160179 eps = 0.0 # 0.01 # 0.001
161- weights = jnp .clip (weights , eps , w_bound - eps ) # jnp.abs(w_bound))
180+ weights = jnp .clip (weights , eps , self . w_bound - eps ) # jnp.abs(w_bound))
162181 step_count += 1
163- accumulated_gradients = (step_count - 1 ) / step_count * accumulated_gradients * decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients
164- step_count = step_count * (1 - learning_mask ) # reset the step count to 0 when we have learned
165- return weights , dWeights , objective , outputs , accumulated_gradients , step_count , main_seed
182+ accumulated_gradients = (step_count - 1 ) / step_count * accumulated_gradients * self .decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients
183+ step_count = step_count * (1 - self .learning_mask ) # reset the step count to 0 when we have learned
184+
185+ # Set updated compartment values
186+ self .weights .set (weights )
187+ self .dWeights .set (dWeights )
188+ self .objective .set (objective )
189+ self .outputs .set (outputs )
190+ self .accumulated_gradients .set (accumulated_gradients )
191+ self .step_count .set (step_count )
192+ self .seed .set (main_seed )
166193
167- @transition (output_compartments = ["inputs" , "outputs" , "objective" , "rewards" , "dWeights" , "accumulated_gradients" , "step_count" , "seed" ])
168- @staticmethod
169- def reset (batch_size , shape ):
194+ # @transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients", "step_count", "seed"])
195+ # @staticmethod
196+ @compilable
197+ def reset (self , batch_size , shape ):
170198 preVals = jnp .zeros ((batch_size , shape [0 ]))
171199 postVals = jnp .zeros ((batch_size , shape [1 ]))
172200 inputs = preVals
@@ -177,7 +205,17 @@ def reset(batch_size, shape):
177205 accumulated_gradients = jnp .zeros ((shape [0 ], shape [1 ] * 2 ))
178206 step_count = jnp .zeros (())
179207 seed = jax .random .PRNGKey (42 )
180- return inputs , outputs , objective , rewards , dWeights , accumulated_gradients , step_count , seed
208+
209+
210+ self .inputs .set (inputs )
211+ self .outputs .set (outputs )
212+ self .objective .set (objective )
213+ self .rewards .set (rewards )
214+ self .dWeights .set (dWeights )
215+ self .accumulated_gradients .set (accumulated_gradients )
216+ self .step_count .set (step_count )
217+ self .seed .set (seed )
218+
181219
182220 @classmethod
183221 def help (cls ): ## component help function
@@ -223,15 +261,27 @@ def help(cls): ## component help function
223261 return info
224262
225263 def __repr__ (self ):
226- comps = [varname for varname in dir (self ) if Compartment . is_compartment (getattr (self , varname ))]
264+ comps = [varname for varname in dir (self ) if isinstance (getattr (self , varname ), Compartment )]
227265 maxlen = max (len (c ) for c in comps ) + 5
228266 lines = f"[{ self .__class__ .__name__ } ] PATH: { self .name } \n "
229267 for c in comps :
230- stats = tensorstats (getattr (self , c ).value )
268+ stats = tensorstats (getattr (self , c ).get () )
231269 if stats is not None :
232270 line = [f"{ k } : { v } " for k , v in stats .items ()]
233271 line = ", " .join (line )
234272 else :
235273 line = "None"
236274 lines += f" { f'({ c } )' .ljust (maxlen )} { line } \n "
237275 return lines
276+
277+
278+ if __name__ == '__main__' :
279+ from ngcsimlib .context import Context
280+ with Context ("Bar" ) as bar :
281+ syn = REINFORCESynapse (
282+ name = "reinforce_syn" ,
283+ shape = (3 , 2 )
284+ )
285+ # Wab = syn.weights.get()
286+ print (syn )
287+
0 commit comments