Skip to content

Commit adc74cf

Browse files
committed
refactored Rate Cell
1 parent 74840d9 commit adc74cf

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
# %%
2+
13
from jax import numpy as jnp, random, jit
24
from functools import partial
35
from ngclearn.utils import tensorstats
4-
from ngclearn import resolver, Component, Compartment
6+
# from ngclearn import resolver, Component, Compartment
7+
from ngcsimlib.compartment import Compartment
8+
from ngcsimlib.compilers.process import transition
59
from ngclearn.components.jaxComponent import JaxComponent
610
from ngclearn.utils.model_utils import create_function, threshold_soft, \
711
threshold_cauchy
@@ -191,8 +195,9 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
191195
self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure
192196
self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity
193197

198+
@transition(output_compartments=["j", "j_td", "z", "zF"])
194199
@staticmethod
195-
def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
200+
def advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
196201
resist_scale, thresholdType, thr_lmbda, is_stateful, j, j_td, z):
197202
#if tau_m > 0.:
198203
if is_stateful:
@@ -220,27 +225,15 @@ def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
220225
zF = fx(z)
221226
return j, j_td, z, zF
222227

223-
@resolver(_advance_state)
224-
def advance_state(self, j, j_td, z, zF):
225-
self.j.set(j)
226-
self.j_td.set(j_td)
227-
self.z.set(z)
228-
self.zF.set(zF)
229-
228+
@transition(output_compartments=["j", "j_td", "z", "zF"])
230229
@staticmethod
231-
def _reset(batch_size, shape): #n_units
230+
def reset(batch_size, shape): #n_units
232231
_shape = (batch_size, shape[0])
233232
if len(shape) > 1:
234233
_shape = (batch_size, shape[0], shape[1], shape[2])
235234
restVals = jnp.zeros(_shape)
236235
return tuple([restVals for _ in range(4)])
237236

238-
@resolver(_reset)
239-
def reset(self, j, zF, j_td, z):
240-
self.j.set(j) # electrical current
241-
self.zF.set(zF) # rate-coded output - activity
242-
self.j_td.set(j_td) # top-down electrical current - pressure
243-
self.z.set(z) # rate activity
244237

245238
def save(self, directory, **kwargs):
246239
## do a protected save of constants, depending on whether they are floats or arrays

0 commit comments

Comments
 (0)