Skip to content

Commit 588e3f5

Browse files
author
Alexander Ororbia
committed
fixed minor issues in input-encoders, further revisions to docs for v3
1 parent 043b0a8 commit 588e3f5

File tree

7 files changed

+36
-63
lines changed

7 files changed

+36
-63
lines changed

docs/tutorials/neurocog/hebbian.md

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,29 @@ Specifically, we will zoom in on two particular code snippets from
2121
below:
2222

2323
```python
24-
Wab = HebbianSynapse(name="Wab", shape=(1, 1), eta=1., signVal=-1.,
25-
wInit=("constant", 1., None), w_bound=0., key=subkeys[3])
24+
Wab = HebbianSynapse(
25+
name="Wab", shape=(1, 1), eta=1., signVal=-1., wInit=("constant", 1., None), w_bound=0., key=subkeys[3]
26+
)
2627

2728
# wire output compartment (rate-coded output zF) of RateCell `a` to input compartment of HebbianSynapse `Wab`
28-
Wab.inputs << a.zF
29+
a.zF >> Wab.inputs
2930
# wire output compartment of HebbianSynapse `Wab` to input compartment (electrical current j) RateCell `b`
30-
b.j << Wab.outputs
31+
Wab.outputs >> b.j
3132

3233
# wire output compartment (rate-coded output zF) of RateCell `a` to presynaptic compartment of HebbianSynapse `Wab`
33-
Wab.pre << a.zF
34+
a.zF >> Wab.pre
3435
# wire output compartment (rate-coded output zF) of RateCell `b` to postsynaptic compartment of HebbianSynapse `Wab`
35-
Wab.post << b.zF
36+
b.zF >> Wab.post
3637
```
3738

3839
as well as (a bit later in the model construction code):
3940

4041
```python
41-
evolve_process = (JaxProcess()
42+
evolve_process = (MethodProcess()
4243
>> a.evolve)
43-
circuit.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
4444

45-
advance_process = (JaxProcess()
45+
advance_process = (MethodProcess()
4646
>> a.advance_state)
47-
circuit.wrap_and_add_command(jit(advance_process.pure), name="advance")
4847
```
4948

5049
Notice that beyond wiring component `a`'s values into the synapse `Wab`'s input compartment
@@ -54,7 +53,7 @@ post-synaptic compartment `Wab.post`. These compartments are specifically
5453
used in `Wab`'s `evolve` call and are not strictly required to be exactly
5554
the same as its input and output compartments. Note that, if one wanted `pre`
5655
and `post` to be exactly identical to `inputs` and `outputs`, one would simply need
57-
to write `Wab.pre << Wab.inputs` and `Wab.post << Wab.outputs` in place
56+
to write `Wab.inputs >> Wab.pre` and `Wab.outputs >> Wab.post` in place
5857
of the pre- and post-synaptic compartment calls above.
5958

6059
The above snippets highlight two key aspects of functionality that a synapse

docs/tutorials/neurocog/input_cells.md

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ spike train over $100$ steps in time as follows:
3939

