Skip to content

Commit f6d56a4

Browse files
author
Alexander Ororbia
committed
revised unit-tests to pass globally; some minor patches to phasor-cell and lif
1 parent d4dfe38 commit f6d56a4

File tree

11 files changed

+43
-100
lines changed

11 files changed

+43
-100
lines changed

ngclearn/components/input_encoders/phasorCell.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
4444
## Layer Size Setup
4545
self.batch_size = batch_size
4646
self.n_units = n_units
47-
_key, subkey = random.split(self.key.value, 2)
47+
_key, *subkey = random.split(self.key.value, 3)
4848
self.key.set(_key)
4949
## Compartment setup
5050
restVals = jnp.zeros((self.batch_size, self.n_units))
@@ -62,7 +62,7 @@ def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
6262
# alpha = ((random.normal(subkey, self.angles.value.shape) * (jnp.sqrt(target_freq) / target_freq)) + 1)
6363
# beta = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq
6464

65-
self.base_scale = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq
65+
self.base_scale = random.poisson(subkey[0], lam=target_freq, shape=self.angles.value.shape) / target_freq
6666

6767
def validate(self, dt=None, **validation_kwargs):
6868
valid = super().validate(**validation_kwargs)
@@ -95,11 +95,11 @@ def advance_state(t, dt, target_freq, key, inputs, angles, tols, base_scale):
9595
angle_per_event = 2 * jnp.pi # rad / e
9696
angle_per_timestep = angle_per_event / time_step_per_event # rad / e
9797
# * e/ts -> rad / ts
98-
key, subkey = random.split(key, 2)
98+
key, *subkey = random.split(key, 3)
9999
# scatter = random.uniform(subkey, angles.shape, minval=0.5,
100100
# maxval=1.5) * base_scale
101101

102-
scatter = ((random.normal(subkey, angles.shape) * 0.2) + 1) * base_scale
102+
scatter = ((random.normal(subkey[0], angles.shape) * 0.2) + 1) * base_scale
103103
scattered_update = angle_per_timestep * scatter
104104
scaled_scattered_update = scattered_update * inputs
105105

@@ -116,7 +116,7 @@ def advance_state(t, dt, target_freq, key, inputs, angles, tols, base_scale):
116116
@staticmethod
117117
def reset(batch_size, n_units, key, target_freq):
118118
restVals = jnp.zeros((batch_size, n_units))
119-
key, subkey = random.split(key, 2)
119+
key, *subkey = random.split(key, 3)
120120
return restVals, restVals, restVals, restVals, key
121121

122122
def save(self, directory, **kwargs):

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
"""
2-
from jax import numpy as jnp, random, jit, nn
3-
from functools import partial
4-
from ngclearn.utils import tensorstats
5-
from ngcsimlib.deprecators import deprecate_args
6-
from ngclearn import resolver, Component, Compartment
7-
from ngclearn.components.jaxComponent import JaxComponent
8-
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
9-
step_euler, step_rk2
10-
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
11-
triangular_estimator,
12-
straight_through_estimator)
13-
"""
141
from ngclearn.components.jaxComponent import JaxComponent
152
from jax import numpy as jnp, random, jit, nn
163
from functools import partial

tests/components/input_encoders/test_bernoulliCell.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,34 @@
33
import numpy as np
44
np.random.seed(42)
55
from ngclearn.components import BernoulliCell
6-
from ngcsimlib.compilers import compile_command, wrap_command
6+
#from ngcsimlib.compilers import compile_command, wrap_command
77
from numpy.testing import assert_array_equal
88

99
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
10+
#from ngcsimlib.component import Component
11+
#from ngcsimlib.compartment import Compartment
1212
from ngcsimlib.context import Context
13-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
13+
#from ngcsimlib.utils.compartment import Get_Compartment_Batch
1414

1515

1616
def test_bernoulliCell1():
17+
name = "bernoulli_ctx"
1718
## create seeding keys
1819
dkey = random.PRNGKey(1234)
1920
dkey, *subkeys = random.split(dkey, 6)
2021
dt = 1. # ms
2122
#T = 300 # ms
2223
# ---- build a simple Bernoulli cell system ----
23-
with Context("Circuit") as ctx:
24+
with Context(name) as ctx:
2425
a = BernoulliCell(name="a", n_units=1, key=subkeys[0])
2526

2627
advance_process = (Process()
2728
>> a.advance_state)
28-
ctx.wrap_and_add_command(advance_process.pure, name="run")
29+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
2930

3031
reset_process = (Process()
3132
>> a.reset)
32-
ctx.wrap_and_add_command(reset_process.pure, name="reset")
33+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3334

3435
## set up non-compiled utility commands
3536
@Context.dynamicCommand

tests/components/input_encoders/test_latencyCell.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515

