diff --git a/README.md b/README.md
index 7e43e7fa..1e7412fd 100644
--- a/README.md
+++ b/README.md
@@ -122,7 +122,7 @@ $ python install -e .
**Version:**
-2.0.0
+2.0.1
Author:
Alexander G. Ororbia II
diff --git a/docs/images/museum/rat_accuracy.jpg b/docs/images/museum/rat_accuracy.jpg
new file mode 100644
index 00000000..6236db35
Binary files /dev/null and b/docs/images/museum/rat_accuracy.jpg differ
diff --git a/docs/images/museum/rat_rewards.jpg b/docs/images/museum/rat_rewards.jpg
new file mode 100644
index 00000000..8c42e1b0
Binary files /dev/null and b/docs/images/museum/rat_rewards.jpg differ
diff --git a/docs/images/museum/ratmaze.png b/docs/images/museum/ratmaze.png
new file mode 100644
index 00000000..65c1006f
Binary files /dev/null and b/docs/images/museum/ratmaze.png differ
diff --git a/docs/images/museum/real_ratmaze.jpg b/docs/images/museum/real_ratmaze.jpg
new file mode 100644
index 00000000..5925c09f
Binary files /dev/null and b/docs/images/museum/real_ratmaze.jpg differ
diff --git a/docs/images/tutorials/neurocog/GEC.png b/docs/images/tutorials/neurocog/GEC.png
new file mode 100644
index 00000000..47ec531f
Binary files /dev/null and b/docs/images/tutorials/neurocog/GEC.png differ
diff --git a/docs/images/tutorials/neurocog/SingleGEC.png b/docs/images/tutorials/neurocog/SingleGEC.png
new file mode 100644
index 00000000..8af5cc58
Binary files /dev/null and b/docs/images/tutorials/neurocog/SingleGEC.png differ
diff --git a/docs/images/tutorials/neurocog/alphasyn.jpg b/docs/images/tutorials/neurocog/alphasyn.jpg
new file mode 100644
index 00000000..037cf868
Binary files /dev/null and b/docs/images/tutorials/neurocog/alphasyn.jpg differ
diff --git a/docs/images/tutorials/neurocog/ei_circuit_dense_exc.jpg b/docs/images/tutorials/neurocog/ei_circuit_dense_exc.jpg
new file mode 100644
index 00000000..84c84023
Binary files /dev/null and b/docs/images/tutorials/neurocog/ei_circuit_dense_exc.jpg differ
diff --git a/docs/images/tutorials/neurocog/ei_circuit_dynamics.jpg b/docs/images/tutorials/neurocog/ei_circuit_dynamics.jpg
new file mode 100644
index 00000000..ef29c41e
Binary files /dev/null and b/docs/images/tutorials/neurocog/ei_circuit_dynamics.jpg differ
diff --git a/docs/images/tutorials/neurocog/ei_circuit_sparse_inh.jpg b/docs/images/tutorials/neurocog/ei_circuit_sparse_inh.jpg
new file mode 100644
index 00000000..3a137c8b
Binary files /dev/null and b/docs/images/tutorials/neurocog/ei_circuit_sparse_inh.jpg differ
diff --git a/docs/images/tutorials/neurocog/exp2syn.jpg b/docs/images/tutorials/neurocog/exp2syn.jpg
new file mode 100644
index 00000000..d32f3bfc
Binary files /dev/null and b/docs/images/tutorials/neurocog/exp2syn.jpg differ
diff --git a/docs/images/tutorials/neurocog/expsyn.jpg b/docs/images/tutorials/neurocog/expsyn.jpg
new file mode 100755
index 00000000..1b9b2fc3
Binary files /dev/null and b/docs/images/tutorials/neurocog/expsyn.jpg differ
diff --git a/docs/installation.md b/docs/installation.md
index cda569a2..b70db6c6 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -6,13 +6,13 @@ without a GPU.
Setup: ngc-learn,
in its entirety (including its supporting utilities),
requires that you ensure that you have installed the following base dependencies in
-your system. Note that this library was developed and tested on Ubuntu 22.04 (and 18.04).
+your system. Note that this library was developed and tested on Ubuntu 22.04 (and earlier versions on 18.04/20.04).
Specifically, ngc-learn requires:
* Python (>=3.10)
-* ngcsimlib (>=0.3.b4), (official page)
+* ngcsimlib (>=1.0.0), (official page)
* NumPy (>=1.26.0)
* SciPy (>=1.7.0)
-* JAX (>= 0.4.18; and jaxlib>=0.4.18)
+* JAX (>= 0.4.28; and jaxlib>=0.4.28)
* Matplotlib (>=3.4.2), (for `ngclearn.utils.viz`)
* Scikit-learn (>=1.3.1), (for `ngclearn.utils.patch_utils` and `ngclearn.utils.density`)
@@ -45,7 +45,7 @@ $ git clone https://github.com/NACLab/ngc-learn.git
$ cd ngc-learn
```
-2. (Optional; only for GPU version) Install JAX for either CUDA 11 or 12 , depending
+2. (Optional; only for GPU version) Install JAX for either CUDA 12 , depending
on your system setup. Follow the
installation instructions
on the official JAX page to properly install the CUDA 11 or 12 version.
diff --git a/docs/modeling/neurons.md b/docs/modeling/neurons.md
index 1ec2600b..4babf8f7 100644
--- a/docs/modeling/neurons.md
+++ b/docs/modeling/neurons.md
@@ -234,7 +234,7 @@ fast spiking (FS), low-threshold spiking (LTS), and resonator (RZ) neurons.
This cell models dynamics over voltage `v` and three channels/gates (related to
potassium and sodium activation/inactivation). This sophisticated cell system is,
as a result, a set of four coupled differential equations and is driven by an appropriately configured set of biophysical constants/coefficients (default values of which have been set according to relevant source work).
-(Note that this cell supports either Euler or midpoint method / RK-2 integration.)
+(Note that this cell supports Euler, midpoint / RK-2 integration, or RK-4 integration.)
```{eval-rst}
.. autoclass:: ngclearn.components.HodgkinHuxleyCell
diff --git a/docs/modeling/synapses.md b/docs/modeling/synapses.md
index 6bf89394..470446e9 100644
--- a/docs/modeling/synapses.md
+++ b/docs/modeling/synapses.md
@@ -60,6 +60,34 @@ This synapse performs a deconvolutional transform of its input signals. Note tha
## Dynamic Synapse Types
+### Exponential Synapse
+
+This (chemical) synapse performs a linear transform of its input signals. Note that this synapse is "dynamic" in the sense that its efficacies are a function of their pre-synaptic inputs; there is no inherent form of long-term plasticity in this base implementation. Synaptic strength values can be viewed as being filtered/smoothened through an expoential kernel.
+
+```{eval-rst}
+.. autoclass:: ngclearn.components.ExponentialSynapse
+ :noindex:
+
+ .. automethod:: advance_state
+ :noindex:
+ .. automethod:: reset
+ :noindex:
+```
+
+### Alpha Synapse
+
+This (chemical) synapse performs a linear transform of its input signals. Note that this synapse is "dynamic" in the sense that its efficacies are a function of their pre-synaptic inputs; there is no inherent form of long-term plasticity in this base implementation. Synaptic strength values can be viewed as being filtered/smoothened through a kernel that models more realistic rise and fall times of synaptic conductance..
+
+```{eval-rst}
+.. autoclass:: ngclearn.components.AlphaSynapse
+ :noindex:
+
+ .. automethod:: advance_state
+ :noindex:
+ .. automethod:: reset
+ :noindex:
+```
+
### Short-Term Plasticity (Dense) Synapse
This synapse performs a linear transform of its input signals. Note that this synapse is "dynamic" in the sense that it engages in short-term plasticity (STP), meaning that its efficacy values change as a function of its inputs/time (and simulated consumed resources), but it does not provide any long-term form of plasticity/adjustment.
diff --git a/docs/museum/index.rst b/docs/museum/index.rst
index 6f534770..25f75dfc 100644
--- a/docs/museum/index.rst
+++ b/docs/museum/index.rst
@@ -18,3 +18,4 @@ relevant, referenced publicly available ngc-learn simulation code.
snn_dc
snn_bfa
sindy
+ rl_snn
diff --git a/docs/museum/rl_snn.md b/docs/museum/rl_snn.md
new file mode 100644
index 00000000..24d11412
--- /dev/null
+++ b/docs/museum/rl_snn.md
@@ -0,0 +1,132 @@
+# Reinforcement Learning through a Spiking Controller
+
+In this exhibit, we will see how to construct a simple biophysical model for
+reinforcement learning with a spiking neural network and modulated
+spike-timing-dependent plasticity.
+This model incorporates a mechanisms from several different models, including
+the constrained RL-centric SNN of [1] as well as the simplifications
+made with respect to the model of [2]. The model code for this
+exhibit can be found
+[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/rl_snn).
+
+## Modeling Operant Conditioning through Modulation
+
+Operant conditioning refers to the idea that there are environmental stimuli that can either increase or decrease the occurrence of (voluntary) behaviors; in other words, positive stimuli can lead to future repeats of a certain behavior whereas negative stimuli can lead to (i.e., punish) a decrease in future occurrences. Ultimately, operant conditioning, through consequences, shapes voluntary behavior where actions followed by rewards are repeated and actions followed by punished/negative outcomes diminish.
+
+In this lesson, we will model very simple case of operant conditioning for a neuronal motor circuit used to engage in the navigation of a simple maze.
+The maze's design will be the rat T-maze and the "rat" will be allowed to move, at a particular point in the maze, in one of four directions (up/North, down/South, left/West, and right/East). A positive reward will be supplied to our rat neuronal circuit if it makes progress towards the direction of food (placed in the upper right corner of the T-maze) and a negative reward will be provided if fails to make progress/gets stuck, i.e., a dense reward functional will be employed. For the exhibit code that goes with this lesson, an implementation of this T-maze environment is provided, modeled in the same style/with the same agent API as the OpenAI gymnasium.
+
+### Reward-Modulated Spike-Timing-Dependent Plasticity (R-STDP)
+
+Although [spike-timing-dependent plasticity](../tutorials/neurocog/stdp.md) (STDP) and [reward-modulated STDP](../tutorials/neurocog/mod_stdp.md) (MSTDP) are covered and analyzed in detail in the ngc-learn set of tutorials, we will briefly review the evolution
+of synaptic strengths as prescribed by modulated STDP with eligibiility traces here. In effect, STDP prescribes changes
+in synaptic strength according to the idea that neurons that fire together, wire together, except that timing matters
+(a temporal interpretation of basic Hebbian learning). This means that, assuming we are able to record the times of
+pre-synaptic and post-synaptic neurons (that a synaptic cable connects), we can, at any time-step $t$, produce an
+adjustment $\Delta W_{ij}(t)$ to a synapse via the following pair of correlational rules:
+
+$$
+\Delta W_{ij}(t) = A^+ \big(x_i s_j \big) - A^- \big(s_i x_j \big)
+$$
+
+where $s_j$ is the spike recorded at time $t$ of the post-synaptic neuron $j$ (and $x_j$ is an exponentially-decaying trace that tracks its spiking history) and $s_i$ is the spike recorded at time $t$ of the pre-synaptic neuron $i$ (and $x_i$ is an exponentially-decaying trace that tracks its pulse history). STDP, as shown in a very simple format above, effectively can be described as balancing two types of alterations to a synaptic efficacy -- long-term potentiation (the first term, which increases synaptic strength) and long-term depression (the second term, which decreases synaptic strength).
+
+Modulated STDP is a three-factor variant of STDP that multiplies the final synaptic update by a third signal, e.g., the modulatory signal is often a reward (dopamine) intensity value (resulting in reward-modulated STDP). However, given that reward signals might be delayed or not arrive/be available at every single time-step, it is common practice to extend a synapse to maintain a second value called an "eligibility trace", which is effectively another exponentially-decaying trace/filter (instantiated as an ODE that can be integrated via the Euler method or related tools) that is constructed to track a sequence of STDP updates applied across a window of time. Once a reward/modulator signal becomes available, the current trace is multiplied by the modulator to produce a change in synaptic efficacy.
+In essence, this update becomes:
+
+$$
+\Delta W_{ij} = \nu E_{ij}(t) r(t), \; \text{where } \; \tau_e \frac{\partial E_{ij}(t)}{\partial t} = -E_{ij}(t) + \Delta W_{ij}(t)
+$$
+
+where $r(t)$ is the dopamine supplied at some time $t$ and $\nu$ is some non-negative global learning rate. Note that MSTDP with eligibility traces (MSTDP-ET) is agnostic to the choice of local STDP/Hebbian update used to produce $\Delta W_{ij}(t)$ (for example, one could replace the trace-based STDP rule we presented above with BCM or a variant of weight-dependent STDP).
+
+## The Spiking Neural Circuit Model
+
+In this exhibit, we build one of the simplest possible spiking neural networks (SNNs) one could design to tackle a simple maze navigation problem such as the rat T-maze; specifically, a three-layer SNN where the first layer is a Poisson encoder and the second and third layers contain sets of recurrent leaky integrate-and-fire (LIF) neurons. The recurrence in our model is to be non-plastic and constructed such that a form of lateral competition is induced among the LIF units, i.e., LIF neurons will be driven by a scaled Hollow-matrix initialized recurrent weight matrix (which will multiply spikes encountered at time $t - \Delta t$ by negative values), which will (quickly yet roughly) approximate the effect of inhibitory neurons. For the synapses that transmit pulses from the sensory/input layer to the second layer, we will opt for a non-plastic sparse mixture of excitatory and inhibitory strength values (much as in the model of [1]) to produce a reasonable encoding of the input Poisson spike trains. For the synapses that transmit pulses from the second layer to the third (control/action) layer, we will employ MSTDP-ET (as shown in the previous section) to adjust the non-negative efficacies in order to learn a basic reactive policy. We will call this very simple neuronal model the "reinforcement learning SNN" (RL-SNN).
+
+The SNN circuit will be provided raw pixels of the T-maze environment (however, this view is a global view of the
+entire maze, as opposed to something more realistic such as an egocentric view of the sensory space), where a cross
+"+" marks its current location and an "X" marks the location of the food substance/goal state. Shown below is an
+image to the left depicting a real-world rat T-maze whereas to the right is our implementation/simulation of the
+T-maze problem (and what our SNN circuit sees at the very start of an episode of the navigation problem).
+
+```{eval-rst}
+.. table::
+ :align: center
+
+ +-------------------------------------------------+------------------------------------------------+
+ | .. image:: ../images/museum/real_ratmaze.jpg | .. image:: ../images/museum/ratmaze.png |
+ | :width: 250px | :width: 200px |
+ | :align: center | :align: center |
+ +-------------------------------------------------+------------------------------------------------+
+```
+
+## Running the RL-SNN Model
+
+To fit the RL-SNN model described above, go to the `exhibits/rl_snn`
+sub-folder (this step assumes that you have git cloned the model museum repo
+code), and execute the RL-SNN's simulation script from the command line as follows:
+
+```console
+$ ./sim.sh
+```
+which will execute a simulation of the MSTDP-adapted SNN on the T-maze problem, specifically executing four uniquely-seeded trial runs (i.e., four different "rat agents") and produce two plots, one containing a smoothened curve of episodic rewards over time and another containing a smoothened task accuracy curve (as in, did the rat reach the goal-state and obtain the food substance or not). You should obtain plots that look roughly like the two below.
+
+```{eval-rst}
+.. table::
+ :align: center
+
+ +-----------------------------------------------+-----------------------------------------------+
+ | .. image:: ../images/museum/rat_rewards.jpg | .. image:: ../images/museum/rat_accuracy.jpg |
+ | :width: 400px | :width: 400px |
+ | :align: center | :align: center |
+ +-----------------------------------------------+-----------------------------------------------+
+```
+
+Notice that we have provided a random agent baseline (i.e., uniform random selection of one of the four possible
+directions to move at each step in an episode) to contrast the SNN rat motor circuit's performance with. As you may
+observe, the SNN circuit ultimately becomes conditioned to taking actions akin to the optimal policy -- go North/up
+if it perceives itself (marked as a cross "+") at the bottom of the T-maze and then go East/right once it has reached the top
+of the T of the T-maze and go right upon perception of the food item (goal state marked as an "X").
+
+The code has been configured to also produce a small video/GIF of the final episode `episode200.gif`, where the MSTDP
+weight changes have been disabled and the agent must solely rely on its memory of the uncovered policy to get to the
+goal state.
+
+### Some Important Limitations
+
+While the above MSTDP-ET-driven motor circuit model is useful and provides a simple model of operant conditioning in
+the context of a very simple maze navigation task, it is important to identify the assumptions/limitations of the
+above setup. Some important limitations or simplifications that have been made to obtain a consistently working
+RL-SNN model:
+1. As mentioned earlier, the sensory input contains a global view of the maze navigation problem, i.e., a 2D birds-eye
+ view of the agent, its goal (the food substance), and its environment. More realistic, but far more difficult
+ versions of this problem would need to consider an ego-centric view (making the problem a partially observable
+ Markov decision process), a more realistic 3D representation of the environment, as well as more complex maze
+ sizes and shapes for the agent/rat model to navigate.
+2. The reward is only delayed with respect to the agent's stimulus processing window, meaning that the agent essentially
+ receives a dopamine signal after an action is taken. If we ignore the SNN's stimulus processing time between video
+ frames of the actual navigation problem, we can view our agent above as tackling what is known in reinforcement
+ learning as the dense reward problem. A far more complex, yet more cognitively realistic, version of the problem
+ is to administer a sparse reward, i.e., the rat motor circuit only receives a useful dopamine/reward stimulus at the
+ end of an episode as opposed to after each action. The above MSTDP-ET model would struggle to solve the sparse
+ reward problem and more sophisticated models would be required in order to achieve successful outcomes, i.e.,
+ appealing to models of memory/cognitive maps, more intelligent forms of exploration, etc.
+3. The SNN circuit itself only permits plastic synapses in its control layer (i.e., the synaptic connections between
+ the second layer and third output/control layer). The bottom layer is non-plastic and fixed, meaning that the
+ agent model is dependent on the quality of the random initialization of the input-to-hidden encoding layer. The
+ input-to-hidden synapses could be adapted with STDP (or MSTDP); however, the agent will not always successfully
+ and stably converge to a consistent policy as the encoding layer's effectiveness is highly dependent on how much
+ of the environment the agent initially sees/explores (if the agent gets "stuck" at any point, STDP will tend to
+ fill up the bottom layer receptive fields with redundant information and make it more difficult for the control
+ layer to learn the consequences of taking different actions).
+
+
+## References
+[1] Chevtchenko, Sérgio F., and Teresa B. Ludermir. "Learning from sparse
+and delayed rewards with a multilayer spiking neural network." 2020 International
+Joint Conference on Neural Networks (IJCNN). IEEE, 2020.
+[2] Diehl, Peter U., and Matthew Cook. "Unsupervised learning of digit
+recognition using spike-timing-dependent plasticity." Frontiers in computational
+neuroscience 9 (2015): 99.
+
diff --git a/docs/museum/sindy.md b/docs/museum/sindy.md
index 17ef4fee..04426d70 100644
--- a/docs/museum/sindy.md
+++ b/docs/museum/sindy.md
@@ -8,8 +8,7 @@ Flow diagrams lack clear directional indicators
Inconsistent color schemes across visualizations
-->
-
-# Sparse Identification of Non-linear Dynamical Systems (SINDy)[1]
+# Sparse Identification of Non-linear Dynamical Systems (SINDy)
In this section, we will study, create, simulate, and visualize a model known as the sparse identification of non-linear dynamical systems (SINDy) [1], implementing it in NGC-Learn and JAX. After going through this demonstration, you will:
@@ -28,19 +27,24 @@ SINDy is a data-driven algorithm that discovers the governing behavior of a dyna
### SINDy Dynamics
-If $\mathbf{X}$ is a system that only depends on variable $t$, a very small change in the independent variable ($dt$) can cause a change in the system by $dX$ amount.
-$$$
+If $\mathbf{X}$ is a system that only depends on variable $t$, a very small change in the independent variable ($dt$) can cause a change in the system by $dX$ amount:
+
+$$
d\mathbf{X} = \mathbf{Ẋ}(t)~dt
-$$$
+$$
+
SINDy models the derivative[^1] (a linear operation) as linear transformations with:
[^1]: The derivative is a linear operation that acts on $dt$ and gives a differential that is the linearized approximation of the taylor series of the function.
-$$$
+
+$$
\frac{d\mathbf{X}(t)}{dt} = \mathbf{Ẋ}(t) = \mathbf{f}(\mathbf{X}(t))
-$$$
+$$
+
SINDy assumes that this linear operation, i.e., $\mathbf{f}(\mathbf{X}(t))$, is a matrix multiplication that linearly combines the relevant predictors in order to describe the system's equation.
-$$$
+
+$$
\mathbf{f}(\mathbf{X}(t)) = \mathbf{\Theta}(\mathbf{X})~\mathbf{W}
-$$$
+$$
Given a group of candidate functions within the library $\mathbf{\Theta}(\mathbf{X})$, the coefficients in $\mathbf{W}$ that choose the library terms are to be **sparse**. In other words, there are only a few functions that exist in the system's differential equation. Given these assumptions, SINDy solves a sparse regression problem in order to find the $\mathbf{W}$ that maps the library of selected terms to each feature of the system being identified. SINDy imposes parsimony constraints over the resulting symbolic regression (i.e., genetic programming) to describe a dynamical system's behavior with as few terms as possible. In order to select a sparse set of the given features, the model adds the LASSO regularizarion penalty (i.e., an L1 norm constraint) to the regression problem and solves the sparse regression or solves the regression problem via STLSQ. We will describe STLSQ in third step of the SINDy dynamics/process.
@@ -48,206 +52,101 @@ In essence, SINDy's dynamics can be presented in three main phases, visualized i
------------------------------------------------------------------------------------------
-
-
-
+
**Figure 1:** **The flow of the three phases in SINDy.** **Phase-1)** Data collection: capturing system states that are changing in time and creating the state vector. **Phase-2A)** Library formation: manually creating the library of candidate predictors that could appear in the model. **Phase-2B)** Derivative computation: using the data collected in phase 1 to compute its derivative with respect to time. **Phase-3)** Solving the sparse regression problem.
-
| - ## Phase 1: Collecting Dataset → $\mathbf{X}_{(m \times n)}$ This phase involves gathering the raw data points representing the system's states across time. In this example, this means capturing the $x$, $y$, and $z$ coordinates of the system's states. Here, $m$ represents the number of data points (number of the snapshots/length of time) and $n$ is the system's dimensionality. - | -
-
- |
-
+
+
+
+
-| - ## Phase 2: Processing - | -
-
- |
-|||||
+
+
### 2.A: Making the Library → $\mathbf{\Theta}_{(m \times p)}$
In this step, using the dataset collected in phase 1, given pre-defined function terms, we construct a dictionary of candidate predictors for identifying the target system's differential equations. These functions form the columns of our library matrix $\mathbf{\Theta}(\mathbf{X})$ and $p$ is the number of candidate predictors. To identify the dynamical structure of the system, this library of candidate functions appears in the regression problem to propose the model's structure that will later serve as the coefficient matrix for weighting the functions according to the problem setup. We assume sparse models will be sufficient to identify the system and do this through sparsification (LASSO or thresholding weights) in order decide which structure best describes the system's behavior using predictors.
Given a set of time-series measurements of a dynamical system state variables ($\mathbf{X}_{(m \times n)}$) we construct the following:
-Library of Candidate Functions: $\Theta(\mathbf{X}) = [\mathbf{1} \quad \mathbf{X} \quad \mathbf{X}^2 \quad \mathbf{X}^3 \quad \sin(\mathbf{X}) \quad \cos(\mathbf{X}) \quad ...]$
- |
-
-
- |
-|||||
-
+
+
### 2.B: Compute State Derivatives → $\mathbf{Ẋ}_{(m \times n)}$
Given a set of time-series measurements of a dynamical system's state variables $\mathbf{X}_{(m \times n)}$, we next construct the derivative matrix: $\mathbf{Ẋ}_{(m \times n)}$ (computed numerically). In this step, using the dataset collected in phase 1, we compute the derivatives of each state variable with respect to time. In this example, we compute $ẋ$, $ẏ$, and $ż$ in order to capture how the system evolves over time.
- |
-
-
- |
-|||||
| - ## Phase 3: Solving Sparse Regression Problem → $\mathbf{W_s}_{(p \times n)}$ Solving the resulting sparse regression (SR) problem that results from the phases/steps above can be done using various method such as Lasso, STLSQ, Elastic Net, as well as many other schemes. Here, we describe the STLSQ approach to solve the SR problem according to the SINDy process. - | - -
-
- |
-
-
-|||||||||||||||
-
+
+
### Solving Sparse Regression by Sequential Thresholding Least Squares (STLSQ)
-
- |
-||
| - ### Sequential Thresholding Least Square (STLSQ) - | -||
|
-
- |
-||
| #### 3.A: Least Square method (LSQ) → $\mathbf{W}$ This step entails finding library coefficients by solving the following regression problem $\mathbf{Ẋ} = \mathbf{\Theta}\mathbf{W}$ analytically $\mathbf{W} = (\mathbf{\Theta}^T \mathbf{\Theta})^{-1} \mathbf{\Theta}^T \mathbf{Ẋ}$ - | -
-
- |
-|
-
+
+
#### 3.B: Thresholding → $\mathbf{W_s}$
This step entails sparsifying $\mathbf{W}$ by keeping only some of the terms within $\mathbf{W}$, particularly those that correspond to the effective terms in the library.
- |
-
-
- |
-|
-
+
+
+
#### 3.C: Masking → $\mathbf{\Theta_s}$
This step sparsifies $\mathbf{\Theta}$ by keeping only the corresponding terms in $\mathbf{W}$ that remain (from the prior step).
- |
-
-
- |
-|
+
#### 3.D: Repeat A → B → C until convergence
We continue to solve LSQ with the sparse matrix $\mathbf{\Theta_s}$ and $\mathbf{W_s}$ and find a new $\mathbf{W}$, repeating steps B and C until convergence.
- |
-
-
- |
-|
-
-
-
-
## Code: Simulating SINDy
We finally present ngc-learn code below for using and simulating the SINDy process to identify several dynamical systems.
-
-
```python
-
-
-
import numpy as np
import jax.numpy as jnp
from ngclearn.utils.feature_dictionaries.polynomialLibrary import PolynomialLibrary
@@ -335,31 +224,14 @@ for dim in range(dX.shape[1]):
coeff = jnp.where(jnp.abs(coef) >= threshold, coef, 0.)
print(f"coefficients for dimension {dim+1}: \n", coeff.T)
-
-
-
```
-
-
-
-
## Results: System Identification
Running the above code should produce results similar to the findings we present next.
-| - Model - | -- Results - | - -
|---|---|
| - ## Oscillator +## Oscillator True model's equation \ $\mathbf{ẋ} = \mu_1\mathbf{x} + \sigma \mathbf{xy}$ \ @@ -380,19 +252,13 @@ $\mathbf{ż} = \mu_2\mathbf{z} - (\omega + \alpha \mathbf{y} + \beta \mathbf{z}) [ 0. -0.009 0. -2.99 4.99 1.99 0. 0. 0. 0.]] ``` - | -
-
- |
-
+which should produce the following results:
+
+
+
+
- ## Lorenz
+## Lorenz
True model's equation \
$\mathbf{ẋ} = 10(\mathbf{y} - \mathbf{x})$ \
@@ -400,7 +266,6 @@ $\mathbf{ẏ} = \mathbf{x}(28 - \mathbf{z}) - \mathbf{y}$ \
$\mathbf{ż} = \mathbf{xy} - \frac{8}{3}~\mathbf{z}$
-
```python
--- SINDy results ----
ẋ = 9.969 𝑦 -9.966 𝑥
@@ -413,19 +278,12 @@ $\mathbf{ż} = \mathbf{xy} - \frac{8}{3}~\mathbf{z}$
[-2.656 0. 0. 0. 0. 0. 0. 0.996 0.]]
```
- |
-
-
- |
-
-
- ## Linear-2D
+which should produce the following results:
+
+
+
+
+## Linear-2D
True model's equation \
$\mathbf{ẋ} = -0.1\mathbf{x} + 2.0\mathbf{y}$ \
@@ -440,21 +298,14 @@ $\mathbf{ẏ} = -2.0\mathbf{x} - 0.1\mathbf{y}$
[[ 1.999 0. -0.100 0. 0.]
[-0.099 0. -1.999 0. 0.]]
```
+
+which should produce the following results:
+
+
- |
-
-
- |
-
| - ## Linear-3D +## Linear-3D True model's equation \ $\mathbf{ẋ} = -0.1\mathbf{x} + 2\mathbf{y}$ \ @@ -473,22 +324,13 @@ $\mathbf{ż} = -0.3\mathbf{z}$ [-0.299 0. 0. 0. 0. 0. 0. 0. 0.]] ``` - | -
-
- |
-
+
+
+
- ## Cubic-2D
+## Cubic-2D
True model's equation \
$\mathbf{ẋ} = -0.1\mathbf{x}^3 + 2.0\mathbf{y}^3$ \
@@ -504,16 +346,10 @@ $\mathbf{ẏ} = -2.0\mathbf{x}^3 - 0.1\mathbf{y}^3$
[ 0. 0. -0.099 0. 0. 0. 0. 0. -1.99]]
```
- |
-
-
- |
-
+
## References
diff --git a/docs/source/ngclearn.components.neurons.graded.rst b/docs/source/ngclearn.components.neurons.graded.rst
index ec39a795..d62a5b7e 100644
--- a/docs/source/ngclearn.components.neurons.graded.rst
+++ b/docs/source/ngclearn.components.neurons.graded.rst
@@ -36,14 +36,6 @@ ngclearn.components.neurons.graded.rateCell module
:undoc-members:
:show-inheritance:
-ngclearn.components.neurons.graded.rateCellOld module
------------------------------------------------------
-
-.. automodule:: ngclearn.components.neurons.graded.rateCellOld
- :members:
- :undoc-members:
- :show-inheritance:
-
ngclearn.components.neurons.graded.rewardErrorCell module
---------------------------------------------------------
diff --git a/docs/source/ngclearn.components.synapses.hebbian.rst b/docs/source/ngclearn.components.synapses.hebbian.rst
index d892d868..778e59ec 100644
--- a/docs/source/ngclearn.components.synapses.hebbian.rst
+++ b/docs/source/ngclearn.components.synapses.hebbian.rst
@@ -36,14 +36,6 @@ ngclearn.components.synapses.hebbian.hebbianSynapse module
:undoc-members:
:show-inheritance:
-ngclearn.components.synapses.hebbian.hebbianSynapseOld module
--------------------------------------------------------------
-
-.. automodule:: ngclearn.components.synapses.hebbian.hebbianSynapseOld
- :members:
- :undoc-members:
- :show-inheritance:
-
ngclearn.components.synapses.hebbian.traceSTDPSynapse module
------------------------------------------------------------
diff --git a/docs/source/ngclearn.components.synapses.rst b/docs/source/ngclearn.components.synapses.rst
index c1ccc6c4..43791098 100644
--- a/docs/source/ngclearn.components.synapses.rst
+++ b/docs/source/ngclearn.components.synapses.rst
@@ -23,6 +23,14 @@ ngclearn.components.synapses.STPDenseSynapse module
:undoc-members:
:show-inheritance:
+ngclearn.components.synapses.alphaSynapse module
+------------------------------------------------
+
+.. automodule:: ngclearn.components.synapses.alphaSynapse
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
ngclearn.components.synapses.denseSynapse module
------------------------------------------------
@@ -31,6 +39,22 @@ ngclearn.components.synapses.denseSynapse module
:undoc-members:
:show-inheritance:
+ngclearn.components.synapses.doubleExpSynapse module
+----------------------------------------------------
+
+.. automodule:: ngclearn.components.synapses.doubleExpSynapse
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+ngclearn.components.synapses.exponentialSynapse module
+------------------------------------------------------
+
+.. automodule:: ngclearn.components.synapses.exponentialSynapse
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
ngclearn.components.synapses.staticSynapse module
-------------------------------------------------
diff --git a/docs/tutorials/neurocog/dynamic_synapses.md b/docs/tutorials/neurocog/dynamic_synapses.md
new file mode 100644
index 00000000..bc708264
--- /dev/null
+++ b/docs/tutorials/neurocog/dynamic_synapses.md
@@ -0,0 +1,421 @@
+# Lecture 4A: Dynamic Synapses and Conductance
+
+In this lesson, we will study dynamic synapses, or synaptic cable components in
+ngc-learn that evolve on fast time-scales in response to their pre-synaptic inputs.
+These types of chemical synapse components are useful for modeling time-varying
+conductance which ultimately drives eletrical current input into neuronal units
+(such as spiking cells). Here, we will learn how to build three important types of dynamic synapses in
+ngc-learn -- the exponential, the alpha, and the double-exponential synapse -- and visualize
+the time-course of their resulting conductances. In addition, we will then
+construct and study a small neuronal circuit involving a leaky integrator that
+is driven by exponential synapses relaying pulses from an excitatory and an
+inhibitory population of Poisson input encoding cells.
+
+## Synaptic Conductance Modeling
+
+Synapse models are typically used to model the post-synaptic response produced by
+action potentials (or pulses) at a pre-synaptic terminal. Assuming an electrical
+response (as opposed to a chemical one, e.g., an influx of calcium), such models seek
+to emulate the time-course of what is known as post-synaptic receptor conductance. Note
+that these dynamic synapse models will end being a bit more sophisticated than the strength
+value matrices we might initially employ (as in synapse components such as the
+[DenseSynapse](ngclearn.components.synapses.denseSynapse)).
+
+Building a dynamic synapse can be done by importing the [exponential synapse](ngclearn.components.synapses.exponentialSynapse),
+the [double-exponential synapse](ngclearn.components.synapses.doubleExpSynapse), or the [alpha synapse](ngclearn.components.synapses.alphaSynapse) from ngc-learn's in-built components and setting them up within a model context for easy analysis. Go ahead and create a Python script named `probe_synapses.py` to place
+the code you will write within.
+For the first part of this lesson, we will import all three dynamic synpapse models and compare their behavior.
+This can be done as follows (using the meta-parameters we provide in the code block below to ensure reasonable dynamics):
+
+```python
+from jax import numpy as jnp, random, jit
+from ngcsimlib.context import Context
+from ngclearn.components import ExponentialSynapse, AlphaSynapse, DoupleExpSynapse
+
+from ngcsimlib.compilers.process import Process
+from ngcsimlib.context import Context
+import ngclearn.utils.weight_distribution as dist
+
+
+dkey = random.PRNGKey(1234) ## creating seeding keys for synapses
+dkey, *subkeys = random.split(dkey, 6)
+dt = 0.1 # ms ## integration time constant
+T = 8. # ms ## total duration time
+
+## ---- build a dual-synapse system ----
+with Context("dual_syn_system") as ctx:
+ Wexp = ExponentialSynapse( ## exponential dynamic synapse
+ name="Wexp", shape=(1, 1), tau_decay=3., g_syn_bar=1., syn_rest=0., resist_scale=1.,
+ weight_init=dist.constant(value=1.), key=subkeys[0]
+ )
+ Walpha = AlphaSynapse( ## alpha dynamic synapse
+ name="Walpha", shape=(1, 1), tau_decay=1., g_syn_bar=1., syn_rest=0., resist_scale=1.,
+ weight_init=dist.constant(value=1.), key=subkeys[0]
+ )
+ Wexp2 = DoupleExpSynapse(
+ name="Wexp2", shape=(1, 1), tau_rise=1., tau_decay=3., g_syn_bar=1., syn_rest=0., resist_scale=1.,
+ weight_init=dist.constant(value=1.), key=subkeys[0]
+ )
+
+ ## set up basic simulation process calls
+ advance_process = (Process("advance_proc")
+ >> Wexp.advance_state
+ >> Walpha.advance_state
+ >> Wexp2.advance_state)
+ ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
+
+ reset_process = (Process("reset_proc")
+ >> Wexp.reset
+ >> Walpha.reset
+ >> Wexp2.reset)
+ ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+```
+
+where we notice in the above we have instantiated three different kinds of chemical synapse components
+that we will run side-by-side in order to extract their produced conductance values in response to
+the exact same input stream. For both the exponential and the alpha synapse, there are at least three important hyper-parameters to configure:
+1. `tau_decay` ($\tau_{\text{decay}}$): the synaptic conductance decay time constant (for the double-exponential synapse, we also have `tau_rise`);
+2. `g_syn_bar` ($\bar{g}_{\text{syn}}$): the maximal conductance value produced by each pulse transmitted
+ across this synapse; and,
+3. `syn_rest` ($E_{rest}$): the (post-synaptic) reversal potential for this synapse -- note that this value
+ determines the direction of current flow through the synapse, yielding a synapse with an
+ excitatory nature for non-negative values of `syn_rest` or a synapse with an inhibitory
+ nature for negative values of `syn_rest`.
+
+The flow of electrical current from a pre-synaptic neuron to a post-synaptic one is often modeled under the assumption that pre-synaptic pulses result in impermanent (transient; lasting for a short period of time) changes in the conductance of a post-synaptic neuron. As a result, the resulting conductance dynamics $g_{\text{syn}}(t)$ -- or the effect (conductance changes in the post-synaptic membrane) of a transmitter binding to and opening post-synaptic receptors -- of each of the synapses that you have built above can be simulated in ngc-learn according to one or more ordinary differential equations (ODEs), which themselves iteratively model different waveform equations of the time-course of synaptic conductance.
+For the exponential synapse, the dynamics adhere to the following ODE:
+
+$$
+\frac{\partial g_{\text{syn}}(t)}{\partial t} = -g_{\text{syn}}(t)/\tau_{\text{syn}} + \bar{g}_{\text{syn}} \sum_{k} \delta(t - t_{k})
+$$
+
+where the conductance (for a post-synaptic unit) output of this synapse is driven by a sum over all of its incoming
+pre-synaptic spikes; this ODE means that pre-synaptic spikes are filtered via an expoential kernel (i.e., a low-pass filter).
+On the other hand, for the alpha synapse, the dynamics adhere to the following coupled set of ODEs:
+
+$$
+\frac{\partial h_{\text{syn}}(t)}{\partial t} &= -h_{\text{syn}}(t)/\tau_{\text{syn}} + \bar{g}_{\text{syn}} \sum_{k} \delta(t - t_{k}) \\
+\frac{\partial g_{\text{syn}}(t)}{\partial t} &= -g_{\text{syn}}(t)/\tau_{\text{syn}} + h_{\text{syn}}(t)/\tau_{\text{syn}}
+$$
+
+where $h_{\text{syn}}(t)$ is an intermediate variable that operates in service of driving the conductance variable $g_{\text{syn}}(t)$ itself.
+The double-exponential (or difference of exponentials) synapse model looks similar to the alpha synapse except that the
+rise and fall/decay of its condutance dynamics are set separately using two different time constants, i.e., $\tau_{\text{rise}}$ and $\tau_{\text{decay}}$,
+as follows:
+
+$$
+\frac{\partial h_{\text{syn}}(t)}{\partial t} &= -h_{\text{syn}}(t)/\tau_{\text{rise}} + \big(\frac{1}{\tau_{\text{rise}}} - \frac{1}{\tau_{\text{decay}}} \big) \bar{g}_{\text{syn}} \sum_{k} \delta(t - t_{k}) \\
+\frac{\partial g_{\text{syn}}(t)}{\partial t} &= -g_{\text{syn}}(t)/\tau_{\text{decay}} + h_{\text{syn}}(t) .
+$$
+
+Finally, we seek model the electrical current that results from some amount of neurotransmitter in previous time steps.
+Thus, for both any of these three synapses, the changes in conductance are finally converted (via Ohm's law) to electrical current to produce the final derived variable $j_{\text{syn}}(t)$:
+
+$$
+j_{\text{syn}}(t) = g_{\text{syn}}(t) (v(t) - E_{\text{rev}})
+$$
+
+where $E_{\text{rev}}$ is the post-synaptic reverse potential of the ion channels that mediate the synaptic current; this is typically set to $E_{\text{rev}} = 0$ (millivolts; mV)for the case of excitatory changes and $E_{\text{rev}} = -75$ (mV) for the case of inhibitory changes. $v(t)$ is the voltage/membrane potential of the post-synaptic the synaptic cable wires to, meaning that the conductance models above are voltage-dependent (in ngc-learn, if you want voltage-independent conductance, then set `syn_rest = None`).
+
+
+### Examining the Conductances of Dynamic Synapses
+
+We can track and visualize the conductance outputs of our different dynamic synapses by running a stream of controlled pre-synaptic pulses. Specifically, we will observe the output behavior of each in response to a sparse stream, eight milliseconds in length, where only a single spike is emitted at one millisecond.
+To create the simulation of a single input pulse stream, you can write the following code:
+
+```python
+time_span = []
+g = []
+ga = []
+gexp2 = []
+ctx.reset()
+Tsteps = int(T/dt) + 1
+for t in range(Tsteps):
+ s_t = jnp.zeros((1, 1))
+ if t * dt == 1.: ## pulse at 1 ms
+ s_t = jnp.ones((1, 1))
+ Wexp.inputs.set(s_t)
+ Walpha.inputs.set(s_t)
+ Wexp.v.set(Wexp.v.value * 0)
+ Wexp2.inputs.set(s_t)
+ Walpha.v.set(Walpha.v.value * 0)
+ Wexp2.v.set(Wexp2.v.value * 0)
+ ctx.run(t=t * dt, dt=dt)
+
+ print(f"\r g = {Wexp.g_syn.value} ga = {Walpha.g_syn.value} gexp2 = {Wexp2.g_syn.value}", end="")
+ g.append(Wexp.g_syn.value)
+ ga.append(Walpha.g_syn.value)
+ time_span.append(t) #* dt)
+print()
+g = jnp.squeeze(jnp.concatenate(g, axis=1))
+g = g/jnp.amax(g)
+ga = jnp.squeeze(jnp.concatenate(ga, axis=1))
+ga = ga/jnp.amax(ga)
+gexp2 = gexp2/jnp.amax(gexp2)
+```
+
+Note that we further normalize the conductance trajectories of all synapses to lie within the range of $[0, 1]$,
+primarily for visualization purposes.
+Finally, to visualize the conductance time-course of the synapses, you can write the following:
+
+```python
+import matplotlib #.pyplot as plt
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+cmap = plt.cm.jet
+
+time_ticks = []
+time_labs = []
+for t in range(Tsteps):
+ if t % 10 == 0:
+ time_ticks.append(t)
+ time_labs.append(f"{t * dt:.1f}")
+
+## ---- plot the exponential synapse conductance time-course ----
+fig, ax = plt.subplots()
+
+gvals = ax.plot(time_span, g, '-', color='tab:red')
+#plt.xticks(time_span, time_labs)
+ax.set_xticks(time_ticks, time_labs)
+ax.set(xlabel='Time (ms)', ylabel='Conductance',
+ title='Exponential Synapse Conductance Time-Course')
+ax.grid(which="major")
+fig.savefig("exp_syn.jpg")
+plt.close()
+
+## ---- plot the alpha synapse conductance time-course ----
+fig, ax = plt.subplots()
+
+gvals = ax.plot(time_span, ga, '-', color='tab:blue')
+#plt.xticks(time_span, time_labs)
+ax.set_xticks(time_ticks, time_labs)
+ax.set(xlabel='Time (ms)', ylabel='Conductance',
+ title='Alpha Synapse Conductance Time-Course')
+ax.grid(which="major")
+fig.savefig("alpha_syn.jpg")
+plt.close()
+
+gvals = ax.plot(time_span, gexp2, '-', color='tab:blue')
+#plt.xticks(time_span, time_labs)
+ax.set_xticks(time_ticks, time_labs)
+#plt.vlines(x=[0, 10, 20, 30, 40, 50, 60, 70, 80], ymin=-0.2, ymax=1.2, colors='gray', linestyles='dashed') #, label='Vertical Lines')
+ax.set(xlabel='Time (ms)', ylabel='Conductance',
+ title='Double-Exponential Synapse Conductance Time-Course')
+ax.grid(which="major")
+fig.savefig("exp2_syn.jpg")
+plt.close()
+```
+
+which should produce and save three plots to disk. You can then compare and contrast the plots of the
+expoential, alpha synapse, and double-exponential conductance trajectories:
+
+```{eval-rst}
+.. table::
+ :align: center
+
+ +--------------------------------------------------------+----------------------------------------------------------+---------------------------------------------------------+
+ | .. image:: ../../images/tutorials/neurocog/expsyn.jpg | .. image:: ../../images/tutorials/neurocog/alphasyn.jpg | .. image:: ../../images/tutorials/neurocog/exp2syn.jpg |
+ | :width: 400px | :width: 400px | :width: 400px |
+ | :align: center | :align: center | :align: center |
+ +--------------------------------------------------------+----------------------------------------------------------+---------------------------------------------------------+
+```
+
+Note that the alpha synapse (right figure) would produce a more realistic fit to recorded synaptic currents (as it attempts to model
+the rise and fall of current in a less simplified manner) at the cost of extra compute, given it uses two ODEs to
+emulate condutance, as opposed to the faster yet less-biophysically-realistic exponential synapse (left figure).
+
+## Excitatory-Inhibitory Driven Dynamics
+
+For this next part of the lesson, create a new Python script named `sim_ei_dynamics.py` for the next portions of code
+you will write.
+Let's next examine a more interesting use-case of the above dynamic synapses -- modeling excitatory (E) and inhibitory (I)
+pressures produced by different groups of pre-synaptic spike trains. This allows us to examine a very common
+and often-used conductance model that is paired with spiking cells such as the leaky integrate-and-fire (LIF). Specifically,
+we seek to simulate the following post-synaptic conductance dynamics for a single LIF unit:
+
+$$
+\tau_{m} \frac{\partial v(t)}{\partial t} = -\big( v(t) - E_{L} \big) - \frac{g_{E}(t)}{g_{L}} \big( v(t) - E_{E} \big) - \frac{g_{I}(t)}{g_{L}} \big( v(t) - E_{I} \big)
+$$
+
+where $g_{L}$ is the leak conductance value for the post-synaptic LIF, $g_{E}(t)$ is the post-synaptic conductance produced by excitatory pre-synaptic spike trains (with excitatory synaptic reverse potential $E_{E}$), and $g_{I}(t)$ is the post-synaptic conductance produced by inhibitory pre-synaptic spike trains (with inhibitory synaptic reverse potential $E_{I}$). Note that the first term of the above ODE is the normal leak portion of the LIF's standard dynamics (scaled by conductance factor $g_{L}$) and the last two terms of the above ODE can be modeled each separately with a dynamic synapse. To differentiate between excitatory and inhibitory conductance changes, we will just configure a different reverse potential for each to induce either excitatory (i.e., $E_{\text{syn}} = E_{E} = 0$ mV) or inhibitory (i.e., $E_{\text{syn}} = E_{I} = -80$ mV) pressure/drive.
+
+We will specifically model the excitatory and inhibitory conductance changes using exponential synapses and the input spike trains for each with Poisson encoding cells; in other words, two different groups of Poisson cells will be wired to a single LIF cell via exponential dynamic synapses. The code for doing this is as follows:
+
+```python
+from jax import numpy as jnp, random, jit
+from ngcsimlib.context import Context
+from ngclearn.components import ExponentialSynapse, PoissonCell, LIFCell
+from ngclearn.operations import summation
+
+from ngcsimlib.compilers.process import Process
+from ngcsimlib.context import Context
+import ngclearn.utils.weight_distribution as dist
+
+## create seeding keys
+dkey = random.PRNGKey(1234)
+dkey, *subkeys = random.split(dkey, 6)
+
+## simulation properties
+dt = 0.1 # ms
+T = 1000. # ms ## total duration time
+
+## post-syn LIF cell properties
+tau_m = 10.
+g_L = 10.
+v_rest = -75.
+v_thr = -55.
+
+## excitatory group properties
+exc_freq = 10. # Hz
+n_exc = 80
+g_e_bar = 2.4
+tau_syn_exc = 2.
+E_rest_exc = 0.
+
+## inhibitory group properties
+inh_freq = 10. # Hz
+n_inh = 20
+g_i_bar = 2.4
+tau_syn_inh = 5.
+E_rest_inh = -80.
+
+Tsteps = int(T/dt)
+
+## ---- build a simple E-I spiking circuit ----
+with Context("ei_snn") as ctx:
+ pre_exc = PoissonCell("pre_exc", n_units=n_exc, target_freq=exc_freq, key=subkeys[0]) ## pre-syn excitatory group
+ pre_inh = PoissonCell("pre_inh", n_units=n_inh, target_freq=inh_freq, key=subkeys[1]) ## pre-syn inhibitory group
+ Wexc = ExponentialSynapse( ## dynamic synapse between excitatory group and LIF
+ name="Wexc", shape=(n_exc,1), tau_decay=tau_syn_exc, g_syn_bar=g_e_bar, syn_rest=E_rest_exc, resist_scale=1./g_L,
+ weight_init=dist.constant(value=1.), key=subkeys[2]
+ )
+ Winh = ExponentialSynapse( ## dynamic synapse between inhibitory group and LIF
+ name="Winh", shape=(n_inh, 1), tau_decay=tau_syn_inh, g_syn_bar=g_i_bar, syn_rest=E_rest_inh, resist_scale=1./g_L,
+ weight_init=dist.constant(value=1.), key=subkeys[2]
+ )
+ post_exc = LIFCell( ## post-syn LIF cell
+ "post_exc", n_units=1, tau_m=tau_m, resist_m=1., thr=v_thr, v_rest=v_rest, conduct_leak=1., v_reset=-75.,
+ tau_theta=0., theta_plus=0., refract_time=2., key=subkeys[3]
+ )
+
+ Wexc.inputs << pre_exc.outputs
+ Winh.inputs << pre_inh.outputs
+ Wexc.v << post_exc.v ## couple voltage to exc synapse
+ Winh.v << post_exc.v ## couple voltage to inh synapse
+ post_exc.j << summation(Wexc.i_syn, Winh.i_syn) ## sum together excitatory & inhibitory pressures
+
+ advance_process = (Process("advance_proc")
+ >> pre_exc.advance_state
+ >> pre_inh.advance_state
+ >> Wexc.advance_state
+ >> Winh.advance_state
+ >> post_exc.advance_state)
+ # ctx.wrap_and_add_command(advance_process.pure, name="run")
+ ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
+
+ reset_process = (Process("reset_proc")
+ >> pre_exc.reset
+ >> pre_inh.reset
+ >> Wexc.reset
+ >> Winh.reset
+ >> post_exc.reset)
+ ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+```
+
+### Examining the Simple Spiking Circuit's Behavior
+
+To run the above spiking circuit, we then write the next block of code (making sure to track/store the resulting membrane potential and pulse values emitted by the post-synaptic LIF):
+
+```python
+volts = []
+time_span = []
+spikes = []
+
+ctx.reset()
+pre_exc.inputs.set(jnp.ones((1, n_exc)))
+pre_inh.inputs.set(jnp.ones((1, n_inh)))
+post_exc.v.set(post_exc.v.value * 0 - 65.) ## initial condition for LIF is -65 mV
+volts.append(post_exc.v.value)
+time_span.append(0.)
+Tsteps = int(T/dt) + 1
+for t in range(1, Tsteps):
+ ctx.run(t=t * dt, dt=dt)
+ print(f"\r v {post_exc.v.value}", end="")
+ volts.append(post_exc.v.value)
+ spikes.append(post_exc.s.value)
+ time_span.append(t) #* dt)
+print()
+volts = jnp.squeeze(jnp.concatenate(volts, axis=1))
+spikes = jnp.squeeze(jnp.concatenate(spikes, axis=1))
+```
+
+from which we may then write the following plotting code to visualize the post-synaptic LIF unit's membrane potential time-course along with any spikes it might have produced in response to the pre-synaptic spike trains:
+
+```python
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+cmap = plt.cm.jet
+
+time_ticks = []
+time_labs = []
+time_ticks.append(0)
+time_labs.append(f"{0.}")
+tdiv = 1000
+for t in range(Tsteps):
+ if t % tdiv == 0:
+ time_ticks.append(t)
+ time_labs.append(f"{t * dt:.0f}")
+
+fig, ax = plt.subplots()
+
+volt_vals = ax.plot(time_span, volts, '-.', color='tab:red')
+stat = jnp.where(spikes > 0.)
+indx = (stat[0] * 1. - 1.).tolist()
+v_thr_below = -0.75
+v_thr_above = 2.
+spk = ax.vlines(x=indx, ymin=v_thr-v_thr_below, ymax=v_thr+v_thr_above, colors='black', ls='-', lw=2)
+_v_thr = v_thr
+ax.hlines(y=_v_thr, xmin=0., xmax=time_span[-1], colors='blue', ls='-.', lw=2)
+
+ax.set(xlabel='Time (ms)', ylabel='Voltage',
+ title='Exponential Synapse LIF')
+ax.grid()
+fig.savefig("ei_circuit_dynamics.jpg")
+```
+
+which should produce a figure depicting dynamics similar to the one below. Black tick
+marks indicate post-synaptic pulses whereas the horizontal dashed blue shows the LIF unit's
+voltage threshold.
+
+
+```{eval-rst}
+.. table::
+ :align: center
+
+ +--------------------------------------------------------------------+
+ | .. image:: ../../images/tutorials/neurocog/ei_circuit_dynamics.jpg |
+ | :width: 400px |
+ | :align: center |
+ +--------------------------------------------------------------------+
+```
+
+Notice that the above shows the behavior of the post-synaptic LIF in response to the integration of pulses coming from two Poisson spike trains both at rates of $10$ Hz (since both `exc_freq` and `inh_freq` have been set to ten). Messing with the frequencies of the excitatory and inhibitory pulse trains can lead to sparser or denser post-synaptic spike outputs. For instance, if we increase the frequency of the excitatory train to $15$ Hz (keeping the inhibitory one at $10$ Hz), we get a denser post-synaptic output pulse pattern as in the left figure below. In contrast, if we instead increase the inhibitory frequency to $30$ Hz (keeping the excitatory at $10$ Hz), we obtain a sparser post-synaptic output pulse train as in the right figure below.
+
+
+```{eval-rst}
+.. table::
+ :align: center
+
+ +-----------------------------------------------------------------------+-----------------------------------------------------------------------+
+ | .. image:: ../../images/tutorials/neurocog/ei_circuit_dense_exc.jpg | .. image:: ../../images/tutorials/neurocog/ei_circuit_sparse_inh.jpg |
+ | :width: 400px | :width: 400px |
+ | :align: center | :align: center |
+ +-----------------------------------------------------------------------+-----------------------------------------------------------------------+
+```
+
+## References
+
+[1] Sterratt, David, et al. Principles of computational modelling in neuroscience. Cambridge university
+press, 2023.
+
Error cells are a particularly useful component cell that offers a simple and
fast way of computing mismatch signals, i.e., error values that compare a
target value against a predicted value. In predictive coding literature, mismatch
@@ -13,8 +17,11 @@ key examples of where error neurons come into play). In this lesson, we will
briefly review one of the most commonly used ones -- the
[Gaussian error cell](ngclearn.components.neurons.graded.gaussianErrorCell).
+
## Calculating Mismatch Values with the Gaussian Error Cell
+
+
The Gaussian error cell, much like most error neurons, is in fact a derived
calculation when considering a cost function. Specifically, this error cell
component inherits its name from the fact that it is producing output values
diff --git a/docs/tutorials/neurocog/hebbian.md b/docs/tutorials/neurocog/hebbian.md
index e2668fc6..8e67754c 100644
--- a/docs/tutorials/neurocog/hebbian.md
+++ b/docs/tutorials/neurocog/hebbian.md
@@ -1,4 +1,4 @@
-# Lecture 4A: Hebbian Synaptic Plasticity
+# Lecture 4B: Hebbian Synaptic Plasticity
In ngc-learn, synaptic plasticity is a key concept at the forefront of its
design in order to promote research into novel ideas and framings of how
diff --git a/docs/tutorials/neurocog/hodgkin_huxley_cell.md b/docs/tutorials/neurocog/hodgkin_huxley_cell.md
index 1580ab1c..44e3b0a7 100755
--- a/docs/tutorials/neurocog/hodgkin_huxley_cell.md
+++ b/docs/tutorials/neurocog/hodgkin_huxley_cell.md
@@ -83,9 +83,9 @@ Formally, the core dynamics of the H-H cell can be written out as follows:
$$
\tau_v \frac{\partial \mathbf{v}_t}{\partial t} &= \mathbf{j}_t - g_Na * \mathbf{m}^3_t * \mathbf{h}_t * (\mathbf{v}_t - v_Na) - g_K * \mathbf{n}^4_t * (\mathbf{v}_t - v_K) - g_L * (\mathbf{v}_t - v_L) \\
-\frac{\partial \mathbf{n}_t}{\partial t} &= alpha_n(\mathbf{v}_t) * (1 - \mathbf{n}_t) - beta_n(\mathbf{v}_t) * \mathbf{n}_t \\
-\frac{\partial \mathbf{m}_t}{\partial t} &= alpha_m(\mathbf{v}_t) * (1 - \mathbf{m}_t) - beta_m(\mathbf{v}_t) * \mathbf{m}_t \\
-\frac{\partial \mathbf{h}_t}{\partial t} &= alpha_h(\mathbf{v}_t) * (1 - \mathbf{h}_t) - beta_h(\mathbf{v}_t) * \mathbf{h}_t
+\frac{\partial \mathbf{n}_t}{\partial t} &= \alpha_n(\mathbf{v}_t) * (1 - \mathbf{n}_t) - \beta_n(\mathbf{v}_t) * \mathbf{n}_t \\
+\frac{\partial \mathbf{m}_t}{\partial t} &= \alpha_m(\mathbf{v}_t) * (1 - \mathbf{m}_t) - \beta_m(\mathbf{v}_t) * \mathbf{m}_t \\
+\frac{\partial \mathbf{h}_t}{\partial t} &= \alpha_h(\mathbf{v}_t) * (1 - \mathbf{h}_t) - \beta_h(\mathbf{v}_t) * \mathbf{h}_t
$$
where we observe that the above four-dimensional set of dynamics is composed of nonlinear ODEs. Notice that, in each gate or channel probability ODE, there are two generator functions (each of which is a function of the membrane potential $\mathbf{v}_t$) that produces the necessary dynamic coefficients at time $t$; $\alpha_x(\mathbf{v}_t)$ and $\beta_x(\mathbf{v}_t)$ produce different biopphysical weighting values depending on which channel $x = \{n, m, h\}$ they are related to.
diff --git a/docs/tutorials/neurocog/index.rst b/docs/tutorials/neurocog/index.rst
index 1fb23384..326591c2 100644
--- a/docs/tutorials/neurocog/index.rst
+++ b/docs/tutorials/neurocog/index.rst
@@ -58,8 +58,9 @@ work towards more advanced concepts.
.. toctree::
:maxdepth: 1
- :caption: Forms of Plasticity
+ :caption: Synapses and Forms of Plasticity
+ dynamic_synapses
hebbian
stdp
mod_stdp
diff --git a/docs/tutorials/neurocog/mod_stdp.md b/docs/tutorials/neurocog/mod_stdp.md
index fbb0b7db..3a76de37 100755
--- a/docs/tutorials/neurocog/mod_stdp.md
+++ b/docs/tutorials/neurocog/mod_stdp.md
@@ -1,4 +1,4 @@
-# Lecture 4C: Reward-Modulated Spike-Timing-Dependent Plasticity
+# Lecture 4D: Reward-Modulated Spike-Timing-Dependent Plasticity
In this lesson, we will build on the notions of spike-timing-dependent
plasticity (STDP), covered [earlier here](../neurocog/stdp.md), to construct
@@ -247,9 +247,8 @@ for i in range(T_max):
```
which will run all three models simultaneously for `200` simulated milliseconds
-and collect statistics of interest. We may then finally make several
-plots of what happens under each STDP mode. First, we will plot the resulting
-synaptic magnitude over time, like so:
+and collect statistics of interest. We may then finally make several plots of what happens under each STDP mode
+(reproducing some key results in [1]. First, we will plot the resulting synaptic magnitude over time, like so:
```python
import matplotlib.pyplot as plt
@@ -374,3 +373,8 @@ modulated STDP updates will only occur when the signal is non-zero; this is
the advantage that MSTDP-ET offers over MSTDP as the synaptic change
dynamics persist (yet decay) in between reward presentation times and thus
MSTDP-ET will be more effective in cases when the reward signal is delayed.
+
+## References
+
+[1] Florian, Răzvan V. "Reinforcement learning through modulation of spike-timing-dependent synaptic plasticity."
+Neural computation 19.6 (2007): 1468-1502.
diff --git a/docs/tutorials/neurocog/short_term_plasticity.md b/docs/tutorials/neurocog/short_term_plasticity.md
index 9bd74ed5..b225f3c5 100755
--- a/docs/tutorials/neurocog/short_term_plasticity.md
+++ b/docs/tutorials/neurocog/short_term_plasticity.md
@@ -1,4 +1,4 @@
-# Lecture 4D: Short-Term Plasticity
+# Lecture 4E: Short-Term Plasticity
In this lesson, we will study how short-term plasticity (STP) [1] dynamics
-- where synaptic efficacy is cast in terms of the history of presynaptic activity --
@@ -19,7 +19,7 @@ synapse that evolves according to STP. We will first write our
simulation of this dynamic synapse from the perspective of STF-dominated
dynamics, plotting out the results under two different Poisson spike trains
with different spiking frequencies. Then, we will modify our simulation
-to emulate dynamics from a STD-dominated perspective.
+to emulate dynamics from an STD-dominated perspective.
### Starting with Facilitation-Dominated Dynamics
@@ -39,10 +39,10 @@ some mixture of the two.
Ultimately, the above means that, in the context of spiking cells, when a
pre-synaptic neuron emits a pulse, this act will affect the relative magnitude
-of the synapse's efficacy;
-in some cases, this will result in an increase (facilitation) and, in others,
-this will result in a decrease (depression) that lasts over a short period
-of time (several hundreds to thousands of milliseconds in many instances).
+of the synapse's efficacy. In some cases, this will result in an increase
+(facilitation) and, in others, this will result in a decrease (depression)
+that lasts over a short period of time (several hundreds to thousands of
+milliseconds in many instances).
As a result of considering synapses to have a dynamic nature to them, both over
short and long time-scales, plasticity can now be thought of as a stimulus and
resource-dependent quantity, reflecting an important biophysical aspect that
@@ -87,13 +87,16 @@ tau_d = 50. # ms
plot_fname = "{}Hz_stp_{}.jpg".format(firing_rate_e, tag)
with Context("Model") as model:
- W = STPDenseSynapse("W", shape=(1, 1), weight_init=dist.constant(value=2.5),
- resources_init=dist.constant(value=Rval),
- tau_f=tau_f, tau_d=tau_d, key=subkeys[0])
+ W = STPDenseSynapse(
+ "W", shape=(1, 1), weight_init=dist.constant(value=2.5),
+ resources_init=dist.constant(value=Rval), tau_f=tau_f, tau_d=tau_d,
+ key=subkeys[0]
+ )
z0 = PoissonCell("z0", n_units=1, target_freq=firing_rate_e, key=subkeys[0])
- z1 = LIFCell("z1", n_units=1, tau_m=tau_m, resist_m=(tau_m / dt) * R_m,
- v_rest=-60., v_reset=-70., thr=-50.,
- tau_theta=0., theta_plus=0., refract_time=0.)
+ z1 = LIFCell(
+ "z1", n_units=1, tau_m=tau_m, resist_m=(tau_m / dt) * R_m, v_rest=-60.,
+ v_reset=-70., thr=-50., tau_theta=0., theta_plus=0., refract_time=0.
+ )
W.inputs << z0.outputs ## z0 -> W
z1.j << W.outputs ## W -> z1
@@ -156,7 +159,7 @@ resources ready for the dynamic synapse's use.
### Simulating and Visualizing STF
-Now that we understand the basics of how an ngc-learn STP works, we can next
+Now that we understand the basics of how an ngc-learn STP synapse works, we can next
try it out on a simple pre-synaptic Poisson spike train. Writing out the
simulated input Poisson spike train and our STP model's processing of this
data can be done as follows:
diff --git a/docs/tutorials/neurocog/stdp.md b/docs/tutorials/neurocog/stdp.md
index 16423c8d..b8e889a0 100755
--- a/docs/tutorials/neurocog/stdp.md
+++ b/docs/tutorials/neurocog/stdp.md
@@ -1,4 +1,4 @@
-# Lecture 4B: Spike-Timing-Dependent Plasticity
+# Lecture 4C: Spike-Timing-Dependent Plasticity
In the context of spiking neuronal networks, one of the most important forms
of adaptation that is often simulated is that of spike-timing-dependent
diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py
index 38a829f8..af856c1a 100644
--- a/ngclearn/components/__init__.py
+++ b/ngclearn/components/__init__.py
@@ -38,6 +38,9 @@
from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse
from .synapses.hebbian.BCMSynapse import BCMSynapse
from .synapses.STPDenseSynapse import STPDenseSynapse
+from .synapses.exponentialSynapse import ExponentialSynapse
+from .synapses.doubleExpSynapse import DoupleExpSynapse
+from .synapses.alphaSynapse import AlphaSynapse
## point to convolutional component types
from .synapses.convolution.convSynapse import ConvSynapse
diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py
index 114d3785..29b5f267 100755
--- a/ngclearn/components/neurons/graded/gaussianErrorCell.py
+++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py
@@ -65,6 +65,12 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
self.mask = Compartment(restVals + 1.0)
+ @staticmethod
+ def eval_log_density(target, mu, Sigma):
+ _dmu = (target - mu)
+ log_density = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
+ return log_density
+
@transition(output_compartments=["dmu", "dtarget", "dSigma", "L", "mask"])
@staticmethod
def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output
@@ -79,6 +85,7 @@ def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian e
dtarget = -dmu # reverse of e
dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma
L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
+ #L = GaussianErrorCell.eval_log_density(target, mu, Sigma)
dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
dtarget = dtarget * modulator * mask
diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py
index 20c9bc98..e55ce8ff 100755
--- a/ngclearn/components/neurons/graded/rateCell.py
+++ b/ngclearn/components/neurons/graded/rateCell.py
@@ -145,6 +145,8 @@ class RateCell(JaxComponent): ## Rate-coded/real-valued cell
act_fx: string name of activation function/nonlinearity to use
+ output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.)
+
integration_type: type of integration to use for this cell's dynamics;
current supported forms include "euler" (Euler/RK-1 integration)
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
@@ -157,12 +159,13 @@ class RateCell(JaxComponent): ## Rate-coded/real-valued cell
"""
# Define Functions
- def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity",
- threshold=("none", 0.), integration_type="euler",
- batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs):
+ def __init__(
+ self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", output_scale=1., threshold=("none", 0.),
+ integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs):
super().__init__(name, **kwargs)
## membrane parameter setup (affects ODE integration)
+ self.output_scale = output_scale
self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode
self.is_stateful = is_stateful
if isinstance(tau_m, float):
@@ -211,8 +214,9 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
@transition(output_compartments=["j", "j_td", "z", "zF"])
@staticmethod
- def advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
- resist_scale, thresholdType, thr_lmbda, is_stateful, j, j_td, z):
+ def advance_state(
+ dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, resist_scale, thresholdType, thr_lmbda, is_stateful,
+ output_scale, j, j_td, z):
#if tau_m > 0.:
if is_stateful:
### run a step of integration over neuronal dynamics
@@ -231,12 +235,12 @@ def advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
elif thresholdType == "cauchy_threshold":
tmp_z = threshold_cauchy(tmp_z, thr_lmbda)
z = tmp_z ## pre-activation function value(s)
- zF = fx(z) ## post-activation function value(s)
+ zF = fx(z) * output_scale ## post-activation function value(s)
else:
## run in "stateless" mode (when no membrane time constant provided)
j_total = j + j_td
z = _run_cell_stateless(j_total)
- zF = fx(z)
+ zF = fx(z) * output_scale
return j, j_td, z, zF
@transition(output_compartments=["j", "j_td", "z", "zF"])
diff --git a/ngclearn/components/neurons/graded/rateCellOld.py b/ngclearn/components/neurons/graded/rateCellOld.py
deleted file mode 100644
index 6962810c..00000000
--- a/ngclearn/components/neurons/graded/rateCellOld.py
+++ /dev/null
@@ -1,350 +0,0 @@
-from jax import numpy as jnp, random, jit
-from functools import partial
-from ngclearn.utils import tensorstats
-from ngclearn import resolver, Component, Compartment
-from ngclearn.components.jaxComponent import JaxComponent
-from ngclearn.utils.model_utils import create_function, threshold_soft, \
- threshold_cauchy
-from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
- step_euler, step_rk2, step_rk4
-
-def _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma):
- z_leak = z # * 2 ## Default: assume Gaussian
- dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
- return dz_dt
-
-def _dfz_internal_laplacian(z, j, j_td, tau_m, leak_gamma):
- z_leak = jnp.sign(z) ## d/dx of Laplace is signum
- dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
- return dz_dt
-
-def _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma):
- z_leak = (z * 2)/(1. + jnp.square(z))
- dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
- return dz_dt
-
-def _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma):
- z_leak = jnp.exp(-jnp.square(z)) * z * 2
- dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
- return dz_dt
-
-
-def _dfz_gaussian(t, z, params): ## diff-eq dynamics wrapper
- j, j_td, tau_m, leak_gamma = params
- dz_dt = _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma)
- return dz_dt
-
-def _dfz_laplacian(t, z, params): ## diff-eq dynamics wrapper
- j, j_td, tau_m, leak_gamma = params
- dz_dt = _dfz_internal_laplacian(z, j, j_td, tau_m, leak_gamma)
- return dz_dt
-
-def _dfz_cauchy(t, z, params): ## diff-eq dynamics wrapper
- j, j_td, tau_m, leak_gamma = params
- dz_dt = _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma)
- return dz_dt
-
-def _dfz_exp(t, z, params): ## diff-eq dynamics wrapper
- j, j_td, tau_m, leak_gamma = params
- dz_dt = _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma)
- return dz_dt
-
-@jit
-def _modulate(j, dfx_val):
- """
- Apply a signal modulator to j (typically of the form of a derivative/dampening function)
-
- Args:
- j: current/stimulus value to modulate
-
- dfx_val: modulator signal
-
- Returns:
- modulated j value
- """
- return j * dfx_val
-
-def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=0):
- """
- Runs leaky rate-coded state dynamics one step in time.
-
- Args:
- dt: integration time constant
-
- j: input (bottom-up) electrical/stimulus current
-
- j_td: modulatory (top-down) electrical/stimulus pressure
-
- z: current value of membrane/state
-
- tau_m: membrane/state time constant
-
- leak_gamma: strength of leak to apply to membrane/state
-
- integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2, 2 --> RK4)
-
- priorType: scale-shift prior distribution to impose over neural dynamics
-
- Returns:
- New value of membrane/state for next time step
- """
- _dfz = {
- 0: _dfz_gaussian,
- 1: _dfz_laplacian,
- 2: _dfz_cauchy,
- 3: _dfz_exp
- }.get(priorType, _dfz_gaussian)
- if integType == 1:
- params = (j, j_td, tau_m, leak_gamma)
- _, _z = step_rk2(0., z, _dfz, dt, params)
- elif integType == 2:
- params = (j, j_td, tau_m, leak_gamma)
- _, _z = step_rk4(0., z, _dfz, dt, params)
- else:
- params = (j, j_td, tau_m, leak_gamma)
- _, _z = step_euler(0., z, _dfz, dt, params)
- return _z
-
-@jit
-def _run_cell_stateless(j):
- """
- A simplification of running a stateless set of dynamics over j (an identity
- functional form of dynamics).
-
- Args:
- j: stimulus to do nothing to
-
- Returns:
- the stimulus
- """
- return j + 0
-
-class RateCell(JaxComponent): ## Rate-coded/real-valued cell
- """
- A non-spiking cell driven by the gradient dynamics of neural generative
- coding-driven predictive processing.
-
- The specific differential equation that characterizes this cell
- is (for adjusting v, given current j, over time) is:
-
- | tau_m * dz/dt = lambda * prior(z) + (j + j_td)
- | where j is the set of general incoming input signals (e.g., message-passed signals)
- | and j_td is taken to be the set of top-down pressure signals
-
- | --- Cell Input Compartments: ---
- | j - input pressure (takes in external signals)
- | j_td - input/top-down pressure input (takes in external signals)
- | --- Cell State Compartments ---
- | z - rate activity
- | --- Cell Output Compartments: ---
- | zF - post-activation function activity, i.e., fx(z)
-
- Args:
- name: the string name of this cell
-
- n_units: number of cellular entities (neural population size)
-
- tau_m: membrane/state time constant (milliseconds)
-
- prior: a kernel for specifying the type of centered scale-shift distribution
- to impose over neuronal dynamics, applied to each neuron or
- dimension within this component (Default: ("gaussian", 0)); this is
- a tuple with 1st element containing a string name of the distribution
- one wants to use while the second value is a `leak rate` scalar
- that controls the influence/weighting that this distribution
- has on the dynamics; for example, ("laplacian, 0.001") means that a
- centered laplacian distribution scaled by `0.001` will be injected
- into this cell's dynamics ODE each step of simulated time
-
- :Note: supported scale-shift distributions include "laplacian",
- "cauchy", "exp", and "gaussian"
-
- act_fx: string name of activation function/nonlinearity to use
-
- integration_type: type of integration to use for this cell's dynamics;
- current supported forms include "euler" (Euler/RK-1 integration)
- and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
-
- :Note: setting the integration type to the midpoint method will
- increase the accuray of the estimate of the cell's evolution
- at an increase in computational cost (and simulation time)
-
- resist_scale: a scaling factor applied to incoming pressure `j` (default: 1)
- """
-
- # Define Functions
- def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity",
- threshold=("none", 0.), integration_type="euler",
- batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs):
- super().__init__(name, **kwargs)
-
- ## membrane parameter setup (affects ODE integration)
- self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode
- self.is_stateful = is_stateful
- if isinstance(tau_m, float):
- if tau_m <= 0: ## trigger stateless mode
- self.is_stateful = False
- priorType, leakRate = prior
- self.priorType = {
- "gaussian": 0,
- "laplacian": 1,
- "cauchy": 2,
- "exp": 3
- }.get(priorType, 0) ## type of scale-shift prior to impose over the leak
- self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior)
- thresholdType, thr_lmbda = threshold
- self.thresholdType = thresholdType ## type of thresholding function to use
- self.thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics
- self.resist_scale = resist_scale ## a "resistance" scaling factor
-
- ## integration properties
- self.integrationType = integration_type
- self.intgFlag = get_integrator_code(self.integrationType)
-
- ## Layer size setup
- _shape = (batch_size, n_units) ## default shape is 2D/matrix
- if shape is None:
- shape = (n_units,) ## we set shape to be equal to n_units if nothing provided
- else:
- _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor
- self.shape = shape
- self.n_units = n_units
- self.batch_size = batch_size
-
- omega_0 = None
- if act_fx == "sine":
- omega_0 = kwargs["omega_0"]
- self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0)
-
- # compartments (state of the cell & parameters will be updated through stateless calls)
- restVals = jnp.zeros(_shape)
- self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current
- self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity
- self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure
- self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity
-
- @staticmethod
- def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType,
- resist_scale, thresholdType, thr_lmbda, is_stateful, j, j_td, z):
- #if tau_m > 0.:
- if is_stateful:
- ### run a step of integration over neuronal dynamics
- ## Notes:
- ## self.pressure <-- "top-down" expectation / contextual pressure
- ## self.current <-- "bottom-up" data-dependent signal
- dfx_val = dfx(z)
- j = _modulate(j, dfx_val)
- j = j * resist_scale
- tmp_z = _run_cell(dt, j, j_td, z,
- tau_m, leak_gamma=priorLeakRate,
- integType=intgFlag, priorType=priorType)
- ## apply optional thresholding sub-dynamics
- if thresholdType == "soft_threshold":
- tmp_z = threshold_soft(tmp_z, thr_lmbda)
- elif thresholdType == "cauchy_threshold":
- tmp_z = threshold_cauchy(tmp_z, thr_lmbda)
- z = tmp_z ## pre-activation function value(s)
- zF = fx(z) ## post-activation function value(s)
- else:
- ## run in "stateless" mode (when no membrane time constant provided)
- j_total = j + j_td
- z = _run_cell_stateless(j_total)
- zF = fx(z)
- return j, j_td, z, zF
-
- @resolver(_advance_state)
- def advance_state(self, j, j_td, z, zF):
- self.j.set(j)
- self.j_td.set(j_td)
- self.z.set(z)
- self.zF.set(zF)
-
- @staticmethod
- def _reset(batch_size, shape): #n_units
- _shape = (batch_size, shape[0])
- if len(shape) > 1:
- _shape = (batch_size, shape[0], shape[1], shape[2])
- restVals = jnp.zeros(_shape)
- return tuple([restVals for _ in range(4)])
-
- @resolver(_reset)
- def reset(self, j, zF, j_td, z):
- self.j.set(j) # electrical current
- self.zF.set(zF) # rate-coded output - activity
- self.j_td.set(j_td) # top-down electrical current - pressure
- self.z.set(z) # rate activity
-
- def save(self, directory, **kwargs):
- ## do a protected save of constants, depending on whether they are floats or arrays
- tau_m = (self.tau_m if isinstance(self.tau_m, float)
- else jnp.ones([[self.tau_m]]))
- priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float)
- else jnp.ones([[self.priorLeakRate]]))
- resist_scale = (self.resist_scale if isinstance(self.resist_scale, float)
- else jnp.ones([[self.resist_scale]]))
-
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name,
- tau_m=tau_m, priorLeakRate=priorLeakRate,
- resist_scale=resist_scale) #, key=self.key.value)
-
- def load(self, directory, seeded=False, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- ## constants loaded in
- self.tau_m = data['tau_m']
- self.priorLeakRate = data['priorLeakRate']
- self.resist_scale = data['resist_scale']
- #if seeded:
- # self.key.set(data['key'])
-
- @classmethod
- def help(cls): ## component help function
- properties = {
- "cell_type": "RateCell - evolves neurons according to rate-coded/"
- "continuous dynamics "
- }
- compartment_props = {
- "inputs":
- {"j": "External input stimulus value(s)",
- "j_td": "External top-down input stimulus value(s); these get "
- "multiplied by the derivative of f(x), i.e., df(x)"},
- "states":
- {"z": "Update to rate-coded continuous dynamics; value at time t"},
- "outputs":
- {"zF": "Nonlinearity/function applied to rate-coded dynamics; f(z)"},
- }
- hyperparams = {
- "n_units": "Number of neuronal cells to model in this layer",
- "batch_size": "Batch size dimension of this component",
- "tau_m": "Cell state/membrane time constant",
- "prior": "What kind of kurtotic prior to place over neuronal dynamics?",
- "act_fx": "Elementwise activation function to apply over cell state `z`",
- "threshold": "What kind of iterative thresholding function to place over neuronal dynamics?",
- "integration_type": "Type of numerical integration to use for the cell dynamics",
- }
- info = {cls.__name__: properties,
- "compartments": compartment_props,
- "dynamics": "tau_m * dz/dt = Prior(z; gamma) + (j + j_td)",
- "hyperparameters": hyperparams}
- return info
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
-if __name__ == '__main__':
- from ngcsimlib.context import Context
- with Context("Bar") as bar:
- X = RateCell("X", 9, 0.03)
- print(X)
\ No newline at end of file
diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py
index 9566f3c9..bf7f2c14 100644
--- a/ngclearn/components/neurons/spiking/LIFCell.py
+++ b/ngclearn/components/neurons/spiking/LIFCell.py
@@ -14,19 +14,29 @@
#from ngcsimlib.component import Component
from ngcsimlib.compartment import Compartment
-#@jit
-def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics
- mask = (rfr >= refract_T) * 1. # get refractory mask
- ## update voltage / membrane potential
- dv_dt = (v_rest - v) * v_decay + (j * mask)
- dv_dt = dv_dt * (1./tau_m)
- return dv_dt
+# def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics
+# mask = (rfr >= refract_T) * 1. # get refractory mask
+# ## update voltage / membrane potential
+# dv_dt = (v_rest - v) * v_decay + (j * mask)
+# dv_dt = dv_dt * (1./tau_m)
+# return dv_dt
+#
+# def _dfv(t, v, params): ## voltage dynamics wrapper
+# j, rfr, tau_m, refract_T, v_rest, v_decay = params
+# dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay)
+# return dv_dt
+
+
def _dfv(t, v, params): ## voltage dynamics wrapper
- j, rfr, tau_m, refract_T, v_rest, v_decay = params
- dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay)
+ j, rfr, tau_m, refract_T, v_rest, g_L = params
+ mask = (rfr >= refract_T) * 1. # get refractory mask
+ ## update voltage / membrane potential
+ dv_dt = (v_rest - v) * g_L + (j * mask)
+ dv_dt = dv_dt * (1. / tau_m)
return dv_dt
+
#@partial(jit, static_argnums=[3, 4])
def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
### Runs homeostatic threshold update dynamics one step (via Euler integration).
@@ -38,6 +48,7 @@ def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
#_V_theta = V_theta + -V_theta * (dt/tau_theta) + S * alpha
return _v_theta
+
class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
"""
A spiking cell based on leaky integrate-and-fire (LIF) neuronal dynamics.
@@ -73,14 +84,14 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
thr: base value for adaptive thresholds that govern short-term
plasticity (in milliVolts, or mV; default: -52. mV)
- v_rest: membrane resting potential (in mV; default: -65 mV)
+ v_rest: reversal potential or membrane resting potential (in mV; default: -65 mV)
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
a neuronal cell's membrane potential will be set to this value;
(default: -60 mV)
- v_decay: decay factor applied to voltage leak (Default: 1.); setting this
- to 0 mV recovers pure integrate-and-fire (IF) dynamics
+ conduct_leak: leak conductance (g_L) value or decay factor applied to voltage leak
+ (Default: 1.); setting this to 0 mV recovers pure integrate-and-fire (IF) dynamics
tau_theta: homeostatic threshold time constant
@@ -112,16 +123,15 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
"arctan" (arc-tangent estimator), and "secant_lif" (the
LIF-specialized secant estimator)
- lower_clamp_voltage: if True, this will ensure voltage never is below
- the value of `v_rest` (default: True)
+ v_min: minimum voltage to clamp dynamics to (Default: None)
""" ## batch_size arg?
- @deprecate_args(thr_jitter=None)
- def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
- v_reset=-60., v_decay=1., tau_theta=1e7, theta_plus=0.05,
- refract_time=5., one_spike=False, integration_type="euler",
- surrogate_type="straight_through", lower_clamp_voltage=True,
- **kwargs):
+ @deprecate_args(thr_jitter=None, v_decay="conduct_leak")
+ def __init__(
+ self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., conduct_leak=1., tau_theta=1e7,
+ theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler", surrogate_type="straight_through",
+ v_min=None, max_one_spike=False, **kwargs
+ ):
super().__init__(name, **kwargs)
## Integration properties
@@ -132,11 +142,12 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
self.tau_m = tau_m ## membrane time constant
self.resist_m = resist_m ## resistance value
self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step
- self.lower_clamp_voltage = lower_clamp_voltage ## True ==> ensures voltage is never < v_rest
+ self.max_one_spike = max_one_spike
+ self.v_min = v_min ## ensures voltage is never < v_min
self.v_rest = v_rest #-65. # mV
self.v_reset = v_reset # -60. # -65. # mV (milli-volts)
- self.v_decay = v_decay ## controls strength of voltage leak (1 -> LIF, 0 => IF)
+ self.g_L = conduct_leak ## controls strength of voltage leak (1 -> LIF, 0 => IF)
## basic asserts to prevent neuronal dynamics breaking...
#assert (self.v_decay * self.dt / self.tau_m) <= 1. ## <-- to integrate in verify...
assert self.resist_m > 0.
@@ -178,11 +189,11 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
@transition(output_compartments=["v", "s", "s_raw", "rfr", "thr_theta", "tols", "key", "surrogate"])
@staticmethod
def advance_state(
- t, dt, tau_m, resist_m, v_rest, v_reset, v_decay, refract_T, thr, tau_theta, theta_plus,
- one_spike, lower_clamp_voltage, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols
+ t, dt, tau_m, resist_m, v_rest, v_reset, g_L, refract_T, thr, tau_theta, theta_plus, one_spike, max_one_spike,
+ v_min, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols
):
skey = None ## this is an empty dkey if single_spike mode turned off
- if one_spike:
+ if one_spike and not max_one_spike:
key, skey = random.split(key, 2)
## run one integration step for neuronal dynamics
j = j * resist_m
@@ -191,13 +202,14 @@ def advance_state(
_v_thr = thr_theta + thr ## calc present voltage threshold
#mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
## update voltage / membrane potential
- v_params = (j, rfr, tau_m, refract_T, v_rest, v_decay)
+ v_params = (j, rfr, tau_m, refract_T, v_rest, g_L)
if intgFlag == 1:
_, _v = step_rk2(0., v, _dfv, dt, v_params)
else:
_, _v = step_euler(0., v, _dfv, dt, v_params)
## obtain action potentials/spikes/pulses
s = (_v > _v_thr) * 1.
+ v_prespike = v
## update refractory variables
_rfr = (rfr + dt) * (1. - s)
## perform hyper-polarization of neuronal cells
@@ -212,6 +224,9 @@ def advance_state(
rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1],
dtype=jnp.float32)
s = s * (1. - m_switch) + rS * m_switch
+ if max_one_spike:
+ rS = nn.one_hot(jnp.argmax(v_prespike, axis=1), num_classes=s.shape[1], dtype=jnp.float32) ## get max-volt spike
+ s = s * rS ## mask out non-max volt spikes
############################################################################
raw_spikes = raw_s
v = _v
@@ -223,8 +238,8 @@ def advance_state(
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
## update tols
tols = (1. - s) * tols + (s * t)
- if lower_clamp_voltage: ## ensure voltage never < v_rest
- v = jnp.maximum(v, v_rest)
+ if v_min is not None: ## ensures voltage never < v_rest
+ v = jnp.maximum(v, v_min)
return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate
@transition(output_compartments=["j", "v", "s", "s_raw", "rfr", "tols", "surrogate"])
diff --git a/ngclearn/components/other/varTrace.py b/ngclearn/components/other/varTrace.py
index 3889cace..94510e75 100644
--- a/ngclearn/components/other/varTrace.py
+++ b/ngclearn/components/other/varTrace.py
@@ -59,6 +59,8 @@ class VarTrace(JaxComponent): ## low-pass filter
a_delta: value to increment a trace by in presence of a spike; note if set
to a value <= 0, then a piecewise gated trace will be used instead
+ P_scale: if `a_delta=0`, then this scales the value that the trace snaps to upon receiving a pulse value
+
gamma_tr: an extra multiplier in front of the leak of the trace (Default: 1)
decay_type: string indicating the decay type to be applied to ODE
@@ -69,19 +71,24 @@ class VarTrace(JaxComponent): ## low-pass filter
2) `'exp'` = exponential trace filter, i.e., decay = exp(-dt/tau_tr) * x_tr;
3) `'step'` = step trace, i.e., decay = 0 (a pulse applied upon input value)
+ n_nearest_spikes: (k) if k > 0, this makes the trace act like a nearest-neighbor trace,
+ i.e., k = 1 yields the 1-nearest (neighbor) trace (Default: 0)
+
batch_size: batch size dimension of this cell (Default: 1)
"""
# Define Functions
- def __init__(self, name, n_units, tau_tr, a_delta, gamma_tr=1, decay_type="exp",
- batch_size=1, **kwargs):
+ def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay_type="exp",
+ n_nearest_spikes=0, batch_size=1, **kwargs):
super().__init__(name, **kwargs)
## Trace control coefficients
self.tau_tr = tau_tr ## trace time constant
self.a_delta = a_delta ## trace increment (if spike occurred)
+ self.P_scale = P_scale ## trace scale if non-additive trace to be used
self.gamma_tr = gamma_tr
self.decay_type = decay_type ## lin --> linear decay; exp --> exponential decay
+ self.n_nearest_spikes = n_nearest_spikes
## Layer Size Setup
self.batch_size = batch_size
@@ -94,17 +101,22 @@ def __init__(self, name, n_units, tau_tr, a_delta, gamma_tr=1, decay_type="exp",
@transition(output_compartments=["outputs", "trace"])
@staticmethod
- def advance_state(dt, decay_type, tau_tr, a_delta, gamma_tr, inputs, trace):
+ def advance_state(
+ dt, decay_type, tau_tr, a_delta, P_scale, gamma_tr, inputs, trace, n_nearest_spikes
+ ):
decayFactor = 0.
if "exp" in decay_type:
decayFactor = jnp.exp(-dt/tau_tr)
elif "lin" in decay_type:
decayFactor = (1. - dt/tau_tr)
_x_tr = gamma_tr * trace * decayFactor
- if a_delta > 0.:
- _x_tr = _x_tr + inputs * a_delta
+ if n_nearest_spikes > 0: ## run k-nearest neighbor trace
+ _x_tr = _x_tr + inputs * (a_delta - (trace/n_nearest_spikes))
else:
- _x_tr = _x_tr * (1. - inputs) + inputs
+ if a_delta > 0.: ## run full convolution trace
+ _x_tr = _x_tr + inputs * a_delta
+ else: ## run simple max-clamped trace
+ _x_tr = _x_tr * (1. - inputs) + inputs * P_scale
trace = _x_tr
return trace, trace
@@ -135,12 +147,15 @@ def help(cls): ## component help function
"tau_tr": "Trace/filter time constant",
"a_delta": "Increment to apply to trace (if not set to 0); "
"otherwise, traces clamp to 1 and then decay",
+ "P_scale": "Max value to snap trace to if a max-clamp trace is triggered/configured",
"decay_type": "Indicator of what type of decay dynamics to use "
- "as filter is updated at time t"
+ "as filter is updated at time t",
+ "n_nearest_neighbors": "Number of nearest pulses to affect/increment trace (if > 0)"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
- "dynamics": "tau_tr * dz/dt ~ -z + inputs",
+ "dynamics": "tau_tr * dz/dt ~ -z + inputs * a_delta (full convolution trace); "
+ "tau_tr * dz/dt ~ -z + inputs * (a_delta - z/n_nearest_neighbors) (near trace)",
"hyperparameters": hyperparams}
return info
diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py
index fc5b8a60..2c21c231 100644
--- a/ngclearn/components/synapses/__init__.py
+++ b/ngclearn/components/synapses/__init__.py
@@ -4,7 +4,9 @@
## short-term plasticity components
from .STPDenseSynapse import STPDenseSynapse
-
+from .exponentialSynapse import ExponentialSynapse
+from .doubleExpSynapse import DoupleExpSynapse
+from .alphaSynapse import AlphaSynapse
## dense synaptic components
from .hebbian.hebbianSynapse import HebbianSynapse
diff --git a/ngclearn/components/synapses/alphaSynapse.py b/ngclearn/components/synapses/alphaSynapse.py
new file mode 100644
index 00000000..cf5f9543
--- /dev/null
+++ b/ngclearn/components/synapses/alphaSynapse.py
@@ -0,0 +1,186 @@
+from jax import random, numpy as jnp, jit
+from ngcsimlib.compilers.process import transition
+from ngcsimlib.component import Component
+from ngcsimlib.compartment import Compartment
+
+from ngclearn.utils.weight_distribution import initialize_params
+from ngcsimlib.logger import info
+from ngclearn.components.synapses import DenseSynapse
+from ngclearn.utils import tensorstats
+
+class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable
+ """
+ A dynamic alpha synaptic cable; this synapse evolves according to alpha synaptic conductance dynamics.
+ Specifically, the conductance dynamics are as follows:
+
+ | dh/dt = -h/tau_decay + gBar sum_k (t - t_k) // h is an intermediate variable
+ | dg/dt = -g/tau_decay + h/tau_decay
+ | i_syn = g * (syn_rest - v) // g is `g_syn` and h is `h_syn` in this synapse implementation
+ | where: syn_rest is the post-synaptic reverse potential for this synapse
+ | t_k marks time of -pre-synaptic k-th pulse received by post-synaptic unit
+
+
+ | --- Synapse Compartments: ---
+ | inputs - input (takes in external signals, e.g., pre-synaptic pulses/spikes)
+ | outputs - output signals (also equal to i_syn, total electrical current)
+ | v - coupled voltages from post-synaptic neurons this synaptic cable connects to
+ | weights - current value matrix of synaptic efficacies
+ | biases - current value vector of synaptic bias values
+ | --- Dynamic / Short-term Plasticity Compartments: ---
+ | g_syn - fixed value matrix of synaptic resources (U)
+ | i_syn - derived total electrical current variable
+
+ Args:
+ name: the string name of this synapse
+
+ shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
+ with number of inputs by number of outputs)
+
+ tau_decay: synaptic decay time constant (ms)
+
+ g_syn_bar: maximum conductance elicited by each incoming spike ("synaptic weight")
+
+ syn_rest: synaptic reversal potential; note, if this is set to `None`, then this
+ synaptic conductance model will no longer be voltage-dependent (and will ignore
+ the voltage compartment provided by an external spiking cell)
+
+ weight_init: a kernel to drive initialization of this synaptic cable's values;
+ typically a tuple with 1st element as a string calling the name of
+ initialization to use
+
+ bias_init: a kernel to drive initialization of biases for this synaptic cable
+ (Default: None, which turns off/disables biases)