1+ # %%
2+
3+ from jax import numpy as jnp , random , jit
4+ from ngcsimlib .context import Context
5+ import numpy as np
6+ np .random .seed (42 )
7+ from ngclearn .components import RewardErrorCell
8+ from ngcsimlib .compilers import compile_command , wrap_command
9+ from numpy .testing import assert_array_equal
10+
11+ from ngcsimlib .compilers .process import Process , transition
12+ from ngcsimlib .component import Component
13+ from ngcsimlib .compartment import Compartment
14+ from ngcsimlib .context import Context
15+ from ngcsimlib .utils .compartment import Get_Compartment_Batch
16+
17+
18+ def test_rewardErrorCell ():
19+ np .random .seed (42 )
20+ name = "reward_error_ctx"
21+ dkey = random .PRNGKey (42 )
22+ dkey , * subkeys = random .split (dkey , 100 )
23+ dt = 1. # ms
24+ alpha = 0.1 # decay factor for moving average
25+ with Context (name ) as ctx :
26+ a = RewardErrorCell (
27+ name = "a" , n_units = 1 , alpha = alpha , ema_window_len = 10 ,
28+ use_online_predictor = True , batch_size = 1
29+ )
30+ advance_process = (Process () >> a .advance_state )
31+ ctx .wrap_and_add_command (jit (advance_process .pure ), name = "run" )
32+ reset_process = (Process () >> a .reset )
33+ ctx .wrap_and_add_command (jit (reset_process .pure ), name = "reset" )
34+ evolve_process = (Process () >> a .evolve )
35+ ctx .wrap_and_add_command (jit (evolve_process .pure ), name = "evolve" )
36+
37+ # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
38+ # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
39+ # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
40+ # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
41+ # evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
42+ # ctx.add_command(wrap_command(jit(ctx.evolve)), name="evolve")
43+
44+ @Context .dynamicCommand
45+ def clamp_reward (x ):
46+ a .reward .set (x )
47+
48+ ## input reward sequence
49+ reward_seq = jnp .array ([[1.0 , 0.5 , 0.0 , 2.0 , 1.5 , 0.0 , 1.0 , 0.5 , 0.0 , 1.0 ]])
50+
51+ # NOTE: expected outputs: look at each function in the cell: e.g., advance_state, evolve, reset, to test
52+ # rpe = reward - mu, mu = mu * (1 - alpha) + reward * alpha
53+ # These expectation numbers will be computed in the loop below
54+ expected_mu = np .zeros ((1 , 10 ))
55+ expected_rpe = np .zeros ((1 , 10 ))
56+ expected_accum_reward = np .zeros ((1 , 10 ))
57+ # Calculate expected values
58+ mu_t = 0.0
59+ accum_t = 0.0
60+ for t in range (10 ):
61+ reward_t = reward_seq [0 , t ]
62+ # print(f"reward_t: {reward_t}")
63+ accum_t += reward_t
64+ # print(f"accum_t: {accum_t}")
65+ expected_accum_reward [0 , t ] = np .asarray (accum_t ) # NOTE: Formula: accum_reward = accum_reward + reward
66+ expected_rpe [0 , t ] = np .asarray (reward_t - mu_t ) # NOTE: Formula: rpe = reward - mu
67+ mu_t = mu_t * (1 - alpha ) + reward_t * alpha # NOTE: Formula: mu = mu * (1. - alpha) + reward * alpha
68+ # print(f"mu_t: {mu_t}")
69+ expected_mu [0 , t ] = np .asarray (mu_t )
70+
71+ mu_outs = []
72+ rpe_outs = []
73+ accum_reward_outs = []
74+ ctx .reset ()
75+ for ts in range (reward_seq .shape [1 ]):
76+ reward_t = jnp .array ([[reward_seq [0 , ts ]]]) ## get reward at time t
77+ ctx .clamp_reward (reward_t )
78+ ctx .run (t = ts * 1. , dt = dt )
79+ mu_outs .append (a .mu .value )
80+ rpe_outs .append (a .rpe .value )
81+ accum_reward_outs .append (a .accum_reward .value )
82+
83+ # Test evolve function
84+ ctx .evolve (t = 10 * 1. , dt = dt )
85+ final_mu = a .mu .value
86+ # print(f"final_mu: {final_mu}")
87+
88+ mu_outs = jnp .concatenate (mu_outs , axis = 1 )
89+ # print(mu_outs)
90+ rpe_outs = jnp .concatenate (rpe_outs , axis = 1 )
91+ # print(rpe_outs)
92+ accum_reward_outs = jnp .concatenate (accum_reward_outs , axis = 1 )
93+ # print(accum_reward_outs)
94+
95+ ## verify outputs match expected values
96+ np .testing .assert_allclose (mu_outs , expected_mu , atol = 1e-5 )
97+ np .testing .assert_allclose (rpe_outs , expected_rpe , atol = 1e-5 )
98+ np .testing .assert_allclose (accum_reward_outs , expected_accum_reward , atol = 1e-5 )
99+
100+ # Verify final mu after evolve
101+ # Basically copy the formula from the evolve function: r = accum_reward/n_ep_steps
102+ # and this one as well: `mu = (1. - 1./ema_window_len) * mu + (1./ema_window_len) * r`
103+ expected_final_mu = (1 - 1 / 10 ) * mu_outs [0 , - 1 ] + (1 / 10 ) * (accum_reward_outs [0 , - 1 ] / 10 )
104+ np .testing .assert_allclose (final_mu , expected_final_mu , atol = 1e-5 )
105+
106+ # test_rewardErrorCell()
0 commit comments