Skip to content

Commit 195e3db

Browse files
author
Alexander Ororbia
committed
sketch of ifcell test
1 parent 5ab8564 commit 195e3db

File tree

2 files changed

+28
-41
lines changed

2 files changed

+28
-41
lines changed

ngclearn/components/neurons/spiking/IFCell.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from ngclearn.components.jaxComponent import JaxComponent
2-
from jax import numpy as jnp, random, jit, nn
3-
from functools import partial
2+
from jax import numpy as jnp, random, nn, Array, jit
43
from ngclearn.utils import tensorstats
54
from ngcsimlib import deprecate_args
6-
from ngcsimlib.logger import info, warn
75
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
86
step_euler, step_rk2
9-
# from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
10-
# triangular_estimator,
11-
# straight_through_estimator)
7+
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
8+
triangular_estimator,
9+
straight_through_estimator)
1210

1311
from ngcsimlib.parser import compilable
1412
from ngcsimlib.compartment import Compartment
@@ -135,7 +133,7 @@ def __init__(
135133
display_name="Refractory Time Period", units="ms")
136134
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
137135
units="ms") ## time-of-last-spike
138-
self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
136+
#self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
139137

140138
@compilable
141139
def advance_state(

tests/components/neurons/spiking/test_IFCell.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,68 +2,57 @@
22
from ngcsimlib.context import Context
33
import numpy as np
44
np.random.seed(42)
5-
from ngclearn.components import IFCell
6-
from ngcsimlib.compilers import compile_command, wrap_command
7-
from numpy.testing import assert_array_equal
85

9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
12-
from ngcsimlib.context import Context
13-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
6+
from ngclearn import Context, MethodProcess
7+
from ngclearn.components.neurons.spiking.IFCell import IFCell
8+
from numpy.testing import assert_array_equal
149

1510
def test_IFCell1():
1611
name = "if_ctx"
1712
## create seeding keys
1813
dkey = random.PRNGKey(1234)
1914
dkey, *subkeys = random.split(dkey, 6)
2015
dt = 1. # ms
21-
trace_increment = 0.1
2216
# ---- build a simple Poisson cell system ----
2317
with Context(name) as ctx:
2418
a = IFCell(
2519
name="a", n_units=1, tau_m=5., resist_m=10., key=subkeys[0]
2620
)
2721

28-
#"""
29-
advance_process = (Process("advance_proc")
22+
# """
23+
advance_process = (MethodProcess("advance_proc")
3024
>> a.advance_state)
31-
#ctx.wrap_and_add_command(advance_process.pure, name="run")
32-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
25+
# ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3326

34-
reset_process = (Process("reset_proc")
27+
reset_process = (MethodProcess("reset_proc")
3528
>> a.reset)
36-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
37-
#"""
38-
39-
"""
40-
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
41-
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
42-
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
43-
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
44-
"""
45-
29+
# ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
30+
# """
4631
## set up non-compiled utility commands
47-
@Context.dynamicCommand
48-
def clamp(x):
49-
a.j.set(x)
32+
# @Context.dynamicCommand
33+
# def clamp(x):
34+
# a.j.set(x)
35+
36+
def clamp(x):
37+
a.j.set(x)
5038

5139
## input spike train
5240
x_seq = jnp.asarray([[1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.]], dtype=jnp.float32)
5341
## desired output/epsp pulses
5442
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.]], dtype=jnp.float32)
5543

5644
outs = []
57-
ctx.reset()
45+
reset_process.run() # ctx.reset()
5846
for ts in range(x_seq.shape[1]):
5947
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
60-
ctx.clamp(x_t)
61-
ctx.run(t=ts * 1., dt=dt)
62-
outs.append(a.s.value)
48+
clamp(x_t) # ctx.clamp(x_t)
49+
advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
50+
outs.append(a.s.get())
6351
outs = jnp.concatenate(outs, axis=1)
64-
print(outs)
65-
52+
# print(outs)
53+
# print(y_seq)
54+
6655
## output should equal input
6756
assert_array_equal(outs, y_seq)
6857

69-
#test_IFCell1()
58+
test_IFCell1()

0 commit comments

Comments
 (0)