Skip to content

Commit ab6c716

Browse files
author
Alexander Ororbia
committed
revised stp-syn neurocog doc and updated stp-syn to use proper initializer
1 parent 276cb89 commit ab6c716

File tree

4 files changed

+90
-187
lines changed

4 files changed

+90
-187
lines changed

docs/tutorials/neurocog/mod_stdp.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Writing the above three parallel single synapse systems, including meta-paramete
1818

1919
```python
2020
from jax import numpy as jnp, random, jit
21-
from ngcsimlib.context import Context
21+
2222
from ngclearn import Context, MethodProcess
2323
## import model-specific mechanisms
2424
from ngclearn.components import (TraceSTDPSynapse, MSTDPETSynapse, RewardErrorCell, VarTrace)

docs/tutorials/neurocog/rate_cell.md

Lines changed: 35 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
# Lecture 3A: The Rate Cell Model
22

3-
Graded neurons are one of the main classes/collections of cell components in ngc-learn. These specifically offer cell models that operate under real-valued dynamics -- in other words, they do not spike or use discrete pulse-like values in their operation. These are useful for building biophysical systems that evolve under continuous, time-varying dynamics, e.g., continuous-time recurrent neural networks, various kinds of predictive coding circuit models, as well as for continuous components in discrete systems, e.g. electrical
4-
current differential equations in spiking networks.
3+
Graded neurons are one of the main classes/collections of cell components in ngc-learn. These specifically offer cell models that operate under real-valued dynamics -- in other words, they do not spike or use discrete pulse-like values in their operation. These are useful for building biophysical systems that evolve under continuous, time-varying dynamics, e.g., continuous-time recurrent neural networks, various kinds of predictive coding circuit models, as well as for continuous components in discrete systems, e.g. electrical current differential equations in spiking networks.
54

65
In this tutorial, we will study one of ngc-learn's workhorse in-built graded cell components, the rate cell ([RateCell](ngclearn.components.neurons.graded.rateCell)).
76

87
## Creating and Using a Rate Cell
98

109
### Instantiating the Rate Cell
1110

12-
Let's go ahead and set up the controller for this lesson's simulation,
13-
where we will a dynamical system with only a single component,
14-
specifically the rate-cell (RateCell). Let's start with the file's header
15-
(or import statements):
11+
Let's go ahead and set up the controller for this lesson's simulation, where we will a dynamical system with only a single component, specifically the rate-cell (RateCell). Let's start with the file's header (or import statements):
1612

1713
```python
1814
from jax import numpy as jnp, random, jit
19-
from ngclearn.utils import JaxProcess
20-
from ngcsimlib.context import Context
15+
16+
from ngclearn import Context, MethodProcess
2117
## import model-specific elements
2218
from ngclearn.components.neurons.graded.rateCell import RateCell
2319
```
@@ -36,91 +32,67 @@ gamma = 1.
3632

3733
with Context("Model") as model: ## model/simulation definition
3834
## instantiate components (like cells)
39-
cell = RateCell("z0", n_units=1, tau_m=tau_m, act_fx=act_fx,
40-
prior=("gaussian", gamma), integration_type="euler", key=subkeys[0])
35+
cell = RateCell(
36+
"z0", n_units=1, tau_m=tau_m, act_fx=act_fx, prior=("gaussian", gamma), integration_type="euler",
37+
key=subkeys[0]
38+
)
4139

4240
## instantiate desired core commands that drive the simulation
43-
advance_process = (JaxProcess()
41+
advance_process = (MethodProcess("advance")
4442
>> cell.advance_state)
45-
model.wrap_and_add_command(jit(advance_process.pure), name="advance")
46-
47-
reset_process = (JaxProcess()
43+
reset_process = (MethodProcess("reset")
4844
>> cell.reset)
49-
model.wrap_and_add_command(jit(reset_process.pure), name="reset")
5045

5146

52-
## instantiate some non-jitted dynamic utility commands
53-
@Context.dynamicCommand
54-
def clamp(x):
55-
cell.j.set(x)
47+
## instantiate utility commands
48+
def clamp(x):
49+
cell.j.set(x)
5650
```
5751

58-
A notable argument to the rate-cell, beyond some of its differential equation
59-
constants (`tau_m` and `gamma`), is its activation function choice (default is
60-
the `identity`), which we have chosen to be a discrete pulse emitting function
61-
known as the `unit_threshold` (which outputs a value of one for any input that
62-
exceeds the threshold of one and zero for anything else).
52+
A notable argument to the rate-cell, beyond some of its differential equation constants (`tau_m` and `gamma`), is its activation function choice (default is the `identity`), which we have chosen to be a discrete pulse emitting function known as the `unit_threshold` (which outputs a value of one for any input that exceeds the threshold of one and zero for anything else).
6353

64-
Mathematically, under the hood, a rate-cell evolves according to the
65-
ordinary differential equation (ODE):
54+
Mathematically, under the hood, a rate-cell evolves according to the ordinary differential equation (ODE):
6655

6756
$$
6857
\tau_m \frac{\partial \mathbf{z}}{\partial t} =
6958
-\gamma \text{prior}\big(\mathbf{z}\big) + (\mathbf{x} + \mathbf{x}_{td})
7059
$$
7160

