Skip to content

Commit 0a803ff

Browse files
author
Alexander Ororbia
committed
revised and add unit-test for varTrace
1 parent 10dc640 commit 0a803ff

File tree

2 files changed

+76
-17
lines changed

2 files changed

+76
-17
lines changed

ngclearn/components/other/varTrace.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
from ngclearn.components.jaxComponent import JaxComponent
12
from jax import numpy as jnp, random, jit
23
from functools import partial
3-
from ngclearn import resolver, Component, Compartment
4-
from ngclearn.components.jaxComponent import JaxComponent
54
from ngclearn.utils import tensorstats
5+
from ngcsimlib.deprecators import deprecate_args
6+
from ngcsimlib.logger import info, warn
7+
8+
from ngcsimlib.compilers.process import transition
9+
#from ngcsimlib.component import Component
10+
from ngcsimlib.compartment import Compartment
611

712
@partial(jit, static_argnums=[4])
813
def _run_varfilter(dt, x, x_tr, decayFactor, gamma_tr, a_delta=0.):
@@ -54,13 +59,17 @@ class VarTrace(JaxComponent): ## low-pass filter
5459
a_delta: value to increment a trace by in presence of a spike; note if set
5560
to a value <= 0, then a piecewise gated trace will be used instead
5661
62+
gamma_tr: an extra multiplier in front of the leak of the trace (Default: 1)
63+
5764
decay_type: string indicating the decay type to be applied to ODE
5865
integration; low-pass filter configuration
5966
6067
:Note: string values that this can be (Default: "exp") are:
6168
1) `'lin'` = linear trace filter, i.e., decay = x_tr + (-x_tr) * (dt/tau_tr);
6269
2) `'exp'` = exponential trace filter, i.e., decay = exp(-dt/tau_tr) * x_tr;
6370
3) `'step'` = step trace, i.e., decay = 0 (a pulse applied upon input value)
71+
72+
batch_size: batch size dimension of this cell (Default: 1)
6473
"""
6574

6675
# Define Functions
@@ -83,38 +92,28 @@ def __init__(self, name, n_units, tau_tr, a_delta, gamma_tr=1, decay_type="exp",
8392
self.outputs = Compartment(restVals) # output compartment
8493
self.trace = Compartment(restVals)
8594

95+
@transition(output_compartments=["outputs", "trace"])
8696
@staticmethod
87-
def _advance_state(dt, decay_type, tau_tr, a_delta, gamma_tr, inputs, trace):
97+
def advance_state(dt, decay_type, tau_tr, a_delta, gamma_tr, inputs, trace):
8898
decayFactor = 0.
8999
if "exp" in decay_type:
90100
decayFactor = jnp.exp(-dt/tau_tr)
91101
elif "lin" in decay_type:
92102
decayFactor = (1. - dt/tau_tr)
93-
94103
_x_tr = gamma_tr * trace * decayFactor
95104
if a_delta > 0.:
96105
_x_tr = _x_tr + inputs * a_delta
97106
else:
98107
_x_tr = _x_tr * (1. - inputs) + inputs
99-
108+
trace = _x_tr
100109
return trace, trace
101110

102-
@resolver(_advance_state)
103-
def advance_state(self, outputs, trace):
104-
self.outputs.set(outputs)
105-
self.trace.set(trace)
106-
111+
@transition(output_compartments=["inputs", "outputs", "trace"])
107112
@staticmethod
108-
def _reset(batch_size, n_units):
113+
def reset(batch_size, n_units):
109114
restVals = jnp.zeros((batch_size, n_units))
110115
return restVals, restVals, restVals
111116

112-
@resolver(_reset)
113-
def reset(self, inputs, outputs, trace):
114-
self.inputs.set(inputs)
115-
self.outputs.set(outputs)
116-
self.trace.set(trace)
117-
118117
@classmethod
119118
def help(cls): ## component help function
120119
properties = {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import VarTrace
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
14+
15+
16+
def test_varTrace1():
17+
## create seeding keys
18+
dkey = random.PRNGKey(1234)
19+
dkey, *subkeys = random.split(dkey, 6)
20+
dt = 1. # ms
21+
trace_increment = 0.1
22+
# ---- build a simple Poisson cell system ----
23+
with Context("Circuit") as ctx:
24+
a = VarTrace(
25+
name="a", n_units=1, a_delta=trace_increment, decay_type="step", tau_tr=1.,
26+
key=subkeys[0]
27+
)
28+
29+
advance_process = (Process()
30+
>> a.advance_state)
31+
ctx.wrap_and_add_command(advance_process.pure, name="run")
32+
33+
reset_process = (Process()
34+
>> a.reset)
35+
ctx.wrap_and_add_command(reset_process.pure, name="reset")
36+
37+
## set up non-compiled utility commands
38+
@Context.dynamicCommand
39+
def clamp(x):
40+
a.inputs.set(x)
41+
42+
## input spike train
43+
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
44+
## desired output pulses
45+
y_seq = x_seq * trace_increment
46+
47+
outs = []
48+
ctx.reset()
49+
for ts in range(x_seq.shape[1]):
50+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
51+
ctx.clamp(x_t)
52+
ctx.run(t=ts * 1., dt=dt)
53+
outs.append(a.outputs.value)
54+
outs = jnp.concatenate(outs, axis=1)
55+
#print(outs)
56+
57+
## output should equal input
58+
assert_array_equal(outs, y_seq)
59+
60+
#test_varTrace1()

0 commit comments

Comments
 (0)