@@ -24,10 +24,11 @@ def test_REINFORCESynapse1():
2424 dkey = random .PRNGKey (1234 )
2525 dkey , * subkeys = random .split (dkey , 6 )
2626 dt = 1. # ms
27+ decay = 0.99
2728 # ---- build a simple Poisson cell system ----
2829 with Context (name ) as ctx :
2930 a = REINFORCESynapse (
30- name = "a" , shape = (1 ,1 ), act_fx = "tanh" , key = subkeys [0 ]
31+ name = "a" , shape = (1 ,1 ), decay = decay , act_fx = "tanh" , key = subkeys [0 ]
3132 )
3233
3334 evolve_process = (Process () >> a .evolve )
@@ -48,27 +49,9 @@ def clamp_rewards(x):
4849 def clamp_weights (x ):
4950 a .weights .set (x )
5051
51- # a.weights.set(jnp.ones((1, 1)) * 0.1)
52-
53- ## check pre-synaptic STDP only
54- # truth = jnp.array([[1.25]])
55- np .random .seed (42 )
56- ctx .reset ()
57- clamp_weights (jnp .ones ((1 , 2 )) * 2 )
58- clamp_rewards (jnp .ones ((1 , 1 )) * 3 )
59- clamp_inputs (jnp .ones ((1 , 1 )) * 0.5 )
60- ctx .adapt (t = 1. , dt = dt )
61- # assert_array_equal(a.dWeights.value, truth)
62- print (f"weights: { a .weights .value } " )
63- print (f"dWeights: { a .dWeights .value } " )
64- print (f"step_count: { a .step_count .value } " )
65- print (f"accumulated_gradients: { a .accumulated_gradients .value } " )
66- print (f"objective: { a .objective .value } " )
67-
68- np .random .seed (42 )
69- # JAX Grad output
52+ # Function definition
7053 _act = jax .nn .tanh
71- def fn (params : dict , inputs : jax .Array , outputs : jax .Array , seed ):
54+ def fn (params : dict , inputs : jax .Array , outputs : jax .Array ):
7255 W_mu , W_logstd = params
7356 activation = _act (inputs )
7457 mean = activation @ W_mu
@@ -81,24 +64,111 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed):
8164 return (- logp * outputs ).mean () * 1e-2
8265 grad_fn = jax .value_and_grad (fn )
8366
84- weights_mu = jnp .ones ((1 , 1 )) * 2
85- weights_logstd = jnp .ones ((1 , 1 )) * 2
86- inputs = jnp .ones ((1 , 1 )) * 0.5
87- outputs = jnp .ones ((1 , 1 )) * 3 # reward
88- objective , grads = grad_fn (
89- (weights_mu , weights_logstd ),
67+ # Some setups
68+ expected_weights_mu = jnp .asarray ([[0.13 ]])
69+ expected_weights_logstd = jnp .asarray ([[0.04 ]])
70+ expected_weights = jnp .concatenate ([expected_weights_mu , expected_weights_logstd ], axis = - 1 )
71+ initial_ngclearn_weights = jnp .concatenate ([expected_weights_mu , expected_weights_logstd ], axis = - 1 )[None ]
72+ expected_gradient_list = []
73+ ctx .reset ()
74+
75+ # Loop through 3 steps
76+ step = 1
77+ # ---------------- Step {step} --------------------
78+ print (f"------------ [Step { step } ] ------------" )
79+ inputs = - 1 ** step * jnp .ones ((1 , 1 )) / 10 # * 0.5 * step / 10.0
80+ outputs = - 1 ** step * jnp .ones ((1 , 1 )) / 10 # * 3 * step / 10.0# reward
81+ # --------- ngclearn ---------
82+ clamp_weights (initial_ngclearn_weights )
83+ clamp_rewards (outputs )
84+ clamp_inputs (inputs )
85+ np .random .seed (42 )
86+ ctx .adapt (t = 1. , dt = dt )
87+ print (f"[ngclearn] objective: { a .objective .value } " )
88+ print (f"[ngclearn] weights: { a .weights .value } " )
89+ print (f"[ngclearn] dWeights: { a .dWeights .value } " )
90+ print (f"[ngclearn] step_count: { a .step_count .value } " )
91+ print (f"[ngclearn] accumulated_gradients: { a .accumulated_gradients .value } " )
92+ # -------- Expectation ---------
93+ print ("--------------" )
94+ np .random .seed (42 )
95+ expected_objective , expected_grads = grad_fn (
96+ (expected_weights_mu , expected_weights_logstd ),
97+ inputs ,
98+ outputs ,
99+ )
100+ # NOTE: Viet: negate the gradient because gradient in ngc-learn
101+ # is gradient ascent, while gradient in JAX is gradient descent
102+ expected_grads = - jnp .concatenate ([expected_grads [0 ], expected_grads [1 ]], axis = - 1 )
103+ expected_gradient_list .append (expected_grads )
104+ print (f"[Expectation] expected_weights: { expected_weights } " )
105+ print (f"[Expectation] dWeights: { expected_grads } " )
106+ print (f"[Expectation] objective: { expected_objective } " )
107+ np .testing .assert_allclose (
108+ a .dWeights .value [0 ],
109+ expected_grads ,
110+ atol = 1e-8
111+ )
112+ np .testing .assert_allclose (
113+ a .objective .value ,
114+ expected_objective ,
115+ atol = 1e-8
116+ )
117+ print ()
118+
119+
120+ step = 2
121+ # ---------------- Step {step} --------------------
122+ print (f"------------ [Step { step } ] ------------" )
123+ inputs = - 1 ** step * jnp .ones ((1 , 1 )) / 10 # * 0.5 * step / 10.0
124+ outputs = - 1 ** step * jnp .ones ((1 , 1 )) / 10 # * 3 * step / 10.0# reward
125+ # --------- ngclearn ---------
126+ clamp_weights (initial_ngclearn_weights )
127+ clamp_rewards (outputs )
128+ clamp_inputs (inputs )
129+ np .random .seed (43 )
130+ ctx .adapt (t = 1. , dt = dt )
131+ print (f"[ngclearn] objective: { a .objective .value } " )
132+ print (f"[ngclearn] weights: { a .weights .value } " )
133+ print (f"[ngclearn] dWeights: { a .dWeights .value } " )
134+ print (f"[ngclearn] step_count: { a .step_count .value } " )
135+ print (f"[ngclearn] accumulated_gradients: { a .accumulated_gradients .value } " )
136+ # -------- Expectation ---------
137+ print ("--------------" )
138+ np .random .seed (43 )
139+ expected_objective , expected_grads = grad_fn (
140+ (expected_weights_mu , expected_weights_logstd ),
90141 inputs ,
91142 outputs ,
92- jax .random .key (42 )
93143 )
94- print (f"expected grads: { grads } " )
95- print (f"expected objective: { objective } " )
144+ # NOTE: Viet: negate the gradient because gradient in ngc-learn
145+ # is gradient ascent, while gradient in JAX is gradient descent
146+ expected_grads = - jnp .concatenate ([expected_grads [0 ], expected_grads [1 ]], axis = - 1 )
147+ expected_gradient_list .append (expected_grads )
148+ print (f"[Expectation] expected_weights: { expected_weights } " )
149+ print (f"[Expectation] dWeights: { expected_grads } " )
150+ print (f"[Expectation] objective: { expected_objective } " )
96151 np .testing .assert_allclose (
97152 a .dWeights .value [0 ],
98- # NOTE: Viet: negate the gradient because gradient in ngc-learn
99- # is gradient ascent, while gradient in JAX is gradient descent
100- - jnp .concatenate ([grads [0 ], grads [1 ]], axis = - 1 ),
153+ expected_grads ,
101154 atol = 1e-8
102- ) # NOTE: gradient is not exact due to different gradient computation, we need to inspect more closely
155+ )
156+ np .testing .assert_allclose (
157+ a .objective .value ,
158+ expected_objective ,
159+ atol = 1e-8
160+ )
161+ print ()
162+
163+ # Finally, check if the accumulated gradients are correct
164+ decay_list = jnp .asarray ([decay ** 1 , decay ** 0 ])
165+ expected_accumulated_gradients = jnp .mean (jnp .stack (expected_gradient_list , 0 ) * decay_list [:, None , None ], axis = 0 )
166+ np .testing .assert_allclose (
167+ a .accumulated_gradients .value [0 ],
168+ expected_accumulated_gradients ,
169+ atol = 1e-8
170+ )
171+
172+
173+ test_REINFORCESynapse1 ()
103174
104- # test_REINFORCESynapse1()
0 commit comments