You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/tutorials/neurocog/rate_cell.md
+35-62Lines changed: 35 additions & 62 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,23 +1,19 @@
1
1
# Lecture 3A: The Rate Cell Model
2
2
3
-
Graded neurons are one of the main classes/collections of cell components in ngc-learn. These specifically offer cell models that operate under real-valued dynamics -- in other words, they do not spike or use discrete pulse-like values in their operation. These are useful for building biophysical systems that evolve under continuous, time-varying dynamics, e.g., continuous-time recurrent neural networks, various kinds of predictive coding circuit models, as well as for continuous components in discrete systems, e.g. electrical
4
-
current differential equations in spiking networks.
3
+
Graded neurons are one of the main classes/collections of cell components in ngc-learn. These specifically offer cell models that operate under real-valued dynamics -- in other words, they do not spike or use discrete pulse-like values in their operation. These are useful for building biophysical systems that evolve under continuous, time-varying dynamics, e.g., continuous-time recurrent neural networks, various kinds of predictive coding circuit models, as well as for continuous components in discrete systems, e.g. electrical current differential equations in spiking networks.
5
4
6
5
In this tutorial, we will study one of ngc-learn's workhorse in-built graded cell components, the rate cell ([RateCell](ngclearn.components.neurons.graded.rateCell)).
7
6
8
7
## Creating and Using a Rate Cell
9
8
10
9
### Instantiating the Rate Cell
11
10
12
-
Let's go ahead and set up the controller for this lesson's simulation,
13
-
where we will a dynamical system with only a single component,
14
-
specifically the rate-cell (RateCell). Let's start with the file's header
15
-
(or import statements):
11
+
Let's go ahead and set up the controller for this lesson's simulation, where we will a dynamical system with only a single component, specifically the rate-cell (RateCell). Let's start with the file's header (or import statements):
16
12
17
13
```python
18
14
from jax import numpy as jnp, random, jit
19
-
from ngclearn.utils import JaxProcess
20
-
fromngcsimlib.contextimport Context
15
+
16
+
fromngclearnimport Context, MethodProcess
21
17
## import model-specific elements
22
18
from ngclearn.components.neurons.graded.rateCell import RateCell
23
19
```
@@ -36,91 +32,67 @@ gamma = 1.
36
32
37
33
with Context("Model") as model: ## model/simulation definition
## instantiate some non-jitted dynamic utility commands
53
-
@Context.dynamicCommand
54
-
defclamp(x):
55
-
cell.j.set(x)
47
+
## instantiate utility commands
48
+
defclamp(x):
49
+
cell.j.set(x)
56
50
```
57
51
58
-
A notable argument to the rate-cell, beyond some of its differential equation
59
-
constants (`tau_m` and `gamma`), is its activation function choice (default is
60
-
the `identity`), which we have chosen to be a discrete pulse emitting function
61
-
known as the `unit_threshold` (which outputs a value of one for any input that
62
-
exceeds the threshold of one and zero for anything else).
52
+
A notable argument to the rate-cell, beyond some of its differential equation constants (`tau_m` and `gamma`), is its activation function choice (default is the `identity`), which we have chosen to be a discrete pulse emitting function known as the `unit_threshold` (which outputs a value of one for any input that exceeds the threshold of one and zero for anything else).
63
53
64
-
Mathematically, under the hood, a rate-cell evolves according to the
65
-
ordinary differential equation (ODE):
54
+
Mathematically, under the hood, a rate-cell evolves according to the ordinary differential equation (ODE):
where $\mathbf{x}$ is external input signal and $\mathbf{x}_{td}$ (default
73
-
value is zero) is an optional additional input pressure signal (`td` stands for "top-down",
74
-
its name motivated by predictive coding literature).
75
-
A good way to understand this equation is in the context of two examples:
76
-
1. in a biophysically more realistic spiking network, $\mathbf{x}$ is the
77
-
total electrical input into the cell from multiple injections produced
78
-
by transmission across synapses ($\mathbf{x}_{td} = 0$)) and the $\text{prior}$
79
-
is set to `gaussian` ($\gamma = 1$), yielding the equation
80
-
$\tau_m \frac{\partial \mathbf{z}}{\partial t} = -\mathbf{z} + \mathbf{x}$ for
81
-
a simple model of synaptic conductance, and
82
-
2. in a predictive coding circuit, $\mathbf{x}$ is the sum of input projections
83
-
(or messages) passed from a "lower" layer/group of neurons while $\mathbf{x}_{td}$
84
-
is set to be the sum of (top-down) pressures produced by an "upper" layer/group
85
-
such as the value of a pair of nearby error neurons multiplied by $-1$.[^1] In
86
-
this example, $0 \leq \gamma \leq 1$ and $\text{prior}$ could be set to one
87
-
of any kind of kurtotic distribution to induce a soft form of sparsity in
88
-
the dynamics, e.g., such as "cauchy" for the Cauchy distribution.
61
+
where $\mathbf{x}$ is external input signal and $\mathbf{x}_{td}$ (default value is zero) is an optional additional input pressure signal (`td` stands for "top-down", its name motivated by predictive coding literature).
62
+
A good way to understand this equation is in the context of two examples:
63
+
1. in a biophysically more realistic spiking network, $\mathbf{x}$ is the total electrical input into the cell from multiple injections produced by transmission across synapses ($\mathbf{x}_{td} = 0$)) and the $\text{prior}$ is set to `gaussian` ($\gamma = 1$), yielding the equation $\tau_m \frac{\partial \mathbf{z}}{\partial t} = -\mathbf{z} + \mathbf{x}$ for a simple model of synaptic conductance, and
64
+
2. in a predictive coding circuit, $\mathbf{x}$ is the sum of input projections (or messages) passed from a "lower" layer/group of neurons while $\mathbf{x}_{td}$ is set to be the sum of (top-down) pressures produced by an "upper" layer/group such as the value of a pair of nearby error neurons multiplied by $-1$.[^1] In this example, $0 \leq \gamma \leq 1$ and $\text{prior}$ could be set to one of any kind of kurtotic distribution to induce a soft form of sparsity in the dynamics, e.g., such as "cauchy" for the Cauchy distribution.
89
65
90
66
### Simulating a Rate Cell
91
67
92
-
Given our single rate-cell dynamical system above, let us write some code to use
93
-
our `Rate` node and visualize its dynamics by feeding
94
-
into it a pulse current (a piecewise input function that is an alternating
95
-
sequence of intervals of where nothing is input and others where a non-zero
96
-
value is input) for a small period of time (`dt * T = 1 * 210` ms). Specifically,
97
-
we can plot the input current, the neuron's linear rate activity `z` and its
98
-
nonlinear activity `phi(z)` as follows:
68
+
Given our single rate-cell dynamical system above, let us write some code to use our `Rate` node and visualize its dynamics by feeding into it a pulse current (a piecewise input function that is an alternating sequence of intervals of where nothing is input and others where a non-zero value is input) for a small period of time (`dt * T = 1 * 210` ms). Specifically, we can plot the input current, the neuron's linear rate activity `z` and its nonlinear activity `phi(z)` as follows:
99
69
100
70
```python
101
71
# create a synthetic electrical pulse current
102
-
current = jnp.concatenate((jnp.zeros((1,10)),
103
-
jnp.ones((1,50)) *1.006,
104
-
jnp.zeros((1,50)),
105
-
jnp.ones((1,50)) *1.006,
106
-
jnp.zeros((1,50))), axis=1)
72
+
current = jnp.concatenate(
73
+
(jnp.zeros((1,10)),
74
+
jnp.ones((1,50)) *1.006,
75
+
jnp.zeros((1,50)),
76
+
jnp.ones((1,50)) *1.006,
77
+
jnp.zeros((1,50))), axis=1
78
+
)
107
79
108
80
lin_out = []
109
81
nonlin_out = []
110
82
t_values = []
111
83
112
-
model.reset()
84
+
reset_process.run()
113
85
t =0.
114
86
for ts inrange(current.shape[1]):
115
87
j_t = jnp.expand_dims(current[0,ts], axis=0) ## get data at time ts
116
-
model.clamp(j_t)
117
-
model.advance(t=ts*1., dt=dt)
88
+
clamp(j_t)
89
+
advance_process.run(t=ts*1., dt=dt)
118
90
t_values.append(t)
119
-
t += dt
91
+
t += dt## advance time forward by dt milliseconds
120
92
121
93
## naively extract simple statistics at time ts and print them to I/O
122
-
linear_z = cell.z.value
123
-
nonlinear_z = cell.zF.value
94
+
linear_z = cell.z.get()
95
+
nonlinear_z = cell.zF.get()
124
96
lin_out.append(linear_z)
125
97
nonlin_out.append(nonlinear_z)
126
98
print("\r{}: s {} ; v {}".format(ts, linear_z, nonlinear_z), end="")
@@ -148,10 +120,11 @@ ax.grid()
148
120
fig.savefig("rate_cell_integration.jpg")
149
121
```
150
122
151
-
which should yield a dynamics plot similar to the one below:
123
+
which should yield a dynamics plot similar to the one below:
0 commit comments