Skip to content

Commit b30c8fe

Browse files
committed
refactor bernoulli, laplacian, and rewarderror cells
1 parent eaafb64 commit b30c8fe

File tree

3 files changed

+17
-60
lines changed

3 files changed

+17
-60
lines changed

ngclearn/components/neurons/graded/bernoulliErrorCell.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from jax import numpy as jnp, jit
44
from ngclearn.utils import tensorstats
55
from ngclearn.utils.model_utils import sigmoid, d_sigmoid
6+
from ngcsimlib.compilers.process import transition
67

78
class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
89
"""
@@ -58,8 +59,9 @@ def __init__(self, name, n_units, batch_size=1, input_logits=False, shape=None,
5859
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
5960
self.mask = Compartment(restVals + 1.0)
6061

62+
@transition(output_compartments=["dp", "dtarget", "L", "mask"])
6163
@staticmethod
62-
def _advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bernoulli error cell output
64+
def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bernoulli error cell output
6365
# Moves Bernoulli error cell dynamics one step forward. Specifically, this routine emulates the error unit
6466
# behavior of the local cost functional
6567
eps = 0.0001
@@ -89,15 +91,9 @@ def _advance_state(dt, p, target, modulator, mask, input_logits): ## compute Ber
8991
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
9092
return dp, dtarget, jnp.squeeze(L), mask
9193

92-
@resolver(_advance_state)
93-
def advance_state(self, dp, dtarget, L, mask):
94-
self.dp.set(dp)
95-
self.dtarget.set(dtarget)
96-
self.L.set(L)
97-
self.mask.set(mask)
98-
94+
@transition(output_compartments=["dp", "dtarget", "target", "p", "modulator", "L", "mask"])
9995
@staticmethod
100-
def _reset(batch_size, shape): ## reset core components/statistics
96+
def reset(batch_size, shape): ## reset core components/statistics
10197
_shape = (batch_size, shape[0])
10298
if len(shape) > 1:
10399
_shape = (batch_size, shape[0], shape[1], shape[2])
@@ -111,16 +107,6 @@ def _reset(batch_size, shape): ## reset core components/statistics
111107
mask = jnp.ones(_shape) ## reset mask
112108
return dp, dtarget, target, p, modulator, L, mask
113109

114-
@resolver(_reset)
115-
def reset(self, dp, dtarget, target, p, modulator, L, mask):
116-
self.dp.set(dp)
117-
self.dtarget.set(dtarget)
118-
self.target.set(target)
119-
self.p.set(p)
120-
self.modulator.set(modulator)
121-
self.L.set(L)
122-
self.mask.set(mask)
123-
124110
@classmethod
125111
def help(cls): ## component help function
126112
properties = {

ngclearn/components/neurons/graded/laplacianErrorCell.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ngclearn.components.jaxComponent import JaxComponent
33
from jax import numpy as jnp, jit
44
from ngclearn.utils import tensorstats
5+
from ngcsimlib.compilers.process import transition
56

67
class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
78
"""
@@ -66,8 +67,9 @@ def __init__(self, name, n_units, batch_size=1, scale=1., shape=None, **kwargs):
6667
self.modulator = Compartment(restVals + 1.0) ## to be set/consumed
6768
self.mask = Compartment(restVals + 1.0)
6869

70+
@transition(output_compartments=["dshift", "dtarget", "dScale", "L", "mask"])
6971
@staticmethod
70-
def _advance_state(dt, shift, target, Scale, modulator, mask): ## compute Laplacian error cell output
72+
def advance_state(dt, shift, target, Scale, modulator, mask): ## compute Laplacian error cell output
7173
# Moves Laplacian cell dynamics one step forward. Specifically, this routine emulates the error unit
7274
# behavior of the local cost functional:
7375
# FIXME: Currently, below does: L(targ, shift) = -||targ - shift||_1/scale
@@ -85,16 +87,9 @@ def _advance_state(dt, shift, target, Scale, modulator, mask): ## compute Laplac
8587
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
8688
return dshift, dtarget, dScale, jnp.squeeze(L), mask
8789

