Skip to content

Commit 10dc640

Browse files
author
Alexander Ororbia
committed
minor revisions to input-encoders, revised phasor-cell w/ unit-test
1 parent adc74cf commit 10dc640

File tree

8 files changed

+86
-32
lines changed

8 files changed

+86
-32
lines changed

ngclearn/components/input_encoders/latencyCell.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ class LatencyCell(JaxComponent):
142142
143143
num_steps: number of discrete time steps to consider for normalized latency
144144
code (only useful if "normalize" is set to True) (Default: 1)
145+
146+
batch_size: batch size dimension of this cell (Default: 1)
145147
"""
146148

147149
# Define Functions

ngclearn/components/input_encoders/phasorCell.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
from ngclearn import resolver, Compartment
21
from ngclearn.components.jaxComponent import JaxComponent
3-
from ngclearn.utils import tensorstats
42
from jax import numpy as jnp, random
5-
from ngcsimlib.logger import warn
3+
from ngclearn.utils import tensorstats
4+
from ngcsimlib.deprecators import deprecate_args
5+
from ngcsimlib.logger import info, warn
6+
7+
from ngcsimlib.compilers.process import transition
8+
#from ngcsimlib.component import Component
9+
from ngcsimlib.compartment import Compartment
10+
611

712
class PhasorCell(JaxComponent):
813
"""
@@ -12,8 +17,9 @@ class PhasorCell(JaxComponent):
1217
| inputs - input (takes in external signals)
1318
| --- Cell State Compartments: ---
1419
| key - JAX PRNG key
20+
| angles - current angle of phasor
1521
| --- Cell Output Compartments: ---
16-
| outputs - output
22+
| outputs - output of phasor cell
1723
| tols - time-of-last-spike
1824
1925
Args:
@@ -23,6 +29,8 @@ class PhasorCell(JaxComponent):
2329
2430
target_freq: maximum frequency (in Hertz) of this spike train
2531
(must be > 0.)
32+
33+
batch_size: batch size dimension of this cell (Default: 1)
2634
"""
2735

2836
# Define Functions
@@ -63,8 +71,7 @@ def validate(self, dt=None, **validation_kwargs):
6371
return False
6472
## check for unstable combinations of dt and target-frequency
6573
# meta-params
66-
events_per_timestep = (
67-
dt / 1000.) * self.target_freq ##
74+
events_per_timestep = (dt / 1000.) * self.target_freq ##
6875
# compute scaled probability
6976
if events_per_timestep > 1.:
7077
valid = False
@@ -78,9 +85,9 @@ def validate(self, dt=None, **validation_kwargs):
7885
)
7986
return valid
8087

88+
@transition(output_compartments=["outputs", "tols", "key", "angles"])
8189
@staticmethod
82-
def _advance_state(t, dt, target_freq, key,
83-
inputs, angles, tols, base_scale):
90+
def advance_state(t, dt, target_freq, key, inputs, angles, tols, base_scale):
8491
ms_per_second = 1000 # ms/s
8592
events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
8693
ms_per_event = 1 / events_per_ms # ms/e
@@ -105,27 +112,13 @@ def _advance_state(t, dt, target_freq, key,
105112

106113
return outputs, tols, key, updated_angles
107114

108-
@resolver(_advance_state)
109-
def advance_state(self, outputs, tols, key, angles):
110-
self.outputs.set(outputs)
111-
self.tols.set(tols)
112-
self.key.set(key)
113-
self.angles.set(angles)
114-
115+
@transition(output_compartments=["inputs", "outputs", "tols", "angles", "key"])
115116
@staticmethod
116-
def _reset(batch_size, n_units, key, target_freq):
117+
def reset(batch_size, n_units, key, target_freq):
117118
restVals = jnp.zeros((batch_size, n_units))
118119
key, subkey = random.split(key, 2)
119120
return restVals, restVals, restVals, restVals, key
120121

121-
@resolver(_reset)
122-
def reset(self, inputs, outputs, tols, angles, key):
123-
self.inputs.set(inputs)
124-
self.outputs.set(outputs)
125-
self.tols.set(tols)
126-
self.key.set(key)
127-
self.angles.set(angles)
128-
129122
def save(self, directory, **kwargs):
130123
file_name = directory + "/" + self.name + ".npz"
131124
jnp.savez(file_name, key=self.key.value)
@@ -154,7 +147,7 @@ def help(cls): ## component help function
154147
hyperparams = {
155148
"n_units": "Number of neuronal cells to model in this layer",
156149
"batch_size": "Batch size dimension of this component",
157-
"target_freq": "Maximum spike frequency of the train produced",
150+
"target_freq": "Maximum spike frequency of the (spike) train produced",
158151
}
159152
info = {cls.__name__: properties,
160153
"compartments": compartment_props,

ngclearn/components/input_encoders/poissonCell.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class PoissonCell(JaxComponent):
2727
n_units: number of cellular entities (neural population size)
2828
2929
target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
30+
31+
batch_size: batch size dimension of this cell (Default: 1)
3032
"""
3133

