Skip to content

Commit a685fcd

Browse files
committed
update reset methods
1 parent 01454f0 commit a685fcd

File tree

6 files changed

+73
-3
lines changed

6 files changed

+73
-3
lines changed

ngclearn/components/neurons/graded/bernoulliErrorCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
1313
"""
1414
A simple (non-spiking) Bernoulli error cell - this is a fixed-point solution
15-
of a mismatch signal. Specifically, this cell operates as a factorized multivariate
15+
of a mismatch signal. Specifically, this cell operates as a factorized multivariate
1616
Bernoulli distribution.
1717
1818
| --- Cell Input Compartments: ---

ngclearn/components/neurons/graded/gaussianErrorCell.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,32 @@ def advance_state(self, dt): ## compute Gaussian error cell output
108108
self.L.set(jnp.squeeze(L))
109109
self.mask.set(mask)
110110

111+
# @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"])
112+
# @staticmethod
113+
@compilable
114+
def reset(self, batch_size, shape, sigma_shape): ## reset core components/statistics
115+
_shape = (batch_size, shape[0])
116+
if len(shape) > 1:
117+
_shape = (batch_size, shape[0], shape[1], shape[2])
118+
restVals = jnp.zeros(_shape)
119+
dmu = restVals
120+
dtarget = restVals
121+
dSigma = jnp.zeros(sigma_shape)
122+
target = restVals
123+
mu = restVals
124+
modulator = mu + 1.
125+
L = 0. #jnp.zeros((1, 1))
126+
mask = jnp.ones(_shape)
127+
128+
self.dmu.set(dmu)
129+
self.dtarget.set(dtarget)
130+
self.dSigma.set(dSigma)
131+
self.target.set(target)
132+
self.mu.set(mu)
133+
self.modulator.set(modulator)
134+
self.L.set(L)
135+
self.mask.set(mask)
136+
111137
@classmethod
112138
def help(cls): ## component help function
113139
properties = {

ngclearn/components/neurons/graded/laplacianErrorCell.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,26 @@ def advance_state(self, dt): ## compute Laplacian error cell output
103103
self.L.set(jnp.squeeze(L))
104104
self.mask.set(mask)
105105

106+
def reset(self, batch_size, n_units, scale_shape):
107+
restVals = jnp.zeros((batch_size, n_units))
108+
dshift = restVals
109+
dtarget = restVals
110+
dScale = jnp.zeros(scale_shape)
111+
target = restVals
112+
shift = restVals
113+
modulator = shift + 1.
114+
L = 0.
115+
mask = jnp.ones((batch_size, n_units))
116+
117+
self.dshift.set(dshift)
118+
self.dtarget.set(dtarget)
119+
self.dScale.set(dScale)
120+
self.target.set(target)
121+
self.shift.set(shift)
122+
self.modulator.set(modulator)
123+
self.L.set(L)
124+
self.mask.set(mask)
125+
106126
@classmethod
107127
def help(cls): ## component help function
108128
properties = {

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,18 @@ def advance_state(self, dt):
251251
self.z.set(z)
252252
self.zF.set(zF)
253253

254+
@compilable
255+
def reset(self, batch_size, shape): #n_units
256+
_shape = (batch_size, shape[0])
257+
if len(shape) > 1:
258+
_shape = (batch_size, shape[0], shape[1], shape[2])
259+
restVals = jnp.zeros(_shape)
260+
self.j.set(restVals)
261+
self.j_td.set(restVals)
262+
self.z.set(restVals)
263+
self.zF.set(restVals)
264+
265+
254266
def save(self, directory, **kwargs):
255267
## do a protected save of constants, depending on whether they are floats or arrays
256268
tau_m = (self.tau_m if isinstance(self.tau_m, float)

ngclearn/components/neurons/graded/rewardErrorCell.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,18 @@ def evolve(self, dt):
9191
# Update compartment
9292
self.mu.set(mu)
9393

94+
@compilable
95+
def reset(self, batch_size, n_units):
96+
restVals = jnp.zeros((batch_size, n_units))
97+
mu = restVals
98+
rpe = restVals
99+
accum_reward = restVals
100+
n_ep_steps = jnp.zeros((batch_size, 1))
101+
self.mu.set(mu)
102+
self.rpe.set(rpe)
103+
self.accum_reward.set(accum_reward)
104+
self.n_ep_steps.set(n_ep_steps)
105+
94106
@classmethod
95107
def help(cls): ## component help function
96108
properties = {

ngclearn/components/other/expKernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ def advance_state(self, t):
7171
# Get the variables
7272
inputs = self.inputs.get()
7373
tf = self.tf.get()
74-
74+
7575
s = inputs
7676
## update spike time window and corresponding window volume
7777
tf, epsp = _apply_kernel(
7878
tf, s, t, self.tau_w, self.win_len, krn_start=0, krn_end=self.win_len-1
7979
) #0:win_len-1)
80-
80+
8181
# Update compartments
8282
self.epsp.set(epsp)
8383
self.tf.set(tf)

0 commit comments

Comments
 (0)