Skip to content

Commit d850f9d

Browse files
author
Alexander Ororbia
committed
revised and added unit-test for exp-kernel
1 parent 0a803ff commit d850f9d

File tree

2 files changed

+72
-16
lines changed

2 files changed

+72
-16
lines changed

ngclearn/components/other/expKernel.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
from jax import numpy as jnp, jit
2-
from functools import partial
3-
from ngclearn import resolver, Component, Compartment
41
from ngclearn.components.jaxComponent import JaxComponent
2+
from jax import numpy as jnp, random, jit
3+
from functools import partial
54
from ngclearn.utils import tensorstats
5+
from ngcsimlib.deprecators import deprecate_args
6+
from ngcsimlib.logger import info, warn
7+
8+
from ngcsimlib.compilers.process import transition
9+
#from ngcsimlib.component import Component
10+
from ngcsimlib.compartment import Compartment
611

712
@partial(jit, static_argnums=[5,6])
813
def _apply_kernel(tf_curr, s, t, tau_w, win_len, krn_start, krn_end):
@@ -40,6 +45,8 @@ class ExpKernel(JaxComponent): ## exponential kernel
4045
nu: (ms, spike time interval for window)
4146
4247
tau_w: spike window time constant (in micro-secs, or nano-s)
48+
49+
batch_size: batch size dimension of this cell (Default: 1)
4350
"""
4451

4552
# Define Functions
@@ -60,31 +67,22 @@ def __init__(self, name, n_units, dt, tau_w=500., nu=4., batch_size=1, **kwargs)
6067
## window of spike times
6168
self.tf = Compartment(jnp.zeros((self.win_len, self.batch_size, self.n_units)))
6269

70+
@transition(output_compartments=["epsp", "tf"])
6371
@staticmethod
64-
def _advance_state(t, tau_w, win_len, inputs, tf):
72+
def advance_state(t, tau_w, win_len, inputs, tf):
6573
s = inputs
6674
## update spike time window and corresponding window volume
6775
tf, epsp = _apply_kernel(tf, s, t, tau_w, win_len, krn_start=0,
6876
krn_end=win_len-1) #0:win_len-1)
6977
return epsp, tf
7078

71-
@resolver(_advance_state)
72-
def advance_state(self, epsp, tf):
73-
self.epsp.set(epsp)
74-
self.tf.set(tf)
75-
79+
@transition(output_compartments=["inputs", "epsp", "tf"])
7680
@staticmethod
77-
def _reset(batch_size, n_units, win_len):
81+
def reset(batch_size, n_units, win_len):
7882
restVals = jnp.zeros((batch_size, n_units))
7983
restTensor = jnp.zeros([win_len, batch_size, n_units], jnp.float32) # tf
8084
return restVals, restVals, restTensor # inputs, epsp, tf
8185

82-
@resolver(_reset)
83-
def reset(self, inputs, epsp, tf):
84-
self.inputs.set(inputs)
85-
self.epsp.set(epsp)
86-
self.tf.set(tf)
87-
8886
@classmethod
8987
def help(cls): ## component help function
9088
properties = {
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import ExpKernel
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
14+
15+
def test_expKernel1():
16+
## create seeding keys
17+
dkey = random.PRNGKey(1234)
18+
dkey, *subkeys = random.split(dkey, 6)
19+
dt = 1. # ms
20+
trace_increment = 0.1
21+
# ---- build a simple Poisson cell system ----
22+
with Context("Circuit") as ctx:
23+
a = ExpKernel(
24+
name="a", n_units=1, dt=1., tau_w=500., nu=4., key=subkeys[0]
25+
)
26+
27+
advance_process = (Process()
28+
>> a.advance_state)
29+
ctx.wrap_and_add_command(advance_process.pure, name="run")
30+
31+
reset_process = (Process()
32+
>> a.reset)
33+
ctx.wrap_and_add_command(reset_process.pure, name="reset")
34+
35+
## set up non-compiled utility commands
36+
@Context.dynamicCommand
37+
def clamp(x):
38+
a.inputs.set(x)
39+
40+
## input spike train
41+
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
42+
## desired output/epsp pulses
43+
y_seq = jnp.asarray([[0., 1., 0.998002, 0.996008, 1.9940181]], dtype=jnp.float32)
44+
45+
outs = []
46+
ctx.reset()
47+
for ts in range(x_seq.shape[1]):
48+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
49+
ctx.clamp(x_t)
50+
ctx.run(t=ts * 1., dt=dt)
51+
outs.append(a.epsp.value)
52+
outs = jnp.concatenate(outs, axis=1)
53+
#print(outs)
54+
55+
## output should equal input
56+
assert_array_equal(outs, y_seq)
57+
58+
#test_expKernel1()

0 commit comments

Comments
 (0)