3234
@deprecate_args(max_freq="target_freq")

ngclearn/utils/patch_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def create_patches(self, add_frame=False, center=True):
118118

119119

120120

121-
def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True): ## scikit
121+
def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True, seed=1234): ## scikit
122122
"""
123123
Generates a set of patches from an array/list of image arrays (via
124124
random sampling with replacement). This uses scikit-learn's patch creation
@@ -134,6 +134,8 @@ def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True):
134134
135135
center: centers each patch by subtracting the patch mean (per-patch)
136136
137+
seed: seed to control the random state of internal patch sampling
138+
137139
Returns:
138140
an array (D x (pH * pW)), where each row is a flattened patch sample
139141
"""
@@ -143,7 +145,7 @@ def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True):
143145
for s in range(_x_batch.shape[0]):
144146
xs = _x_batch[s, :]
145147
xs = xs.reshape(px, py)
146-
patches = extract_patches_2d(xs, patch_size, max_patches=max_patches)#, random_state=69)
148+
patches = extract_patches_2d(xs, patch_size, max_patches=max_patches, random_state=seed)#, random_state=69)
147149
patches = np.reshape(patches, (len(patches), -1)) # flatten each patch in set
148150
if s > 0:
149151
p_batch = np.concatenate((p_batch,patches),axis=0)

tests/components/input_encoders/test_bernoulliCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@ def clamp(x):
5252
assert_array_equal(outs, x_seq)
5353
print(outs)
5454

55-
test_bernoulliCell1()
55+
#test_bernoulliCell1()

tests/components/input_encoders/test_latencyCell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ngcsimlib.utils.compartment import Get_Compartment_Batch
1414

1515

16-
def test_latencyCell():
16+
def test_latencyCell1():
1717
## create seeding keys
1818
dkey = random.PRNGKey(1234)
1919
dkey, *subkeys = random.split(dkey, 6)
@@ -69,4 +69,4 @@ def clamp(x):
6969
## output should equal input
7070
assert_array_equal(outs, targets)
7171

72-
test_latencyCell()
72+
#test_latencyCell1()
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import PhasorCell
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
14+
15+
16+
def test_phasorCell1():
17+
## create seeding keys
18+
dkey = random.PRNGKey(1234)
19+
dkey, *subkeys = random.split(dkey, 6)
20+
dt = 1. # ms
21+
# T = 300 # ms
22+
# ---- build a simple Poisson cell system ----
23+
with Context("Circuit") as ctx:
24+
a = PhasorCell(name="a", n_units=1, target_freq=1000., key=subkeys[0])
25+
26+
advance_process = (Process()
27+
>> a.advance_state)
28+
ctx.wrap_and_add_command(advance_process.pure, name="run")
29+
30+
reset_process = (Process()
31+
>> a.reset)
32+
ctx.wrap_and_add_command(reset_process.pure, name="reset")
33+
34+
## set up non-compiled utility commands
35+
@Context.dynamicCommand
36+
def clamp(x):
37+
a.inputs.set(x)
38+
39+
## input spike train
40+
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
41+
42+
outs = []
43+
ctx.reset()
44+
for ts in range(x_seq.shape[1]):
45+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
46+
ctx.clamp(x_t)
47+
ctx.run(t=ts * 1., dt=dt)
48+
outs.append(a.outputs.value)
49+
#print(a.outputs.value)
50+
outs = jnp.concatenate(outs, axis=1)
51+
52+
## output should equal input
53+
assert_array_equal(outs, x_seq)
54+
55+
test_phasorCell1()

tests/components/input_encoders/test_poissonCell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ngcsimlib.utils.compartment import Get_Compartment_Batch
1414

1515

16-
def test_poissonCell():
16+
def test_poissonCell1():
1717
## create seeding keys
1818
dkey = random.PRNGKey(1234)
1919
dkey, *subkeys = random.split(dkey, 6)
@@ -51,4 +51,4 @@ def clamp(x):
5151
## output should equal input
5252
assert_array_equal(outs, x_seq)
5353

54-
test_poissonCell()
54+
#test_poissonCell1()

0 commit comments

Comments
 (0)