72-
where $\mathbf{x}$ is external input signal and $\mathbf{x}_{td}$ (default
73-
value is zero) is an optional additional input pressure signal (`td` stands for "top-down",
74-
its name motivated by predictive coding literature).
75-
A good way to understand this equation is in the context of two examples:
76-
1. in a biophysically more realistic spiking network, $\mathbf{x}$ is the
77-
total electrical input into the cell from multiple injections produced
78-
by transmission across synapses ($\mathbf{x}_{td} = 0$)) and the $\text{prior}$
79-
is set to `gaussian` ($\gamma = 1$), yielding the equation
80-
$\tau_m \frac{\partial \mathbf{z}}{\partial t} = -\mathbf{z} + \mathbf{x}$ for
81-
a simple model of synaptic conductance, and
82-
2. in a predictive coding circuit, $\mathbf{x}$ is the sum of input projections
83-
(or messages) passed from a "lower" layer/group of neurons while $\mathbf{x}_{td}$
84-
is set to be the sum of (top-down) pressures produced by an "upper" layer/group
85-
such as the value of a pair of nearby error neurons multiplied by $-1$.[^1] In
86-
this example, $0 \leq \gamma \leq 1$ and $\text{prior}$ could be set to one
87-
of any kind of kurtotic distribution to induce a soft form of sparsity in
88-
the dynamics, e.g., such as "cauchy" for the Cauchy distribution.
61+
where $\mathbf{x}$ is external input signal and $\mathbf{x}_{td}$ (default value is zero) is an optional additional input pressure signal (`td` stands for "top-down", its name motivated by predictive coding literature).
62+
A good way to understand this equation is in the context of two examples:
63+
1. in a biophysically more realistic spiking network, $\mathbf{x}$ is the total electrical input into the cell from multiple injections produced by transmission across synapses ($\mathbf{x}_{td} = 0$)) and the $\text{prior}$ is set to `gaussian` ($\gamma = 1$), yielding the equation $\tau_m \frac{\partial \mathbf{z}}{\partial t} = -\mathbf{z} + \mathbf{x}$ for a simple model of synaptic conductance, and
64+
2. in a predictive coding circuit, $\mathbf{x}$ is the sum of input projections (or messages) passed from a "lower" layer/group of neurons while $\mathbf{x}_{td}$ is set to be the sum of (top-down) pressures produced by an "upper" layer/group such as the value of a pair of nearby error neurons multiplied by $-1$.[^1] In this example, $0 \leq \gamma \leq 1$ and $\text{prior}$ could be set to one of any kind of kurtotic distribution to induce a soft form of sparsity in the dynamics, e.g., such as "cauchy" for the Cauchy distribution.
8965

9066
### Simulating a Rate Cell
9167

92-
Given our single rate-cell dynamical system above, let us write some code to use
93-
our `Rate` node and visualize its dynamics by feeding
94-
into it a pulse current (a piecewise input function that is an alternating
95-
sequence of intervals of where nothing is input and others where a non-zero
96-
value is input) for a small period of time (`dt * T = 1 * 210` ms). Specifically,
97-
we can plot the input current, the neuron's linear rate activity `z` and its
98-
nonlinear activity `phi(z)` as follows:
68+
Given our single rate-cell dynamical system above, let us write some code to use our `Rate` node and visualize its dynamics by feeding into it a pulse current (a piecewise input function that is an alternating sequence of intervals of where nothing is input and others where a non-zero value is input) for a small period of time (`dt * T = 1 * 210` ms). Specifically, we can plot the input current, the neuron's linear rate activity `z` and its nonlinear activity `phi(z)` as follows:
9969

10070
```python
10171
# create a synthetic electrical pulse current
102-
current = jnp.concatenate((jnp.zeros((1,10)),
103-
jnp.ones((1,50)) * 1.006,
104-
jnp.zeros((1,50)),
105-
jnp.ones((1,50)) * 1.006,
106-
jnp.zeros((1,50))), axis=1)
72+
current = jnp.concatenate(
73+
(jnp.zeros((1,10)),
74+
jnp.ones((1,50)) * 1.006,
75+
jnp.zeros((1,50)),
76+
jnp.ones((1,50)) * 1.006,
77+
jnp.zeros((1,50))), axis=1
78+
)
10779

10880
lin_out = []
10981
nonlin_out = []
11082
t_values = []
11183

112-
model.reset()
84+
reset_process.run()
11385
t = 0.
11486
for ts in range(current.shape[1]):
11587
j_t = jnp.expand_dims(current[0,ts], axis=0) ## get data at time ts
116-
model.clamp(j_t)
117-
model.advance(t=ts*1., dt=dt)
88+
clamp(j_t)
89+
advance_process.run(t=ts*1., dt=dt)
11890
t_values.append(t)
119-
t += dt
91+
t += dt ## advance time forward by dt milliseconds
12092

12193
## naively extract simple statistics at time ts and print them to I/O
122-
linear_z = cell.z.value
123-
nonlinear_z = cell.zF.value
94+
linear_z = cell.z.get()
95+
nonlinear_z = cell.zF.get()
12496
lin_out.append(linear_z)
12597
nonlin_out.append(nonlinear_z)
12698
print("\r {}: s {} ; v {}".format(ts, linear_z, nonlinear_z), end="")
@@ -148,10 +120,11 @@ ax.grid()
148120
fig.savefig("rate_cell_integration.jpg")
149121
```
150122

151-
which should yield a dynamics plot similar to the one below:
123+
which should yield a dynamics plot similar to the one below:
152124

153125
<img src="../../images/tutorials/neurocog/rate_cell_integration.jpg" width="500" />
154126

127+
155128
<!-- footnotes -->
156129
[^1]: [Error neurons](ngclearn.components.neurons.graded.gaussianErrorCell)
157130
produce this kind of "top-down" value, which is technically the first derivative

0 commit comments

Comments
 (0)