|
| 1 | +# %% |
| 2 | + |
1 | 3 | from jax import numpy as jnp, random, jit |
2 | 4 | from functools import partial |
3 | 5 | from ngclearn.utils import tensorstats |
4 | | -from ngclearn import resolver, Component, Compartment |
| 6 | +# from ngclearn import resolver, Component, Compartment |
| 7 | +from ngcsimlib.compartment import Compartment |
| 8 | +from ngcsimlib.compilers.process import transition |
5 | 9 | from ngclearn.components.jaxComponent import JaxComponent |
6 | 10 | from ngclearn.utils.model_utils import create_function, threshold_soft, \ |
7 | 11 | threshold_cauchy |
@@ -191,8 +195,9 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit |
191 | 195 | self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure |
192 | 196 | self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity |
193 | 197 |
|
| 198 | + @transition(output_compartments=["j", "j_td", "z", "zF"]) |
194 | 199 | @staticmethod |
195 | | - def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, |
| 200 | + def advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, |
196 | 201 | resist_scale, thresholdType, thr_lmbda, is_stateful, j, j_td, z): |
197 | 202 | #if tau_m > 0.: |
198 | 203 | if is_stateful: |
@@ -220,27 +225,15 @@ def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, |
220 | 225 | zF = fx(z) |
221 | 226 | return j, j_td, z, zF |
222 | 227 |
|
223 | | - @resolver(_advance_state) |
224 | | - def advance_state(self, j, j_td, z, zF): |
225 | | - self.j.set(j) |
226 | | - self.j_td.set(j_td) |
227 | | - self.z.set(z) |
228 | | - self.zF.set(zF) |
229 | | - |
| 228 | + @transition(output_compartments=["j", "j_td", "z", "zF"]) |
230 | 229 | @staticmethod |
231 | | - def _reset(batch_size, shape): #n_units |
| 230 | + def reset(batch_size, shape): #n_units |
232 | 231 | _shape = (batch_size, shape[0]) |
233 | 232 | if len(shape) > 1: |
234 | 233 | _shape = (batch_size, shape[0], shape[1], shape[2]) |
235 | 234 | restVals = jnp.zeros(_shape) |
236 | 235 | return tuple([restVals for _ in range(4)]) |
237 | 236 |
|
238 | | - @resolver(_reset) |
239 | | - def reset(self, j, zF, j_td, z): |
240 | | - self.j.set(j) # electrical current |
241 | | - self.zF.set(zF) # rate-coded output - activity |
242 | | - self.j_td.set(j_td) # top-down electrical current - pressure |
243 | | - self.z.set(z) # rate activity |
244 | 237 |
|
245 | 238 | def save(self, directory, **kwargs): |
246 | 239 | ## do a protected save of constants, depending on whether they are floats or arrays |
|
0 commit comments