4040
```python
4141
from jax import numpy as jnp, random, jit
42-
from ngcsimlib.context import Context
43-
from ngclearn.utils import JaxProcess
42+
from ngclearn import Context, MethodProcess
4443

4544
from ngclearn.utils.viz.raster import create_raster_plot
4645
## import model-specific mechanisms
@@ -56,27 +55,24 @@ T = 100 ## number time steps to simulate
5655
with Context("Model") as model:
5756
cell = BernoulliCell("z0", n_units=10, key=subkeys[0])
5857

59-
advance_process = (JaxProcess()
58+
advance_process = (MethodProcess("advance_proc")
6059
>> cell.advance_state)
61-
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
6260

63-
reset_process = (JaxProcess()
61+
reset_process = (MethodProcess("reset_proc")
6462
>> cell.reset)
65-
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
6663

67-
68-
@Context.dynamicCommand
69-
def clamp(x):
70-
cell.inputs.set(x)
64+
def clamp(x):
65+
cell.inputs.set(x)
66+
7167

7268
probs = jnp.asarray([[0.8, 0.2, 0., 0.55, 0.9, 0, 0.15, 0., 0.6, 0.77]], dtype=jnp.float32)
7369
spikes = []
74-
model.reset()
70+
reset_process.run()
7571
for ts in range(T):
76-
model.clamp(probs)
77-
model.advance(t=ts * 1., dt=dt)
72+
clamp(probs)
73+
advance_process.run(t=ts * 1., dt=dt)
7874

79-
s_t = cell.outputs.value
75+
s_t = cell.outputs.get()
8076
spikes.append(s_t)
8177
spikes = jnp.concatenate(spikes, axis=0)
8278
create_raster_plot(spikes, plot_fname="input_cell_raster.jpg")
@@ -121,7 +117,7 @@ and by replacing the line that has the `BernoulliCell` call with the
121117
following line instead:
122118

123119
```python
124-
cell = PoissonCell("z0", n_units=10, max_freq=63.75, key=subkeys[0])
120+
cell = PoissonCell("z0", n_units=10, target_freq=63.75, key=subkeys[0])
125121
```
126122

127123
Running the code with the two above small modifications will
@@ -149,12 +145,12 @@ mu = 0.
149145
probs = jnp.asarray([[1.]],dtype=jnp.float32)
150146
for _ in range(n_trials):
151147
spikes = []
152-
model.reset()
148+
reset_process.run()
153149
for ts in range(T):
154-
model.clamp(probs)
155-
model.advance(t=ts*1., dt=dt)
150+
clamp(probs)
151+
advance_process.run(t=ts * 1., dt=dt)
156152

157-
s_t = cell.outputs.value
153+
s_t = cell.outputs.get()
158154
spikes.append(s_t)
159155
count = jnp.sum(jnp.concatenate(spikes, axis=0))
160156
mu += count

ngclearn/components/input_encoders/bernoulliCell.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class BernoulliCell(JaxComponent):
2626
batch_size: batch size dimension of this cell (Default: 1)
2727
"""
2828

29-
def __init__(self, name: str, n_units: int, batch_size: int = 1, key: Union[jax.Array, None] = None):
29+
def __init__(self, name: str, n_units: int, batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs):
3030
super().__init__(name=name, key=key)
3131

3232
## Layer Size Setup
@@ -80,20 +80,6 @@ def help(cls): ## component help function
8080
"hyperparameters": hyperparams}
8181
return info
8282

83-
# def __repr__(self):
84-
# comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
85-
# maxlen = max(len(c) for c in comps) + 5
86-
# lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
87-
# for c in comps:
88-
# stats = tensorstats(getattr(self, c).value)
89-
# if stats is not None:
90-
# line = [f"{k}: {v}" for k, v in stats.items()]
91-
# line = ", ".join(line)
92-
# else:
93-
# line = "None"
94-
# lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
95-
# return lines
96-
9783
if __name__ == '__main__':
9884
from ngcsimlib.context import Context
9985
with Context("Bar") as bar:

ngclearn/components/input_encoders/latencyCell.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,9 @@ class LatencyCell(JaxComponent):
144144
"""
145145

146146
def __init__(
147-
self, name: str, n_units: int, tau: float = 1., threshold: float = 0.01,
148-
first_spike_time: float = 0., linearize: bool = False,
149-
normalize: bool = False, clip_spikes: bool = False,
150-
num_steps: float = 1., batch_size: int = 1,
151-
key: Union[jax.Array, None] = None
147+
self, name: str, n_units: int, tau: float = 1., threshold: float = 0.01, first_spike_time: float = 0.,
148+
linearize: bool = False, normalize: bool = False, clip_spikes: bool = False, num_steps: float = 1.,
149+
batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs
152150
):
153151
super().__init__(name=name, key=key)
154152

ngclearn/components/input_encoders/phasorCell.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,6 @@ def reset(self):
137137
self.angles.set(restVals)
138138
self.key.set(key)
139139

140-
141-
def save(self, directory, **kwargs):
142-
file_name = directory + "/" + self.name + ".npz"
143-
jnp.savez(file_name, key=self.key.value)
144-
145-
def load(self, directory, **kwargs):
146-
file_name = directory + "/" + self.name + ".npz"
147-
data = jnp.load(file_name)
148-
self.key.set(data['key'])
149-
150140
@classmethod
151141
def help(cls): ## component help function
152142
properties = {

ngclearn/components/input_encoders/poissonCell.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import jax
44
from typing import Union
55

6+
from ngcsimlib import deprecate_args
67
from ngcsimlib.parser import compilable
78
from ngcsimlib.compartment import Compartment
89

@@ -29,8 +30,11 @@ class PoissonCell(JaxComponent):
2930
batch_size: batch size dimension of this cell (Default: 1)
3031
"""
3132

32-
def __init__(self, name: str, n_units: int, target_freq: float = 63.75, batch_size: int = 1,
33-
key: Union[jax.Array, None] = None):
33+
@deprecate_args(max_freq="target_freq")
34+
def __init__(
35+
self, name: str, n_units: int, target_freq: float = 63.75, batch_size: int = 1,
36+
key: Union[jax.Array, None] = None, **kwargs
37+
):
3438
super().__init__(name=name, key=key)
3539

3640
## Constrained Bernoulli meta-parameters

tests/components/input_encoders/test_poissonCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from jax import numpy as jnp, random, jit
22
import numpy as np
33
np.random.seed(42)
4-
from ngclearn.components import PoissonCell
4+
from ngclearn.components.input_encoders.poissonCell import PoissonCell
55
from numpy.testing import assert_array_equal
66

77
from ngclearn import MethodProcess, Context

0 commit comments

Comments
 (0)