88-
@resolver(_advance_state)
89-
def advance_state(self, dshift, dtarget, dScale, L, mask):
90-
self.dshift.set(dshift)
91-
self.dtarget.set(dtarget)
92-
self.dScale.set(dScale)
93-
self.L.set(L)
94-
self.mask.set(mask)
95-
90+
@transition(output_compartments=["dshift", "dtarget", "dScale", "target", "shift", "modulator", "L", "mask"])
9691
@staticmethod
97-
def _reset(batch_size, n_units, scale_shape):
92+
def reset(batch_size, n_units, scale_shape):
9893
restVals = jnp.zeros((batch_size, n_units))
9994
dshift = restVals
10095
dtarget = restVals
@@ -106,17 +101,6 @@ def _reset(batch_size, n_units, scale_shape):
106101
mask = jnp.ones((batch_size, n_units))
107102
return dshift, dtarget, dScale, target, shift, modulator, L, mask
108103

109-
@resolver(_reset)
110-
def reset(self, dshift, dtarget, dScale, target, shift, modulator, L, mask):
111-
self.dshift.set(dshift)
112-
self.dtarget.set(dtarget)
113-
self.dScale.set(dScale)
114-
self.target.set(target)
115-
self.shift.set(shift)
116-
self.modulator.set(modulator)
117-
self.L.set(L)
118-
self.mask.set(mask)
119-
120104
@classmethod
121105
def help(cls): ## component help function
122106
properties = {

ngclearn/components/neurons/graded/rewardErrorCell.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ngclearn import resolver, Component, Compartment
22
from ngclearn.components.jaxComponent import JaxComponent
33
from jax import numpy as jnp, jit
4+
from ngcsimlib.compilers.process import transition
45
from ngclearn.utils import tensorstats
56

67
class RewardErrorCell(JaxComponent): ## Reward prediction error cell
@@ -50,8 +51,9 @@ def __init__(self, name, n_units, alpha, ema_window_len=10,
5051
self.accum_reward = Compartment(restVals) ## accumulated reward signal(s)
5152
self.n_ep_steps = Compartment(jnp.zeros((self.batch_size, 1))) ## number of episode steps taken
5253

54+
@transition(output_compartments=["mu", "rpe", "n_ep_steps", "accum_reward"])
5355
@staticmethod
54-
def _advance_state(dt, use_online_predictor, alpha, mu, rpe, reward,
56+
def advance_state(dt, use_online_predictor, alpha, mu, rpe, reward,
5557
n_ep_steps, accum_reward):
5658
## compute/update RPE and predictor values
5759
accum_reward = accum_reward + reward
@@ -61,41 +63,26 @@ def _advance_state(dt, use_online_predictor, alpha, mu, rpe, reward,
6163
n_ep_steps = n_ep_steps + 1
6264
return mu, rpe, n_ep_steps, accum_reward
6365

64-
@resolver(_advance_state)
65-
def advance_state(self, mu, rpe, n_ep_steps, accum_reward):
66-
self.mu.set(mu)
67-
self.rpe.set(rpe)
68-
self.n_ep_steps.set(n_ep_steps)
69-
self.accum_reward.set(accum_reward)
70-
66+
@transition(output_compartments=["mu"])
7167
@staticmethod
72-
def _evolve(dt, use_online_predictor, ema_window_len, n_ep_steps, mu,
68+
def evolve(dt, use_online_predictor, ema_window_len, n_ep_steps, mu,
7369
accum_reward):
7470
if use_online_predictor:
7571
## total episodic reward signal
7672
r = accum_reward/n_ep_steps
7773
mu = (1. - 1./ema_window_len) * mu + (1./ema_window_len) * r
7874
return mu
7975

80-
@resolver(_evolve)
81-
def evolve(self, mu):
82-
self.mu.set(mu)
83-
76+
@transition(output_compartments=["mu", "rpe", "accum_reward", "n_ep_steps"])
8477
@staticmethod
85-
def _reset(batch_size, n_units):
78+
def reset(batch_size, n_units):
8679
restVals = jnp.zeros((batch_size, n_units))
8780
mu = restVals
8881
rpe = restVals
8982
accum_reward = restVals
9083
n_ep_steps = jnp.zeros((batch_size, 1))
9184
return mu, rpe, accum_reward, n_ep_steps
9285

93-
@resolver(_reset)
94-
def reset(self, mu, rpe, accum_reward, n_ep_steps):
95-
self.mu.set(mu)
96-
self.rpe.set(rpe)
97-
self.accum_reward.set(accum_reward)
98-
self.n_ep_steps.set(n_ep_steps)
9986

10087
@classmethod
10188
def help(cls): ## component help function

0 commit comments

Comments
 (0)