1616
def test_latencyCell1():
17+
name = "latency_ctx"
1718
## create seeding keys
1819
dkey = random.PRNGKey(1234)
1920
dkey, *subkeys = random.split(dkey, 6)
2021
T = 50 # 100 #5 ## number of simulation steps to run
2122
dt = 1. # 0.1 # ms ## compute integration time constant
2223
tau = 1.
2324
# ---- build a simple Poisson cell system ----
24-
with Context("Circuit") as ctx:
25+
with Context(name) as ctx:
2526
a = LatencyCell(
2627
"a", n_units=4, tau=tau, threshold=0.01, linearize=True,
2728
normalize=True, num_steps=T, clip_spikes=False
@@ -30,13 +31,13 @@ def test_latencyCell1():
3031
## create and compile core simulation commands
3132
advance_process = (Process()
3233
>> a.advance_state)
33-
ctx.wrap_and_add_command(advance_process.pure, name="advance")
34+
ctx.wrap_and_add_command(jit(advance_process.pure), name="advance")
3435
calc_spike_times_process = (Process()
3536
>> a.calc_spike_times)
36-
ctx.wrap_and_add_command(calc_spike_times_process.pure, name="calc_spike_times")
37+
ctx.wrap_and_add_command(jit(calc_spike_times_process.pure), name="calc_spike_times")
3738
reset_process = (Process()
3839
>> a.reset)
39-
ctx.wrap_and_add_command(reset_process.pure, name="reset")
40+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
4041

4142
## set up non-compiled utility commands
4243
@Context.dynamicCommand

tests/components/input_encoders/test_phasorCell.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,25 @@
33
import numpy as np
44
np.random.seed(42)
55
from ngclearn.components import PhasorCell
6-
from ngcsimlib.compilers import compile_command, wrap_command
6+
#from ngcsimlib.compilers import compile_command, wrap_command
77
from numpy.testing import assert_array_equal
88

99
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
12-
from ngcsimlib.context import Context
13-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
10+
#from ngcsimlib.component import Component
11+
#from ngcsimlib.compartment import Compartment
12+
#from ngcsimlib.context import Context
13+
#from ngcsimlib.utils.compartment import Get_Compartment_Batch
1414

1515

1616
def test_phasorCell1():
17+
name = "phasor_ctx"
1718
## create seeding keys
1819
dkey = random.PRNGKey(1234)
1920
dkey, *subkeys = random.split(dkey, 6)
2021
dt = 1. # ms
2122
# T = 300 # ms
2223
# ---- build a simple Poisson cell system ----
23-
with Context("Circuit") as ctx:
24+
with Context(name) as ctx:
2425
a = PhasorCell(name="a", n_units=1, target_freq=1000., key=subkeys[0])
2526

2627
advance_process = (Process()
@@ -52,4 +53,4 @@ def clamp(x):
5253
## output should equal input
5354
assert_array_equal(outs, x_seq)
5455

55-
test_phasorCell1()
56+
#test_phasorCell1()

tests/components/input_encoders/test_poissonCell.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,23 @@
1414

1515

1616
def test_poissonCell1():
17+
name = "poisson_ctx"
1718
## create seeding keys
1819
dkey = random.PRNGKey(1234)
1920
dkey, *subkeys = random.split(dkey, 6)
2021
dt = 1. # ms
2122
# T = 300 # ms
2223
# ---- build a simple Poisson cell system ----
23-
with Context("Circuit") as ctx:
24+
with Context(name) as ctx:
2425
a = PoissonCell(name="a", n_units=1, target_freq=1000., key=subkeys[0])
2526

2627
advance_process = (Process()
2728
>> a.advance_state)
28-
ctx.wrap_and_add_command(advance_process.pure, name="run")
29+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
2930

3031
reset_process = (Process()
3132
>> a.reset)
32-
ctx.wrap_and_add_command(reset_process.pure, name="reset")
33+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3334

3435
## set up non-compiled utility commands
3536
@Context.dynamicCommand

tests/components/neurons/graded/test_rateCell.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

tests/components/neurons/spiking/test_LIFCell.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
from ngcsimlib.utils.compartment import Get_Compartment_Batch
1414

1515
def test_LIFCell1():
16+
name = "lif_ctx"
1617
## create seeding keys
1718
dkey = random.PRNGKey(1234)
1819
dkey, *subkeys = random.split(dkey, 6)
1920
dt = 1. # ms
2021
trace_increment = 0.1
2122
# ---- build a simple Poisson cell system ----
22-
with Context("Circuit") as ctx:
23+
with Context(name) as ctx:
2324
a = LIFCell(
2425
name="a", n_units=1, tau_m=5., resist_m=30., key=subkeys[0]
2526
)
@@ -65,4 +66,4 @@ def clamp(x):
6566
## output should equal input
6667
assert_array_equal(outs, y_seq)
6768

68-
test_LIFCell1()
69+
#test_LIFCell1()

tests/components/neurons/spiking/test_sLIFCell.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
from ngcsimlib.utils.compartment import Get_Compartment_Batch
1414

1515
def test_sLIFCell1():
16+
name = "slif_ctx"
1617
## create seeding keys
1718
dkey = random.PRNGKey(1234)
1819
dkey, *subkeys = random.split(dkey, 6)
1920
dt = 1. # ms
2021
trace_increment = 0.1
2122
# ---- build a simple Poisson cell system ----
22-
with Context("Circuit") as ctx:
23+
with Context(name) as ctx:
2324
a = SLIFCell(
2425
name="a", n_units=1, tau_m=50., resist_m=10., thr=0.3, key=subkeys[0]
2526
)

tests/components/other/test_expKernel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,25 @@
1313
from ngcsimlib.utils.compartment import Get_Compartment_Batch
1414

1515
def test_expKernel1():
16+
name = "expKernel_ctx"
1617
## create seeding keys
1718
dkey = random.PRNGKey(1234)
1819
dkey, *subkeys = random.split(dkey, 6)
1920
dt = 1. # ms
2021
trace_increment = 0.1
2122
# ---- build a simple Poisson cell system ----
22-
with Context("Circuit") as ctx:
23+
with Context(name) as ctx:
2324
a = ExpKernel(
2425
name="a", n_units=1, dt=1., tau_w=500., nu=4., key=subkeys[0]
2526
)
2627

2728
advance_process = (Process()
2829
>> a.advance_state)
29-
ctx.wrap_and_add_command(advance_process.pure, name="run")
30+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3031

3132
reset_process = (Process()
3233
>> a.reset)
33-
ctx.wrap_and_add_command(reset_process.pure, name="reset")
34+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3435

3536
## set up non-compiled utility commands
3637
@Context.dynamicCommand

0 commit comments

Comments
 (0)