diff --git a/README.md b/README.md index 7e43e7fa..1e7412fd 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ $ python install -e . **Version:**
-2.0.0 +2.0.1 Author: Alexander G. Ororbia II
diff --git a/docs/images/museum/rat_accuracy.jpg b/docs/images/museum/rat_accuracy.jpg new file mode 100644 index 00000000..6236db35 Binary files /dev/null and b/docs/images/museum/rat_accuracy.jpg differ diff --git a/docs/images/museum/rat_rewards.jpg b/docs/images/museum/rat_rewards.jpg new file mode 100644 index 00000000..8c42e1b0 Binary files /dev/null and b/docs/images/museum/rat_rewards.jpg differ diff --git a/docs/images/museum/ratmaze.png b/docs/images/museum/ratmaze.png new file mode 100644 index 00000000..65c1006f Binary files /dev/null and b/docs/images/museum/ratmaze.png differ diff --git a/docs/images/museum/real_ratmaze.jpg b/docs/images/museum/real_ratmaze.jpg new file mode 100644 index 00000000..5925c09f Binary files /dev/null and b/docs/images/museum/real_ratmaze.jpg differ diff --git a/docs/images/tutorials/neurocog/GEC.png b/docs/images/tutorials/neurocog/GEC.png new file mode 100644 index 00000000..47ec531f Binary files /dev/null and b/docs/images/tutorials/neurocog/GEC.png differ diff --git a/docs/images/tutorials/neurocog/SingleGEC.png b/docs/images/tutorials/neurocog/SingleGEC.png new file mode 100644 index 00000000..8af5cc58 Binary files /dev/null and b/docs/images/tutorials/neurocog/SingleGEC.png differ diff --git a/docs/images/tutorials/neurocog/alphasyn.jpg b/docs/images/tutorials/neurocog/alphasyn.jpg new file mode 100644 index 00000000..037cf868 Binary files /dev/null and b/docs/images/tutorials/neurocog/alphasyn.jpg differ diff --git a/docs/images/tutorials/neurocog/ei_circuit_dense_exc.jpg b/docs/images/tutorials/neurocog/ei_circuit_dense_exc.jpg new file mode 100644 index 00000000..84c84023 Binary files /dev/null and b/docs/images/tutorials/neurocog/ei_circuit_dense_exc.jpg differ diff --git a/docs/images/tutorials/neurocog/ei_circuit_dynamics.jpg b/docs/images/tutorials/neurocog/ei_circuit_dynamics.jpg new file mode 100644 index 00000000..ef29c41e Binary files /dev/null and b/docs/images/tutorials/neurocog/ei_circuit_dynamics.jpg differ diff --git a/docs/images/tutorials/neurocog/ei_circuit_sparse_inh.jpg b/docs/images/tutorials/neurocog/ei_circuit_sparse_inh.jpg new file mode 100644 index 00000000..3a137c8b Binary files /dev/null and b/docs/images/tutorials/neurocog/ei_circuit_sparse_inh.jpg differ diff --git a/docs/images/tutorials/neurocog/exp2syn.jpg b/docs/images/tutorials/neurocog/exp2syn.jpg new file mode 100644 index 00000000..d32f3bfc Binary files /dev/null and b/docs/images/tutorials/neurocog/exp2syn.jpg differ diff --git a/docs/images/tutorials/neurocog/expsyn.jpg b/docs/images/tutorials/neurocog/expsyn.jpg new file mode 100755 index 00000000..1b9b2fc3 Binary files /dev/null and b/docs/images/tutorials/neurocog/expsyn.jpg differ diff --git a/docs/installation.md b/docs/installation.md index cda569a2..b70db6c6 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -6,13 +6,13 @@ without a GPU. Setup: ngc-learn, in its entirety (including its supporting utilities), requires that you ensure that you have installed the following base dependencies in -your system. Note that this library was developed and tested on Ubuntu 22.04 (and 18.04). +your system. Note that this library was developed and tested on Ubuntu 22.04 (and earlier versions on 18.04/20.04). Specifically, ngc-learn requires: * Python (>=3.10) -* ngcsimlib (>=0.3.b4), (official page) +* ngcsimlib (>=1.0.0), (official page) * NumPy (>=1.26.0) * SciPy (>=1.7.0) -* JAX (>= 0.4.18; and jaxlib>=0.4.18) +* JAX (>= 0.4.28; and jaxlib>=0.4.28) * Matplotlib (>=3.4.2), (for `ngclearn.utils.viz`) * Scikit-learn (>=1.3.1), (for `ngclearn.utils.patch_utils` and `ngclearn.utils.density`) @@ -45,7 +45,7 @@ $ git clone https://github.com/NACLab/ngc-learn.git $ cd ngc-learn ``` -2. (Optional; only for GPU version) Install JAX for either CUDA 11 or 12 , depending +2. (Optional; only for GPU version) Install JAX for either CUDA 12 , depending on your system setup. Follow the installation instructions on the official JAX page to properly install the CUDA 11 or 12 version. diff --git a/docs/modeling/neurons.md b/docs/modeling/neurons.md index 1ec2600b..4babf8f7 100644 --- a/docs/modeling/neurons.md +++ b/docs/modeling/neurons.md @@ -234,7 +234,7 @@ fast spiking (FS), low-threshold spiking (LTS), and resonator (RZ) neurons. This cell models dynamics over voltage `v` and three channels/gates (related to potassium and sodium activation/inactivation). This sophisticated cell system is, as a result, a set of four coupled differential equations and is driven by an appropriately configured set of biophysical constants/coefficients (default values of which have been set according to relevant source work). -(Note that this cell supports either Euler or midpoint method / RK-2 integration.) +(Note that this cell supports Euler, midpoint / RK-2 integration, or RK-4 integration.) ```{eval-rst} .. autoclass:: ngclearn.components.HodgkinHuxleyCell diff --git a/docs/modeling/synapses.md b/docs/modeling/synapses.md index 6bf89394..470446e9 100644 --- a/docs/modeling/synapses.md +++ b/docs/modeling/synapses.md @@ -60,6 +60,34 @@ This synapse performs a deconvolutional transform of its input signals. Note tha ## Dynamic Synapse Types +### Exponential Synapse + +This (chemical) synapse performs a linear transform of its input signals. Note that this synapse is "dynamic" in the sense that its efficacies are a function of their pre-synaptic inputs; there is no inherent form of long-term plasticity in this base implementation. Synaptic strength values can be viewed as being filtered/smoothened through an expoential kernel. + +```{eval-rst} +.. autoclass:: ngclearn.components.ExponentialSynapse + :noindex: + + .. automethod:: advance_state + :noindex: + .. automethod:: reset + :noindex: +``` + +### Alpha Synapse + +This (chemical) synapse performs a linear transform of its input signals. Note that this synapse is "dynamic" in the sense that its efficacies are a function of their pre-synaptic inputs; there is no inherent form of long-term plasticity in this base implementation. Synaptic strength values can be viewed as being filtered/smoothened through a kernel that models more realistic rise and fall times of synaptic conductance.. + +```{eval-rst} +.. autoclass:: ngclearn.components.AlphaSynapse + :noindex: + + .. automethod:: advance_state + :noindex: + .. automethod:: reset + :noindex: +``` + ### Short-Term Plasticity (Dense) Synapse This synapse performs a linear transform of its input signals. Note that this synapse is "dynamic" in the sense that it engages in short-term plasticity (STP), meaning that its efficacy values change as a function of its inputs/time (and simulated consumed resources), but it does not provide any long-term form of plasticity/adjustment. diff --git a/docs/museum/index.rst b/docs/museum/index.rst index 6f534770..25f75dfc 100644 --- a/docs/museum/index.rst +++ b/docs/museum/index.rst @@ -18,3 +18,4 @@ relevant, referenced publicly available ngc-learn simulation code. snn_dc snn_bfa sindy + rl_snn diff --git a/docs/museum/rl_snn.md b/docs/museum/rl_snn.md new file mode 100644 index 00000000..24d11412 --- /dev/null +++ b/docs/museum/rl_snn.md @@ -0,0 +1,132 @@ +# Reinforcement Learning through a Spiking Controller + +In this exhibit, we will see how to construct a simple biophysical model for +reinforcement learning with a spiking neural network and modulated +spike-timing-dependent plasticity. +This model incorporates a mechanisms from several different models, including +the constrained RL-centric SNN of [1] as well as the simplifications +made with respect to the model of [2]. The model code for this +exhibit can be found +[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/rl_snn). + +## Modeling Operant Conditioning through Modulation + +Operant conditioning refers to the idea that there are environmental stimuli that can either increase or decrease the occurrence of (voluntary) behaviors; in other words, positive stimuli can lead to future repeats of a certain behavior whereas negative stimuli can lead to (i.e., punish) a decrease in future occurrences. Ultimately, operant conditioning, through consequences, shapes voluntary behavior where actions followed by rewards are repeated and actions followed by punished/negative outcomes diminish. + +In this lesson, we will model very simple case of operant conditioning for a neuronal motor circuit used to engage in the navigation of a simple maze. +The maze's design will be the rat T-maze and the "rat" will be allowed to move, at a particular point in the maze, in one of four directions (up/North, down/South, left/West, and right/East). A positive reward will be supplied to our rat neuronal circuit if it makes progress towards the direction of food (placed in the upper right corner of the T-maze) and a negative reward will be provided if fails to make progress/gets stuck, i.e., a dense reward functional will be employed. For the exhibit code that goes with this lesson, an implementation of this T-maze environment is provided, modeled in the same style/with the same agent API as the OpenAI gymnasium. + +### Reward-Modulated Spike-Timing-Dependent Plasticity (R-STDP) + +Although [spike-timing-dependent plasticity](../tutorials/neurocog/stdp.md) (STDP) and [reward-modulated STDP](../tutorials/neurocog/mod_stdp.md) (MSTDP) are covered and analyzed in detail in the ngc-learn set of tutorials, we will briefly review the evolution +of synaptic strengths as prescribed by modulated STDP with eligibiility traces here. In effect, STDP prescribes changes +in synaptic strength according to the idea that neurons that fire together, wire together, except that timing matters +(a temporal interpretation of basic Hebbian learning). This means that, assuming we are able to record the times of +pre-synaptic and post-synaptic neurons (that a synaptic cable connects), we can, at any time-step $t$, produce an +adjustment $\Delta W_{ij}(t)$ to a synapse via the following pair of correlational rules: + +$$ +\Delta W_{ij}(t) = A^+ \big(x_i s_j \big) - A^- \big(s_i x_j \big) +$$ + +where $s_j$ is the spike recorded at time $t$ of the post-synaptic neuron $j$ (and $x_j$ is an exponentially-decaying trace that tracks its spiking history) and $s_i$ is the spike recorded at time $t$ of the pre-synaptic neuron $i$ (and $x_i$ is an exponentially-decaying trace that tracks its pulse history). STDP, as shown in a very simple format above, effectively can be described as balancing two types of alterations to a synaptic efficacy -- long-term potentiation (the first term, which increases synaptic strength) and long-term depression (the second term, which decreases synaptic strength). + +Modulated STDP is a three-factor variant of STDP that multiplies the final synaptic update by a third signal, e.g., the modulatory signal is often a reward (dopamine) intensity value (resulting in reward-modulated STDP). However, given that reward signals might be delayed or not arrive/be available at every single time-step, it is common practice to extend a synapse to maintain a second value called an "eligibility trace", which is effectively another exponentially-decaying trace/filter (instantiated as an ODE that can be integrated via the Euler method or related tools) that is constructed to track a sequence of STDP updates applied across a window of time. Once a reward/modulator signal becomes available, the current trace is multiplied by the modulator to produce a change in synaptic efficacy. +In essence, this update becomes: + +$$ +\Delta W_{ij} = \nu E_{ij}(t) r(t), \; \text{where } \; \tau_e \frac{\partial E_{ij}(t)}{\partial t} = -E_{ij}(t) + \Delta W_{ij}(t) +$$ + +where $r(t)$ is the dopamine supplied at some time $t$ and $\nu$ is some non-negative global learning rate. Note that MSTDP with eligibility traces (MSTDP-ET) is agnostic to the choice of local STDP/Hebbian update used to produce $\Delta W_{ij}(t)$ (for example, one could replace the trace-based STDP rule we presented above with BCM or a variant of weight-dependent STDP). + +## The Spiking Neural Circuit Model + +In this exhibit, we build one of the simplest possible spiking neural networks (SNNs) one could design to tackle a simple maze navigation problem such as the rat T-maze; specifically, a three-layer SNN where the first layer is a Poisson encoder and the second and third layers contain sets of recurrent leaky integrate-and-fire (LIF) neurons. The recurrence in our model is to be non-plastic and constructed such that a form of lateral competition is induced among the LIF units, i.e., LIF neurons will be driven by a scaled Hollow-matrix initialized recurrent weight matrix (which will multiply spikes encountered at time $t - \Delta t$ by negative values), which will (quickly yet roughly) approximate the effect of inhibitory neurons. For the synapses that transmit pulses from the sensory/input layer to the second layer, we will opt for a non-plastic sparse mixture of excitatory and inhibitory strength values (much as in the model of [1]) to produce a reasonable encoding of the input Poisson spike trains. For the synapses that transmit pulses from the second layer to the third (control/action) layer, we will employ MSTDP-ET (as shown in the previous section) to adjust the non-negative efficacies in order to learn a basic reactive policy. We will call this very simple neuronal model the "reinforcement learning SNN" (RL-SNN). + +The SNN circuit will be provided raw pixels of the T-maze environment (however, this view is a global view of the +entire maze, as opposed to something more realistic such as an egocentric view of the sensory space), where a cross +"+" marks its current location and an "X" marks the location of the food substance/goal state. Shown below is an +image to the left depicting a real-world rat T-maze whereas to the right is our implementation/simulation of the +T-maze problem (and what our SNN circuit sees at the very start of an episode of the navigation problem). + +```{eval-rst} +.. table:: + :align: center + + +-------------------------------------------------+------------------------------------------------+ + | .. image:: ../images/museum/real_ratmaze.jpg | .. image:: ../images/museum/ratmaze.png | + | :width: 250px | :width: 200px | + | :align: center | :align: center | + +-------------------------------------------------+------------------------------------------------+ +``` + +## Running the RL-SNN Model + +To fit the RL-SNN model described above, go to the `exhibits/rl_snn` +sub-folder (this step assumes that you have git cloned the model museum repo +code), and execute the RL-SNN's simulation script from the command line as follows: + +```console +$ ./sim.sh +``` +which will execute a simulation of the MSTDP-adapted SNN on the T-maze problem, specifically executing four uniquely-seeded trial runs (i.e., four different "rat agents") and produce two plots, one containing a smoothened curve of episodic rewards over time and another containing a smoothened task accuracy curve (as in, did the rat reach the goal-state and obtain the food substance or not). You should obtain plots that look roughly like the two below. + +```{eval-rst} +.. table:: + :align: center + + +-----------------------------------------------+-----------------------------------------------+ + | .. image:: ../images/museum/rat_rewards.jpg | .. image:: ../images/museum/rat_accuracy.jpg | + | :width: 400px | :width: 400px | + | :align: center | :align: center | + +-----------------------------------------------+-----------------------------------------------+ +``` + +Notice that we have provided a random agent baseline (i.e., uniform random selection of one of the four possible +directions to move at each step in an episode) to contrast the SNN rat motor circuit's performance with. As you may +observe, the SNN circuit ultimately becomes conditioned to taking actions akin to the optimal policy -- go North/up +if it perceives itself (marked as a cross "+") at the bottom of the T-maze and then go East/right once it has reached the top +of the T of the T-maze and go right upon perception of the food item (goal state marked as an "X"). + +The code has been configured to also produce a small video/GIF of the final episode `episode200.gif`, where the MSTDP +weight changes have been disabled and the agent must solely rely on its memory of the uncovered policy to get to the +goal state. + +### Some Important Limitations + +While the above MSTDP-ET-driven motor circuit model is useful and provides a simple model of operant conditioning in +the context of a very simple maze navigation task, it is important to identify the assumptions/limitations of the +above setup. Some important limitations or simplifications that have been made to obtain a consistently working +RL-SNN model: +1. As mentioned earlier, the sensory input contains a global view of the maze navigation problem, i.e., a 2D birds-eye + view of the agent, its goal (the food substance), and its environment. More realistic, but far more difficult + versions of this problem would need to consider an ego-centric view (making the problem a partially observable + Markov decision process), a more realistic 3D representation of the environment, as well as more complex maze + sizes and shapes for the agent/rat model to navigate. +2. The reward is only delayed with respect to the agent's stimulus processing window, meaning that the agent essentially + receives a dopamine signal after an action is taken. If we ignore the SNN's stimulus processing time between video + frames of the actual navigation problem, we can view our agent above as tackling what is known in reinforcement + learning as the dense reward problem. A far more complex, yet more cognitively realistic, version of the problem + is to administer a sparse reward, i.e., the rat motor circuit only receives a useful dopamine/reward stimulus at the + end of an episode as opposed to after each action. The above MSTDP-ET model would struggle to solve the sparse + reward problem and more sophisticated models would be required in order to achieve successful outcomes, i.e., + appealing to models of memory/cognitive maps, more intelligent forms of exploration, etc. +3. The SNN circuit itself only permits plastic synapses in its control layer (i.e., the synaptic connections between + the second layer and third output/control layer). The bottom layer is non-plastic and fixed, meaning that the + agent model is dependent on the quality of the random initialization of the input-to-hidden encoding layer. The + input-to-hidden synapses could be adapted with STDP (or MSTDP); however, the agent will not always successfully + and stably converge to a consistent policy as the encoding layer's effectiveness is highly dependent on how much + of the environment the agent initially sees/explores (if the agent gets "stuck" at any point, STDP will tend to + fill up the bottom layer receptive fields with redundant information and make it more difficult for the control + layer to learn the consequences of taking different actions). + + +## References +[1] Chevtchenko, Sérgio F., and Teresa B. Ludermir. "Learning from sparse +and delayed rewards with a multilayer spiking neural network." 2020 International +Joint Conference on Neural Networks (IJCNN). IEEE, 2020.
+[2] Diehl, Peter U., and Matthew Cook. "Unsupervised learning of digit +recognition using spike-timing-dependent plasticity." Frontiers in computational +neuroscience 9 (2015): 99. + diff --git a/docs/museum/sindy.md b/docs/museum/sindy.md index 17ef4fee..04426d70 100644 --- a/docs/museum/sindy.md +++ b/docs/museum/sindy.md @@ -8,8 +8,7 @@ Flow diagrams lack clear directional indicators Inconsistent color schemes across visualizations --> - -# Sparse Identification of Non-linear Dynamical Systems (SINDy)[1] +# Sparse Identification of Non-linear Dynamical Systems (SINDy) In this section, we will study, create, simulate, and visualize a model known as the sparse identification of non-linear dynamical systems (SINDy) [1], implementing it in NGC-Learn and JAX. After going through this demonstration, you will: @@ -28,19 +27,24 @@ SINDy is a data-driven algorithm that discovers the governing behavior of a dyna ### SINDy Dynamics -If $\mathbf{X}$ is a system that only depends on variable $t$, a very small change in the independent variable ($dt$) can cause a change in the system by $dX$ amount. -$$$ +If $\mathbf{X}$ is a system that only depends on variable $t$, a very small change in the independent variable ($dt$) can cause a change in the system by $dX$ amount: + +$$ d\mathbf{X} = \mathbf{Ẋ}(t)~dt -$$$ +$$ + SINDy models the derivative[^1] (a linear operation) as linear transformations with: [^1]: The derivative is a linear operation that acts on $dt$ and gives a differential that is the linearized approximation of the taylor series of the function. -$$$ + +$$ \frac{d\mathbf{X}(t)}{dt} = \mathbf{Ẋ}(t) = \mathbf{f}(\mathbf{X}(t)) -$$$ +$$ + SINDy assumes that this linear operation, i.e., $\mathbf{f}(\mathbf{X}(t))$, is a matrix multiplication that linearly combines the relevant predictors in order to describe the system's equation. -$$$ + +$$ \mathbf{f}(\mathbf{X}(t)) = \mathbf{\Theta}(\mathbf{X})~\mathbf{W} -$$$ +$$ Given a group of candidate functions within the library $\mathbf{\Theta}(\mathbf{X})$, the coefficients in $\mathbf{W}$ that choose the library terms are to be **sparse**. In other words, there are only a few functions that exist in the system's differential equation. Given these assumptions, SINDy solves a sparse regression problem in order to find the $\mathbf{W}$ that maps the library of selected terms to each feature of the system being identified. SINDy imposes parsimony constraints over the resulting symbolic regression (i.e., genetic programming) to describe a dynamical system's behavior with as few terms as possible. In order to select a sparse set of the given features, the model adds the LASSO regularizarion penalty (i.e., an L1 norm constraint) to the regression problem and solves the sparse regression or solves the regression problem via STLSQ. We will describe STLSQ in third step of the SINDy dynamics/process. @@ -48,206 +52,101 @@ In essence, SINDy's dynamics can be presented in three main phases, visualized i ------------------------------------------------------------------------------------------ -

- - + **Figure 1:** **The flow of the three phases in SINDy.** **Phase-1)** Data collection: capturing system states that are changing in time and creating the state vector. **Phase-2A)** Library formation: manually creating the library of candidate predictors that could appear in the model. **Phase-2B)** Derivative computation: using the data collected in phase 1 to compute its derivative with respect to time. **Phase-3)** Solving the sparse regression problem. -

------------------------------------------------------------------------------------------ - - - - - - - - -
- ## Phase 1: Collecting Dataset → $\mathbf{X}_{(m \times n)}$ This phase involves gathering the raw data points representing the system's states across time. In this example, this means capturing the $x$, $y$, and $z$ coordinates of the system's states. Here, $m$ represents the number of data points (number of the snapshots/length of time) and $n$ is the system's dimensionality. - -

- Dataset collection showing x, y, z coordinates - -

-
- + + + + + - + - - - - - - - - - - - +Library of Candidate Functions: +$\Theta(\mathbf{X}) = [\mathbf{1} \quad \mathbf{X} \quad \mathbf{X}^2 \quad \mathbf{X}^3 \quad \sin(\mathbf{X}) \quad \cos(\mathbf{X}) \quad ...]$ - - - - - + + -
- ## Phase 2: Processing - -

- -

-
+ + ### 2.A: Making the Library → $\mathbf{\Theta}_{(m \times p)}$ In this step, using the dataset collected in phase 1, given pre-defined function terms, we construct a dictionary of candidate predictors for identifying the target system's differential equations. These functions form the columns of our library matrix $\mathbf{\Theta}(\mathbf{X})$ and $p$ is the number of candidate predictors. To identify the dynamical structure of the system, this library of candidate functions appears in the regression problem to propose the model's structure that will later serve as the coefficient matrix for weighting the functions according to the problem setup. We assume sparse models will be sufficient to identify the system and do this through sparsification (LASSO or thresholding weights) in order decide which structure best describes the system's behavior using predictors. Given a set of time-series measurements of a dynamical system state variables ($\mathbf{X}_{(m \times n)}$) we construct the following: -Library of Candidate Functions: $\Theta(\mathbf{X}) = [\mathbf{1} \quad \mathbf{X} \quad \mathbf{X}^2 \quad \mathbf{X}^3 \quad \sin(\mathbf{X}) \quad \cos(\mathbf{X}) \quad ...]$ - -

- -

-
- + + ### 2.B: Compute State Derivatives → $\mathbf{Ẋ}_{(m \times n)}$ Given a set of time-series measurements of a dynamical system's state variables $\mathbf{X}_{(m \times n)}$, we next construct the derivative matrix: $\mathbf{Ẋ}_{(m \times n)}$ (computed numerically). In this step, using the dataset collected in phase 1, we compute the derivatives of each state variable with respect to time. In this example, we compute $ẋ$, $ẏ$, and $ż$ in order to capture how the system evolves over time. - -

- -

-
- - - - - - - - - - - -
- ## Phase 3: Solving Sparse Regression Problem → $\mathbf{W_s}_{(p \times n)}$ Solving the resulting sparse regression (SR) problem that results from the phases/steps above can be done using various method such as Lasso, STLSQ, Elastic Net, as well as many other schemes. Here, we describe the STLSQ approach to solve the SR problem according to the SINDy process. - -

- Dataset collection showing x, y, z coordincates -

-
- - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - -
- + + ### Solving Sparse Regression by Sequential Thresholding Least Squares (STLSQ) -

- + + **Figure 1:** **The flow of three phases in SINDy.** **Phase-1)** Data collection: capturing system's states that are changing in time and making the state vector. **Phase-2A)** Library formation: manually creating the library of candidate predictors that could appear in the model. **Phase-2B)** Derivative computation: using the data collected in phase 1 and computing its derivative with respect to time. **Phase-3)** Solving the sparse regression problem via STLSQ.

------------------------------------------------------------------------------------------ -
- ### Sequential Thresholding Least Square (STLSQ) -
-

- State derivatives visualization -

-
#### 3.A: Least Square method (LSQ) → $\mathbf{W}$ This step entails finding library coefficients by solving the following regression problem $\mathbf{Ẋ} = \mathbf{\Theta}\mathbf{W}$ analytically $\mathbf{W} = (\mathbf{\Theta}^T \mathbf{\Theta})^{-1} \mathbf{\Theta}^T \mathbf{Ẋ}$ - -

- State derivatives visualization -

-
- + + #### 3.B: Thresholding → $\mathbf{W_s}$ This step entails sparsifying $\mathbf{W}$ by keeping only some of the terms within $\mathbf{W}$, particularly those that correspond to the effective terms in the library. - -

- State derivatives visualization -

-
- + + + #### 3.C: Masking → $\mathbf{\Theta_s}$ This step sparsifies $\mathbf{\Theta}$ by keeping only the corresponding terms in $\mathbf{W}$ that remain (from the prior step). - -

- State derivatives visualization -

-
+ #### 3.D: Repeat A → B → C until convergence We continue to solve LSQ with the sparse matrix $\mathbf{\Theta_s}$ and $\mathbf{W_s}$ and find a new $\mathbf{W}$, repeating steps B and C until convergence. - -

- State derivatives visualization -

-
- - - - + - - - - ## Code: Simulating SINDy We finally present ngc-learn code below for using and simulating the SINDy process to identify several dynamical systems. - - ```python - - - import numpy as np import jax.numpy as jnp from ngclearn.utils.feature_dictionaries.polynomialLibrary import PolynomialLibrary @@ -335,31 +224,14 @@ for dim in range(dX.shape[1]): coeff = jnp.where(jnp.abs(coef) >= threshold, coef, 0.) print(f"coefficients for dimension {dim+1}: \n", coeff.T) - - - ``` - - - - ## Results: System Identification Running the above code should produce results similar to the findings we present next. - - - - - - - - - - - - - - - - - - - - - - - +which should produce the following results: - - - - - - - - -
- Model - - Results -
- ## Oscillator +## Oscillator True model's equation \ $\mathbf{ẋ} = \mu_1\mathbf{x} + \sigma \mathbf{xy}$ \ @@ -380,19 +252,13 @@ $\mathbf{ż} = \mu_2\mathbf{z} - (\omega + \alpha \mathbf{y} + \beta \mathbf{z}) [ 0. -0.009 0. -2.99 4.99 1.99 0. 0. 0. 0.]] ``` - -

- - -

-
+which should produce the following results: + + + + - ## Lorenz +## Lorenz True model's equation \ $\mathbf{ẋ} = 10(\mathbf{y} - \mathbf{x})$ \ @@ -400,7 +266,6 @@ $\mathbf{ẏ} = \mathbf{x}(28 - \mathbf{z}) - \mathbf{y}$ \ $\mathbf{ż} = \mathbf{xy} - \frac{8}{3}~\mathbf{z}$ - ```python --- SINDy results ---- ẋ = 9.969 𝑦 -9.966 𝑥 @@ -413,19 +278,12 @@ $\mathbf{ż} = \mathbf{xy} - \frac{8}{3}~\mathbf{z}$ [-2.656 0. 0. 0. 0. 0. 0. 0.996 0.]] ``` - -

- - -

-
- - ## Linear-2D +which should produce the following results: + + + + +## Linear-2D True model's equation \ $\mathbf{ẋ} = -0.1\mathbf{x} + 2.0\mathbf{y}$ \ @@ -440,21 +298,14 @@ $\mathbf{ẏ} = -2.0\mathbf{x} - 0.1\mathbf{y}$ [[ 1.999 0. -0.100 0. 0.] [-0.099 0. -1.999 0. 0.]] ``` + +which should produce the following results: + + - -

- - -

-
- ## Linear-3D +## Linear-3D True model's equation \ $\mathbf{ẋ} = -0.1\mathbf{x} + 2\mathbf{y}$ \ @@ -473,22 +324,13 @@ $\mathbf{ż} = -0.3\mathbf{z}$ [-0.299 0. 0. 0. 0. 0. 0. 0. 0.]] ``` - -

- - -

-
+ + + - ## Cubic-2D +## Cubic-2D True model's equation \ $\mathbf{ẋ} = -0.1\mathbf{x}^3 + 2.0\mathbf{y}^3$ \ @@ -504,16 +346,10 @@ $\mathbf{ẏ} = -2.0\mathbf{x}^3 - 0.1\mathbf{y}^3$ [ 0. 0. -0.099 0. 0. 0. 0. 0. -1.99]] ``` - -

- - -

-
+which should produce the following results: + + + ## References diff --git a/docs/source/ngclearn.components.neurons.graded.rst b/docs/source/ngclearn.components.neurons.graded.rst index ec39a795..d62a5b7e 100644 --- a/docs/source/ngclearn.components.neurons.graded.rst +++ b/docs/source/ngclearn.components.neurons.graded.rst @@ -36,14 +36,6 @@ ngclearn.components.neurons.graded.rateCell module :undoc-members: :show-inheritance: -ngclearn.components.neurons.graded.rateCellOld module ------------------------------------------------------ - -.. automodule:: ngclearn.components.neurons.graded.rateCellOld - :members: - :undoc-members: - :show-inheritance: - ngclearn.components.neurons.graded.rewardErrorCell module --------------------------------------------------------- diff --git a/docs/source/ngclearn.components.synapses.hebbian.rst b/docs/source/ngclearn.components.synapses.hebbian.rst index d892d868..778e59ec 100644 --- a/docs/source/ngclearn.components.synapses.hebbian.rst +++ b/docs/source/ngclearn.components.synapses.hebbian.rst @@ -36,14 +36,6 @@ ngclearn.components.synapses.hebbian.hebbianSynapse module :undoc-members: :show-inheritance: -ngclearn.components.synapses.hebbian.hebbianSynapseOld module -------------------------------------------------------------- - -.. automodule:: ngclearn.components.synapses.hebbian.hebbianSynapseOld - :members: - :undoc-members: - :show-inheritance: - ngclearn.components.synapses.hebbian.traceSTDPSynapse module ------------------------------------------------------------ diff --git a/docs/source/ngclearn.components.synapses.rst b/docs/source/ngclearn.components.synapses.rst index c1ccc6c4..43791098 100644 --- a/docs/source/ngclearn.components.synapses.rst +++ b/docs/source/ngclearn.components.synapses.rst @@ -23,6 +23,14 @@ ngclearn.components.synapses.STPDenseSynapse module :undoc-members: :show-inheritance: +ngclearn.components.synapses.alphaSynapse module +------------------------------------------------ + +.. automodule:: ngclearn.components.synapses.alphaSynapse + :members: + :undoc-members: + :show-inheritance: + ngclearn.components.synapses.denseSynapse module ------------------------------------------------ @@ -31,6 +39,22 @@ ngclearn.components.synapses.denseSynapse module :undoc-members: :show-inheritance: +ngclearn.components.synapses.doubleExpSynapse module +---------------------------------------------------- + +.. automodule:: ngclearn.components.synapses.doubleExpSynapse + :members: + :undoc-members: + :show-inheritance: + +ngclearn.components.synapses.exponentialSynapse module +------------------------------------------------------ + +.. automodule:: ngclearn.components.synapses.exponentialSynapse + :members: + :undoc-members: + :show-inheritance: + ngclearn.components.synapses.staticSynapse module ------------------------------------------------- diff --git a/docs/tutorials/neurocog/dynamic_synapses.md b/docs/tutorials/neurocog/dynamic_synapses.md new file mode 100644 index 00000000..bc708264 --- /dev/null +++ b/docs/tutorials/neurocog/dynamic_synapses.md @@ -0,0 +1,421 @@ +# Lecture 4A: Dynamic Synapses and Conductance + +In this lesson, we will study dynamic synapses, or synaptic cable components in +ngc-learn that evolve on fast time-scales in response to their pre-synaptic inputs. +These types of chemical synapse components are useful for modeling time-varying +conductance which ultimately drives eletrical current input into neuronal units +(such as spiking cells). Here, we will learn how to build three important types of dynamic synapses in +ngc-learn -- the exponential, the alpha, and the double-exponential synapse -- and visualize +the time-course of their resulting conductances. In addition, we will then +construct and study a small neuronal circuit involving a leaky integrator that +is driven by exponential synapses relaying pulses from an excitatory and an +inhibitory population of Poisson input encoding cells. + +## Synaptic Conductance Modeling + +Synapse models are typically used to model the post-synaptic response produced by +action potentials (or pulses) at a pre-synaptic terminal. Assuming an electrical +response (as opposed to a chemical one, e.g., an influx of calcium), such models seek +to emulate the time-course of what is known as post-synaptic receptor conductance. Note +that these dynamic synapse models will end being a bit more sophisticated than the strength +value matrices we might initially employ (as in synapse components such as the +[DenseSynapse](ngclearn.components.synapses.denseSynapse)). + +Building a dynamic synapse can be done by importing the [exponential synapse](ngclearn.components.synapses.exponentialSynapse), +the [double-exponential synapse](ngclearn.components.synapses.doubleExpSynapse), or the [alpha synapse](ngclearn.components.synapses.alphaSynapse) from ngc-learn's in-built components and setting them up within a model context for easy analysis. Go ahead and create a Python script named `probe_synapses.py` to place +the code you will write within. +For the first part of this lesson, we will import all three dynamic synpapse models and compare their behavior. +This can be done as follows (using the meta-parameters we provide in the code block below to ensure reasonable dynamics): + +```python +from jax import numpy as jnp, random, jit +from ngcsimlib.context import Context +from ngclearn.components import ExponentialSynapse, AlphaSynapse, DoupleExpSynapse + +from ngcsimlib.compilers.process import Process +from ngcsimlib.context import Context +import ngclearn.utils.weight_distribution as dist + + +dkey = random.PRNGKey(1234) ## creating seeding keys for synapses +dkey, *subkeys = random.split(dkey, 6) +dt = 0.1 # ms ## integration time constant +T = 8. # ms ## total duration time + +## ---- build a dual-synapse system ---- +with Context("dual_syn_system") as ctx: + Wexp = ExponentialSynapse( ## exponential dynamic synapse + name="Wexp", shape=(1, 1), tau_decay=3., g_syn_bar=1., syn_rest=0., resist_scale=1., + weight_init=dist.constant(value=1.), key=subkeys[0] + ) + Walpha = AlphaSynapse( ## alpha dynamic synapse + name="Walpha", shape=(1, 1), tau_decay=1., g_syn_bar=1., syn_rest=0., resist_scale=1., + weight_init=dist.constant(value=1.), key=subkeys[0] + ) + Wexp2 = DoupleExpSynapse( + name="Wexp2", shape=(1, 1), tau_rise=1., tau_decay=3., g_syn_bar=1., syn_rest=0., resist_scale=1., + weight_init=dist.constant(value=1.), key=subkeys[0] + ) + + ## set up basic simulation process calls + advance_process = (Process("advance_proc") + >> Wexp.advance_state + >> Walpha.advance_state + >> Wexp2.advance_state) + ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + + reset_process = (Process("reset_proc") + >> Wexp.reset + >> Walpha.reset + >> Wexp2.reset) + ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") +``` + +where we notice in the above we have instantiated three different kinds of chemical synapse components +that we will run side-by-side in order to extract their produced conductance values in response to +the exact same input stream. For both the exponential and the alpha synapse, there are at least three important hyper-parameters to configure: +1. `tau_decay` ($\tau_{\text{decay}}$): the synaptic conductance decay time constant (for the double-exponential synapse, we also have `tau_rise`); +2. `g_syn_bar` ($\bar{g}_{\text{syn}}$): the maximal conductance value produced by each pulse transmitted + across this synapse; and, +3. `syn_rest` ($E_{rest}$): the (post-synaptic) reversal potential for this synapse -- note that this value + determines the direction of current flow through the synapse, yielding a synapse with an + excitatory nature for non-negative values of `syn_rest` or a synapse with an inhibitory + nature for negative values of `syn_rest`. + +The flow of electrical current from a pre-synaptic neuron to a post-synaptic one is often modeled under the assumption that pre-synaptic pulses result in impermanent (transient; lasting for a short period of time) changes in the conductance of a post-synaptic neuron. As a result, the resulting conductance dynamics $g_{\text{syn}}(t)$ -- or the effect (conductance changes in the post-synaptic membrane) of a transmitter binding to and opening post-synaptic receptors -- of each of the synapses that you have built above can be simulated in ngc-learn according to one or more ordinary differential equations (ODEs), which themselves iteratively model different waveform equations of the time-course of synaptic conductance. +For the exponential synapse, the dynamics adhere to the following ODE: + +$$ +\frac{\partial g_{\text{syn}}(t)}{\partial t} = -g_{\text{syn}}(t)/\tau_{\text{syn}} + \bar{g}_{\text{syn}} \sum_{k} \delta(t - t_{k}) +$$ + +where the conductance (for a post-synaptic unit) output of this synapse is driven by a sum over all of its incoming +pre-synaptic spikes; this ODE means that pre-synaptic spikes are filtered via an expoential kernel (i.e., a low-pass filter). +On the other hand, for the alpha synapse, the dynamics adhere to the following coupled set of ODEs: + +$$ +\frac{\partial h_{\text{syn}}(t)}{\partial t} &= -h_{\text{syn}}(t)/\tau_{\text{syn}} + \bar{g}_{\text{syn}} \sum_{k} \delta(t - t_{k}) \\ +\frac{\partial g_{\text{syn}}(t)}{\partial t} &= -g_{\text{syn}}(t)/\tau_{\text{syn}} + h_{\text{syn}}(t)/\tau_{\text{syn}} +$$ + +where $h_{\text{syn}}(t)$ is an intermediate variable that operates in service of driving the conductance variable $g_{\text{syn}}(t)$ itself. +The double-exponential (or difference of exponentials) synapse model looks similar to the alpha synapse except that the +rise and fall/decay of its condutance dynamics are set separately using two different time constants, i.e., $\tau_{\text{rise}}$ and $\tau_{\text{decay}}$, +as follows: + +$$ +\frac{\partial h_{\text{syn}}(t)}{\partial t} &= -h_{\text{syn}}(t)/\tau_{\text{rise}} + \big(\frac{1}{\tau_{\text{rise}}} - \frac{1}{\tau_{\text{decay}}} \big) \bar{g}_{\text{syn}} \sum_{k} \delta(t - t_{k}) \\ +\frac{\partial g_{\text{syn}}(t)}{\partial t} &= -g_{\text{syn}}(t)/\tau_{\text{decay}} + h_{\text{syn}}(t) . +$$ + +Finally, we seek model the electrical current that results from some amount of neurotransmitter in previous time steps. +Thus, for both any of these three synapses, the changes in conductance are finally converted (via Ohm's law) to electrical current to produce the final derived variable $j_{\text{syn}}(t)$: + +$$ +j_{\text{syn}}(t) = g_{\text{syn}}(t) (v(t) - E_{\text{rev}}) +$$ + +where $E_{\text{rev}}$ is the post-synaptic reverse potential of the ion channels that mediate the synaptic current; this is typically set to $E_{\text{rev}} = 0$ (millivolts; mV)for the case of excitatory changes and $E_{\text{rev}} = -75$ (mV) for the case of inhibitory changes. $v(t)$ is the voltage/membrane potential of the post-synaptic the synaptic cable wires to, meaning that the conductance models above are voltage-dependent (in ngc-learn, if you want voltage-independent conductance, then set `syn_rest = None`). + + +### Examining the Conductances of Dynamic Synapses + +We can track and visualize the conductance outputs of our different dynamic synapses by running a stream of controlled pre-synaptic pulses. Specifically, we will observe the output behavior of each in response to a sparse stream, eight milliseconds in length, where only a single spike is emitted at one millisecond. +To create the simulation of a single input pulse stream, you can write the following code: + +```python +time_span = [] +g = [] +ga = [] +gexp2 = [] +ctx.reset() +Tsteps = int(T/dt) + 1 +for t in range(Tsteps): + s_t = jnp.zeros((1, 1)) + if t * dt == 1.: ## pulse at 1 ms + s_t = jnp.ones((1, 1)) + Wexp.inputs.set(s_t) + Walpha.inputs.set(s_t) + Wexp.v.set(Wexp.v.value * 0) + Wexp2.inputs.set(s_t) + Walpha.v.set(Walpha.v.value * 0) + Wexp2.v.set(Wexp2.v.value * 0) + ctx.run(t=t * dt, dt=dt) + + print(f"\r g = {Wexp.g_syn.value} ga = {Walpha.g_syn.value} gexp2 = {Wexp2.g_syn.value}", end="") + g.append(Wexp.g_syn.value) + ga.append(Walpha.g_syn.value) + time_span.append(t) #* dt) +print() +g = jnp.squeeze(jnp.concatenate(g, axis=1)) +g = g/jnp.amax(g) +ga = jnp.squeeze(jnp.concatenate(ga, axis=1)) +ga = ga/jnp.amax(ga) +gexp2 = gexp2/jnp.amax(gexp2) +``` + +Note that we further normalize the conductance trajectories of all synapses to lie within the range of $[0, 1]$, +primarily for visualization purposes. +Finally, to visualize the conductance time-course of the synapses, you can write the following: + +```python +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +time_ticks = [] +time_labs = [] +for t in range(Tsteps): + if t % 10 == 0: + time_ticks.append(t) + time_labs.append(f"{t * dt:.1f}") + +## ---- plot the exponential synapse conductance time-course ---- +fig, ax = plt.subplots() + +gvals = ax.plot(time_span, g, '-', color='tab:red') +#plt.xticks(time_span, time_labs) +ax.set_xticks(time_ticks, time_labs) +ax.set(xlabel='Time (ms)', ylabel='Conductance', + title='Exponential Synapse Conductance Time-Course') +ax.grid(which="major") +fig.savefig("exp_syn.jpg") +plt.close() + +## ---- plot the alpha synapse conductance time-course ---- +fig, ax = plt.subplots() + +gvals = ax.plot(time_span, ga, '-', color='tab:blue') +#plt.xticks(time_span, time_labs) +ax.set_xticks(time_ticks, time_labs) +ax.set(xlabel='Time (ms)', ylabel='Conductance', + title='Alpha Synapse Conductance Time-Course') +ax.grid(which="major") +fig.savefig("alpha_syn.jpg") +plt.close() + +gvals = ax.plot(time_span, gexp2, '-', color='tab:blue') +#plt.xticks(time_span, time_labs) +ax.set_xticks(time_ticks, time_labs) +#plt.vlines(x=[0, 10, 20, 30, 40, 50, 60, 70, 80], ymin=-0.2, ymax=1.2, colors='gray', linestyles='dashed') #, label='Vertical Lines') +ax.set(xlabel='Time (ms)', ylabel='Conductance', + title='Double-Exponential Synapse Conductance Time-Course') +ax.grid(which="major") +fig.savefig("exp2_syn.jpg") +plt.close() +``` + +which should produce and save three plots to disk. You can then compare and contrast the plots of the +expoential, alpha synapse, and double-exponential conductance trajectories: + +```{eval-rst} +.. table:: + :align: center + + +--------------------------------------------------------+----------------------------------------------------------+---------------------------------------------------------+ + | .. image:: ../../images/tutorials/neurocog/expsyn.jpg | .. image:: ../../images/tutorials/neurocog/alphasyn.jpg | .. image:: ../../images/tutorials/neurocog/exp2syn.jpg | + | :width: 400px | :width: 400px | :width: 400px | + | :align: center | :align: center | :align: center | + +--------------------------------------------------------+----------------------------------------------------------+---------------------------------------------------------+ +``` + +Note that the alpha synapse (right figure) would produce a more realistic fit to recorded synaptic currents (as it attempts to model +the rise and fall of current in a less simplified manner) at the cost of extra compute, given it uses two ODEs to +emulate condutance, as opposed to the faster yet less-biophysically-realistic exponential synapse (left figure). + +## Excitatory-Inhibitory Driven Dynamics + +For this next part of the lesson, create a new Python script named `sim_ei_dynamics.py` for the next portions of code +you will write. +Let's next examine a more interesting use-case of the above dynamic synapses -- modeling excitatory (E) and inhibitory (I) +pressures produced by different groups of pre-synaptic spike trains. This allows us to examine a very common +and often-used conductance model that is paired with spiking cells such as the leaky integrate-and-fire (LIF). Specifically, +we seek to simulate the following post-synaptic conductance dynamics for a single LIF unit: + +$$ +\tau_{m} \frac{\partial v(t)}{\partial t} = -\big( v(t) - E_{L} \big) - \frac{g_{E}(t)}{g_{L}} \big( v(t) - E_{E} \big) - \frac{g_{I}(t)}{g_{L}} \big( v(t) - E_{I} \big) +$$ + +where $g_{L}$ is the leak conductance value for the post-synaptic LIF, $g_{E}(t)$ is the post-synaptic conductance produced by excitatory pre-synaptic spike trains (with excitatory synaptic reverse potential $E_{E}$), and $g_{I}(t)$ is the post-synaptic conductance produced by inhibitory pre-synaptic spike trains (with inhibitory synaptic reverse potential $E_{I}$). Note that the first term of the above ODE is the normal leak portion of the LIF's standard dynamics (scaled by conductance factor $g_{L}$) and the last two terms of the above ODE can be modeled each separately with a dynamic synapse. To differentiate between excitatory and inhibitory conductance changes, we will just configure a different reverse potential for each to induce either excitatory (i.e., $E_{\text{syn}} = E_{E} = 0$ mV) or inhibitory (i.e., $E_{\text{syn}} = E_{I} = -80$ mV) pressure/drive. + +We will specifically model the excitatory and inhibitory conductance changes using exponential synapses and the input spike trains for each with Poisson encoding cells; in other words, two different groups of Poisson cells will be wired to a single LIF cell via exponential dynamic synapses. The code for doing this is as follows: + +```python +from jax import numpy as jnp, random, jit +from ngcsimlib.context import Context +from ngclearn.components import ExponentialSynapse, PoissonCell, LIFCell +from ngclearn.operations import summation + +from ngcsimlib.compilers.process import Process +from ngcsimlib.context import Context +import ngclearn.utils.weight_distribution as dist + +## create seeding keys +dkey = random.PRNGKey(1234) +dkey, *subkeys = random.split(dkey, 6) + +## simulation properties +dt = 0.1 # ms +T = 1000. # ms ## total duration time + +## post-syn LIF cell properties +tau_m = 10. +g_L = 10. +v_rest = -75. +v_thr = -55. + +## excitatory group properties +exc_freq = 10. # Hz +n_exc = 80 +g_e_bar = 2.4 +tau_syn_exc = 2. +E_rest_exc = 0. + +## inhibitory group properties +inh_freq = 10. # Hz +n_inh = 20 +g_i_bar = 2.4 +tau_syn_inh = 5. +E_rest_inh = -80. + +Tsteps = int(T/dt) + +## ---- build a simple E-I spiking circuit ---- +with Context("ei_snn") as ctx: + pre_exc = PoissonCell("pre_exc", n_units=n_exc, target_freq=exc_freq, key=subkeys[0]) ## pre-syn excitatory group + pre_inh = PoissonCell("pre_inh", n_units=n_inh, target_freq=inh_freq, key=subkeys[1]) ## pre-syn inhibitory group + Wexc = ExponentialSynapse( ## dynamic synapse between excitatory group and LIF + name="Wexc", shape=(n_exc,1), tau_decay=tau_syn_exc, g_syn_bar=g_e_bar, syn_rest=E_rest_exc, resist_scale=1./g_L, + weight_init=dist.constant(value=1.), key=subkeys[2] + ) + Winh = ExponentialSynapse( ## dynamic synapse between inhibitory group and LIF + name="Winh", shape=(n_inh, 1), tau_decay=tau_syn_inh, g_syn_bar=g_i_bar, syn_rest=E_rest_inh, resist_scale=1./g_L, + weight_init=dist.constant(value=1.), key=subkeys[2] + ) + post_exc = LIFCell( ## post-syn LIF cell + "post_exc", n_units=1, tau_m=tau_m, resist_m=1., thr=v_thr, v_rest=v_rest, conduct_leak=1., v_reset=-75., + tau_theta=0., theta_plus=0., refract_time=2., key=subkeys[3] + ) + + Wexc.inputs << pre_exc.outputs + Winh.inputs << pre_inh.outputs + Wexc.v << post_exc.v ## couple voltage to exc synapse + Winh.v << post_exc.v ## couple voltage to inh synapse + post_exc.j << summation(Wexc.i_syn, Winh.i_syn) ## sum together excitatory & inhibitory pressures + + advance_process = (Process("advance_proc") + >> pre_exc.advance_state + >> pre_inh.advance_state + >> Wexc.advance_state + >> Winh.advance_state + >> post_exc.advance_state) + # ctx.wrap_and_add_command(advance_process.pure, name="run") + ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + + reset_process = (Process("reset_proc") + >> pre_exc.reset + >> pre_inh.reset + >> Wexc.reset + >> Winh.reset + >> post_exc.reset) + ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") +``` + +### Examining the Simple Spiking Circuit's Behavior + +To run the above spiking circuit, we then write the next block of code (making sure to track/store the resulting membrane potential and pulse values emitted by the post-synaptic LIF): + +```python +volts = [] +time_span = [] +spikes = [] + +ctx.reset() +pre_exc.inputs.set(jnp.ones((1, n_exc))) +pre_inh.inputs.set(jnp.ones((1, n_inh))) +post_exc.v.set(post_exc.v.value * 0 - 65.) ## initial condition for LIF is -65 mV +volts.append(post_exc.v.value) +time_span.append(0.) +Tsteps = int(T/dt) + 1 +for t in range(1, Tsteps): + ctx.run(t=t * dt, dt=dt) + print(f"\r v {post_exc.v.value}", end="") + volts.append(post_exc.v.value) + spikes.append(post_exc.s.value) + time_span.append(t) #* dt) +print() +volts = jnp.squeeze(jnp.concatenate(volts, axis=1)) +spikes = jnp.squeeze(jnp.concatenate(spikes, axis=1)) +``` + +from which we may then write the following plotting code to visualize the post-synaptic LIF unit's membrane potential time-course along with any spikes it might have produced in response to the pre-synaptic spike trains: + +```python +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +time_ticks = [] +time_labs = [] +time_ticks.append(0) +time_labs.append(f"{0.}") +tdiv = 1000 +for t in range(Tsteps): + if t % tdiv == 0: + time_ticks.append(t) + time_labs.append(f"{t * dt:.0f}") + +fig, ax = plt.subplots() + +volt_vals = ax.plot(time_span, volts, '-.', color='tab:red') +stat = jnp.where(spikes > 0.) +indx = (stat[0] * 1. - 1.).tolist() +v_thr_below = -0.75 +v_thr_above = 2. +spk = ax.vlines(x=indx, ymin=v_thr-v_thr_below, ymax=v_thr+v_thr_above, colors='black', ls='-', lw=2) +_v_thr = v_thr +ax.hlines(y=_v_thr, xmin=0., xmax=time_span[-1], colors='blue', ls='-.', lw=2) + +ax.set(xlabel='Time (ms)', ylabel='Voltage', + title='Exponential Synapse LIF') +ax.grid() +fig.savefig("ei_circuit_dynamics.jpg") +``` + +which should produce a figure depicting dynamics similar to the one below. Black tick +marks indicate post-synaptic pulses whereas the horizontal dashed blue shows the LIF unit's +voltage threshold. + + +```{eval-rst} +.. table:: + :align: center + + +--------------------------------------------------------------------+ + | .. image:: ../../images/tutorials/neurocog/ei_circuit_dynamics.jpg | + | :width: 400px | + | :align: center | + +--------------------------------------------------------------------+ +``` + +Notice that the above shows the behavior of the post-synaptic LIF in response to the integration of pulses coming from two Poisson spike trains both at rates of $10$ Hz (since both `exc_freq` and `inh_freq` have been set to ten). Messing with the frequencies of the excitatory and inhibitory pulse trains can lead to sparser or denser post-synaptic spike outputs. For instance, if we increase the frequency of the excitatory train to $15$ Hz (keeping the inhibitory one at $10$ Hz), we get a denser post-synaptic output pulse pattern as in the left figure below. In contrast, if we instead increase the inhibitory frequency to $30$ Hz (keeping the excitatory at $10$ Hz), we obtain a sparser post-synaptic output pulse train as in the right figure below. + + +```{eval-rst} +.. table:: + :align: center + + +-----------------------------------------------------------------------+-----------------------------------------------------------------------+ + | .. image:: ../../images/tutorials/neurocog/ei_circuit_dense_exc.jpg | .. image:: ../../images/tutorials/neurocog/ei_circuit_sparse_inh.jpg | + | :width: 400px | :width: 400px | + | :align: center | :align: center | + +-----------------------------------------------------------------------+-----------------------------------------------------------------------+ +``` + +## References + +[1] Sterratt, David, et al. Principles of computational modelling in neuroscience. Cambridge university +press, 2023.
+[2] Roth, Arnd, and Mark CW van Rossum. "Modeling synapses." Computational modeling methods for neuroscientists 6.139 (2009): 700. diff --git a/docs/tutorials/neurocog/error_cell.md b/docs/tutorials/neurocog/error_cell.md index a8b5e28b..04368d5d 100644 --- a/docs/tutorials/neurocog/error_cell.md +++ b/docs/tutorials/neurocog/error_cell.md @@ -1,5 +1,9 @@ + + # Lecture 3B: Error Cell Models + + Error cells are a particularly useful component cell that offers a simple and fast way of computing mismatch signals, i.e., error values that compare a target value against a predicted value. In predictive coding literature, mismatch @@ -13,8 +17,11 @@ key examples of where error neurons come into play). In this lesson, we will briefly review one of the most commonly used ones -- the [Gaussian error cell](ngclearn.components.neurons.graded.gaussianErrorCell). + ## Calculating Mismatch Values with the Gaussian Error Cell + + The Gaussian error cell, much like most error neurons, is in fact a derived calculation when considering a cost function. Specifically, this error cell component inherits its name from the fact that it is producing output values diff --git a/docs/tutorials/neurocog/hebbian.md b/docs/tutorials/neurocog/hebbian.md index e2668fc6..8e67754c 100644 --- a/docs/tutorials/neurocog/hebbian.md +++ b/docs/tutorials/neurocog/hebbian.md @@ -1,4 +1,4 @@ -# Lecture 4A: Hebbian Synaptic Plasticity +# Lecture 4B: Hebbian Synaptic Plasticity In ngc-learn, synaptic plasticity is a key concept at the forefront of its design in order to promote research into novel ideas and framings of how diff --git a/docs/tutorials/neurocog/hodgkin_huxley_cell.md b/docs/tutorials/neurocog/hodgkin_huxley_cell.md index 1580ab1c..44e3b0a7 100755 --- a/docs/tutorials/neurocog/hodgkin_huxley_cell.md +++ b/docs/tutorials/neurocog/hodgkin_huxley_cell.md @@ -83,9 +83,9 @@ Formally, the core dynamics of the H-H cell can be written out as follows: $$ \tau_v \frac{\partial \mathbf{v}_t}{\partial t} &= \mathbf{j}_t - g_Na * \mathbf{m}^3_t * \mathbf{h}_t * (\mathbf{v}_t - v_Na) - g_K * \mathbf{n}^4_t * (\mathbf{v}_t - v_K) - g_L * (\mathbf{v}_t - v_L) \\ -\frac{\partial \mathbf{n}_t}{\partial t} &= alpha_n(\mathbf{v}_t) * (1 - \mathbf{n}_t) - beta_n(\mathbf{v}_t) * \mathbf{n}_t \\ -\frac{\partial \mathbf{m}_t}{\partial t} &= alpha_m(\mathbf{v}_t) * (1 - \mathbf{m}_t) - beta_m(\mathbf{v}_t) * \mathbf{m}_t \\ -\frac{\partial \mathbf{h}_t}{\partial t} &= alpha_h(\mathbf{v}_t) * (1 - \mathbf{h}_t) - beta_h(\mathbf{v}_t) * \mathbf{h}_t +\frac{\partial \mathbf{n}_t}{\partial t} &= \alpha_n(\mathbf{v}_t) * (1 - \mathbf{n}_t) - \beta_n(\mathbf{v}_t) * \mathbf{n}_t \\ +\frac{\partial \mathbf{m}_t}{\partial t} &= \alpha_m(\mathbf{v}_t) * (1 - \mathbf{m}_t) - \beta_m(\mathbf{v}_t) * \mathbf{m}_t \\ +\frac{\partial \mathbf{h}_t}{\partial t} &= \alpha_h(\mathbf{v}_t) * (1 - \mathbf{h}_t) - \beta_h(\mathbf{v}_t) * \mathbf{h}_t $$ where we observe that the above four-dimensional set of dynamics is composed of nonlinear ODEs. Notice that, in each gate or channel probability ODE, there are two generator functions (each of which is a function of the membrane potential $\mathbf{v}_t$) that produces the necessary dynamic coefficients at time $t$; $\alpha_x(\mathbf{v}_t)$ and $\beta_x(\mathbf{v}_t)$ produce different biopphysical weighting values depending on which channel $x = \{n, m, h\}$ they are related to. diff --git a/docs/tutorials/neurocog/index.rst b/docs/tutorials/neurocog/index.rst index 1fb23384..326591c2 100644 --- a/docs/tutorials/neurocog/index.rst +++ b/docs/tutorials/neurocog/index.rst @@ -58,8 +58,9 @@ work towards more advanced concepts. .. toctree:: :maxdepth: 1 - :caption: Forms of Plasticity + :caption: Synapses and Forms of Plasticity + dynamic_synapses hebbian stdp mod_stdp diff --git a/docs/tutorials/neurocog/mod_stdp.md b/docs/tutorials/neurocog/mod_stdp.md index fbb0b7db..3a76de37 100755 --- a/docs/tutorials/neurocog/mod_stdp.md +++ b/docs/tutorials/neurocog/mod_stdp.md @@ -1,4 +1,4 @@ -# Lecture 4C: Reward-Modulated Spike-Timing-Dependent Plasticity +# Lecture 4D: Reward-Modulated Spike-Timing-Dependent Plasticity In this lesson, we will build on the notions of spike-timing-dependent plasticity (STDP), covered [earlier here](../neurocog/stdp.md), to construct @@ -247,9 +247,8 @@ for i in range(T_max): ``` which will run all three models simultaneously for `200` simulated milliseconds -and collect statistics of interest. We may then finally make several -plots of what happens under each STDP mode. First, we will plot the resulting -synaptic magnitude over time, like so: +and collect statistics of interest. We may then finally make several plots of what happens under each STDP mode +(reproducing some key results in [1]. First, we will plot the resulting synaptic magnitude over time, like so: ```python import matplotlib.pyplot as plt @@ -374,3 +373,8 @@ modulated STDP updates will only occur when the signal is non-zero; this is the advantage that MSTDP-ET offers over MSTDP as the synaptic change dynamics persist (yet decay) in between reward presentation times and thus MSTDP-ET will be more effective in cases when the reward signal is delayed. + +## References + +[1] Florian, Răzvan V. "Reinforcement learning through modulation of spike-timing-dependent synaptic plasticity." +Neural computation 19.6 (2007): 1468-1502. diff --git a/docs/tutorials/neurocog/short_term_plasticity.md b/docs/tutorials/neurocog/short_term_plasticity.md index 9bd74ed5..b225f3c5 100755 --- a/docs/tutorials/neurocog/short_term_plasticity.md +++ b/docs/tutorials/neurocog/short_term_plasticity.md @@ -1,4 +1,4 @@ -# Lecture 4D: Short-Term Plasticity +# Lecture 4E: Short-Term Plasticity In this lesson, we will study how short-term plasticity (STP) [1] dynamics -- where synaptic efficacy is cast in terms of the history of presynaptic activity -- @@ -19,7 +19,7 @@ synapse that evolves according to STP. We will first write our simulation of this dynamic synapse from the perspective of STF-dominated dynamics, plotting out the results under two different Poisson spike trains with different spiking frequencies. Then, we will modify our simulation -to emulate dynamics from a STD-dominated perspective. +to emulate dynamics from an STD-dominated perspective. ### Starting with Facilitation-Dominated Dynamics @@ -39,10 +39,10 @@ some mixture of the two. Ultimately, the above means that, in the context of spiking cells, when a pre-synaptic neuron emits a pulse, this act will affect the relative magnitude -of the synapse's efficacy; -in some cases, this will result in an increase (facilitation) and, in others, -this will result in a decrease (depression) that lasts over a short period -of time (several hundreds to thousands of milliseconds in many instances). +of the synapse's efficacy. In some cases, this will result in an increase +(facilitation) and, in others, this will result in a decrease (depression) +that lasts over a short period of time (several hundreds to thousands of +milliseconds in many instances). As a result of considering synapses to have a dynamic nature to them, both over short and long time-scales, plasticity can now be thought of as a stimulus and resource-dependent quantity, reflecting an important biophysical aspect that @@ -87,13 +87,16 @@ tau_d = 50. # ms plot_fname = "{}Hz_stp_{}.jpg".format(firing_rate_e, tag) with Context("Model") as model: - W = STPDenseSynapse("W", shape=(1, 1), weight_init=dist.constant(value=2.5), - resources_init=dist.constant(value=Rval), - tau_f=tau_f, tau_d=tau_d, key=subkeys[0]) + W = STPDenseSynapse( + "W", shape=(1, 1), weight_init=dist.constant(value=2.5), + resources_init=dist.constant(value=Rval), tau_f=tau_f, tau_d=tau_d, + key=subkeys[0] + ) z0 = PoissonCell("z0", n_units=1, target_freq=firing_rate_e, key=subkeys[0]) - z1 = LIFCell("z1", n_units=1, tau_m=tau_m, resist_m=(tau_m / dt) * R_m, - v_rest=-60., v_reset=-70., thr=-50., - tau_theta=0., theta_plus=0., refract_time=0.) + z1 = LIFCell( + "z1", n_units=1, tau_m=tau_m, resist_m=(tau_m / dt) * R_m, v_rest=-60., + v_reset=-70., thr=-50., tau_theta=0., theta_plus=0., refract_time=0. + ) W.inputs << z0.outputs ## z0 -> W z1.j << W.outputs ## W -> z1 @@ -156,7 +159,7 @@ resources ready for the dynamic synapse's use. ### Simulating and Visualizing STF -Now that we understand the basics of how an ngc-learn STP works, we can next +Now that we understand the basics of how an ngc-learn STP synapse works, we can next try it out on a simple pre-synaptic Poisson spike train. Writing out the simulated input Poisson spike train and our STP model's processing of this data can be done as follows: diff --git a/docs/tutorials/neurocog/stdp.md b/docs/tutorials/neurocog/stdp.md index 16423c8d..b8e889a0 100755 --- a/docs/tutorials/neurocog/stdp.md +++ b/docs/tutorials/neurocog/stdp.md @@ -1,4 +1,4 @@ -# Lecture 4B: Spike-Timing-Dependent Plasticity +# Lecture 4C: Spike-Timing-Dependent Plasticity In the context of spiking neuronal networks, one of the most important forms of adaptation that is often simulated is that of spike-timing-dependent diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index 38a829f8..af856c1a 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -38,6 +38,9 @@ from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse from .synapses.hebbian.BCMSynapse import BCMSynapse from .synapses.STPDenseSynapse import STPDenseSynapse +from .synapses.exponentialSynapse import ExponentialSynapse +from .synapses.doubleExpSynapse import DoupleExpSynapse +from .synapses.alphaSynapse import AlphaSynapse ## point to convolutional component types from .synapses.convolution.convSynapse import ConvSynapse diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py index 114d3785..29b5f267 100755 --- a/ngclearn/components/neurons/graded/gaussianErrorCell.py +++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py @@ -65,6 +65,12 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs): self.modulator = Compartment(restVals + 1.0) # to be set/consumed self.mask = Compartment(restVals + 1.0) + @staticmethod + def eval_log_density(target, mu, Sigma): + _dmu = (target - mu) + log_density = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) + return log_density + @transition(output_compartments=["dmu", "dtarget", "dSigma", "L", "mask"]) @staticmethod def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output @@ -79,6 +85,7 @@ def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian e dtarget = -dmu # reverse of e dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) + #L = GaussianErrorCell.eval_log_density(target, mu, Sigma) dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance... dtarget = dtarget * modulator * mask diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index 20c9bc98..e55ce8ff 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -145,6 +145,8 @@ class RateCell(JaxComponent): ## Rate-coded/real-valued cell act_fx: string name of activation function/nonlinearity to use + output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.) + integration_type: type of integration to use for this cell's dynamics; current supported forms include "euler" (Euler/RK-1 integration) and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") @@ -157,12 +159,13 @@ class RateCell(JaxComponent): ## Rate-coded/real-valued cell """ # Define Functions - def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", - threshold=("none", 0.), integration_type="euler", - batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs): + def __init__( + self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", output_scale=1., threshold=("none", 0.), + integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs): super().__init__(name, **kwargs) ## membrane parameter setup (affects ODE integration) + self.output_scale = output_scale self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode self.is_stateful = is_stateful if isinstance(tau_m, float): @@ -211,8 +214,9 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit @transition(output_compartments=["j", "j_td", "z", "zF"]) @staticmethod - def advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, - resist_scale, thresholdType, thr_lmbda, is_stateful, j, j_td, z): + def advance_state( + dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, resist_scale, thresholdType, thr_lmbda, is_stateful, + output_scale, j, j_td, z): #if tau_m > 0.: if is_stateful: ### run a step of integration over neuronal dynamics @@ -231,12 +235,12 @@ def advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, elif thresholdType == "cauchy_threshold": tmp_z = threshold_cauchy(tmp_z, thr_lmbda) z = tmp_z ## pre-activation function value(s) - zF = fx(z) ## post-activation function value(s) + zF = fx(z) * output_scale ## post-activation function value(s) else: ## run in "stateless" mode (when no membrane time constant provided) j_total = j + j_td z = _run_cell_stateless(j_total) - zF = fx(z) + zF = fx(z) * output_scale return j, j_td, z, zF @transition(output_compartments=["j", "j_td", "z", "zF"]) diff --git a/ngclearn/components/neurons/graded/rateCellOld.py b/ngclearn/components/neurons/graded/rateCellOld.py deleted file mode 100644 index 6962810c..00000000 --- a/ngclearn/components/neurons/graded/rateCellOld.py +++ /dev/null @@ -1,350 +0,0 @@ -from jax import numpy as jnp, random, jit -from functools import partial -from ngclearn.utils import tensorstats -from ngclearn import resolver, Component, Compartment -from ngclearn.components.jaxComponent import JaxComponent -from ngclearn.utils.model_utils import create_function, threshold_soft, \ - threshold_cauchy -from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ - step_euler, step_rk2, step_rk4 - -def _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma): - z_leak = z # * 2 ## Default: assume Gaussian - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_laplacian(z, j, j_td, tau_m, leak_gamma): - z_leak = jnp.sign(z) ## d/dx of Laplace is signum - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma): - z_leak = (z * 2)/(1. + jnp.square(z)) - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma): - z_leak = jnp.exp(-jnp.square(z)) * z * 2 - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - - -def _dfz_gaussian(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -def _dfz_laplacian(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_laplacian(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -def _dfz_cauchy(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -def _dfz_exp(t, z, params): ## diff-eq dynamics wrapper - j, j_td, tau_m, leak_gamma = params - dz_dt = _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma) - return dz_dt - -@jit -def _modulate(j, dfx_val): - """ - Apply a signal modulator to j (typically of the form of a derivative/dampening function) - - Args: - j: current/stimulus value to modulate - - dfx_val: modulator signal - - Returns: - modulated j value - """ - return j * dfx_val - -def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=0): - """ - Runs leaky rate-coded state dynamics one step in time. - - Args: - dt: integration time constant - - j: input (bottom-up) electrical/stimulus current - - j_td: modulatory (top-down) electrical/stimulus pressure - - z: current value of membrane/state - - tau_m: membrane/state time constant - - leak_gamma: strength of leak to apply to membrane/state - - integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2, 2 --> RK4) - - priorType: scale-shift prior distribution to impose over neural dynamics - - Returns: - New value of membrane/state for next time step - """ - _dfz = { - 0: _dfz_gaussian, - 1: _dfz_laplacian, - 2: _dfz_cauchy, - 3: _dfz_exp - }.get(priorType, _dfz_gaussian) - if integType == 1: - params = (j, j_td, tau_m, leak_gamma) - _, _z = step_rk2(0., z, _dfz, dt, params) - elif integType == 2: - params = (j, j_td, tau_m, leak_gamma) - _, _z = step_rk4(0., z, _dfz, dt, params) - else: - params = (j, j_td, tau_m, leak_gamma) - _, _z = step_euler(0., z, _dfz, dt, params) - return _z - -@jit -def _run_cell_stateless(j): - """ - A simplification of running a stateless set of dynamics over j (an identity - functional form of dynamics). - - Args: - j: stimulus to do nothing to - - Returns: - the stimulus - """ - return j + 0 - -class RateCell(JaxComponent): ## Rate-coded/real-valued cell - """ - A non-spiking cell driven by the gradient dynamics of neural generative - coding-driven predictive processing. - - The specific differential equation that characterizes this cell - is (for adjusting v, given current j, over time) is: - - | tau_m * dz/dt = lambda * prior(z) + (j + j_td) - | where j is the set of general incoming input signals (e.g., message-passed signals) - | and j_td is taken to be the set of top-down pressure signals - - | --- Cell Input Compartments: --- - | j - input pressure (takes in external signals) - | j_td - input/top-down pressure input (takes in external signals) - | --- Cell State Compartments --- - | z - rate activity - | --- Cell Output Compartments: --- - | zF - post-activation function activity, i.e., fx(z) - - Args: - name: the string name of this cell - - n_units: number of cellular entities (neural population size) - - tau_m: membrane/state time constant (milliseconds) - - prior: a kernel for specifying the type of centered scale-shift distribution - to impose over neuronal dynamics, applied to each neuron or - dimension within this component (Default: ("gaussian", 0)); this is - a tuple with 1st element containing a string name of the distribution - one wants to use while the second value is a `leak rate` scalar - that controls the influence/weighting that this distribution - has on the dynamics; for example, ("laplacian, 0.001") means that a - centered laplacian distribution scaled by `0.001` will be injected - into this cell's dynamics ODE each step of simulated time - - :Note: supported scale-shift distributions include "laplacian", - "cauchy", "exp", and "gaussian" - - act_fx: string name of activation function/nonlinearity to use - - integration_type: type of integration to use for this cell's dynamics; - current supported forms include "euler" (Euler/RK-1 integration) - and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") - - :Note: setting the integration type to the midpoint method will - increase the accuray of the estimate of the cell's evolution - at an increase in computational cost (and simulation time) - - resist_scale: a scaling factor applied to incoming pressure `j` (default: 1) - """ - - # Define Functions - def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", - threshold=("none", 0.), integration_type="euler", - batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs): - super().__init__(name, **kwargs) - - ## membrane parameter setup (affects ODE integration) - self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode - self.is_stateful = is_stateful - if isinstance(tau_m, float): - if tau_m <= 0: ## trigger stateless mode - self.is_stateful = False - priorType, leakRate = prior - self.priorType = { - "gaussian": 0, - "laplacian": 1, - "cauchy": 2, - "exp": 3 - }.get(priorType, 0) ## type of scale-shift prior to impose over the leak - self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior) - thresholdType, thr_lmbda = threshold - self.thresholdType = thresholdType ## type of thresholding function to use - self.thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics - self.resist_scale = resist_scale ## a "resistance" scaling factor - - ## integration properties - self.integrationType = integration_type - self.intgFlag = get_integrator_code(self.integrationType) - - ## Layer size setup - _shape = (batch_size, n_units) ## default shape is 2D/matrix - if shape is None: - shape = (n_units,) ## we set shape to be equal to n_units if nothing provided - else: - _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor - self.shape = shape - self.n_units = n_units - self.batch_size = batch_size - - omega_0 = None - if act_fx == "sine": - omega_0 = kwargs["omega_0"] - self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0) - - # compartments (state of the cell & parameters will be updated through stateless calls) - restVals = jnp.zeros(_shape) - self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current - self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity - self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure - self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity - - @staticmethod - def _advance_state(dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, - resist_scale, thresholdType, thr_lmbda, is_stateful, j, j_td, z): - #if tau_m > 0.: - if is_stateful: - ### run a step of integration over neuronal dynamics - ## Notes: - ## self.pressure <-- "top-down" expectation / contextual pressure - ## self.current <-- "bottom-up" data-dependent signal - dfx_val = dfx(z) - j = _modulate(j, dfx_val) - j = j * resist_scale - tmp_z = _run_cell(dt, j, j_td, z, - tau_m, leak_gamma=priorLeakRate, - integType=intgFlag, priorType=priorType) - ## apply optional thresholding sub-dynamics - if thresholdType == "soft_threshold": - tmp_z = threshold_soft(tmp_z, thr_lmbda) - elif thresholdType == "cauchy_threshold": - tmp_z = threshold_cauchy(tmp_z, thr_lmbda) - z = tmp_z ## pre-activation function value(s) - zF = fx(z) ## post-activation function value(s) - else: - ## run in "stateless" mode (when no membrane time constant provided) - j_total = j + j_td - z = _run_cell_stateless(j_total) - zF = fx(z) - return j, j_td, z, zF - - @resolver(_advance_state) - def advance_state(self, j, j_td, z, zF): - self.j.set(j) - self.j_td.set(j_td) - self.z.set(z) - self.zF.set(zF) - - @staticmethod - def _reset(batch_size, shape): #n_units - _shape = (batch_size, shape[0]) - if len(shape) > 1: - _shape = (batch_size, shape[0], shape[1], shape[2]) - restVals = jnp.zeros(_shape) - return tuple([restVals for _ in range(4)]) - - @resolver(_reset) - def reset(self, j, zF, j_td, z): - self.j.set(j) # electrical current - self.zF.set(zF) # rate-coded output - activity - self.j_td.set(j_td) # top-down electrical current - pressure - self.z.set(z) # rate activity - - def save(self, directory, **kwargs): - ## do a protected save of constants, depending on whether they are floats or arrays - tau_m = (self.tau_m if isinstance(self.tau_m, float) - else jnp.ones([[self.tau_m]])) - priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float) - else jnp.ones([[self.priorLeakRate]])) - resist_scale = (self.resist_scale if isinstance(self.resist_scale, float) - else jnp.ones([[self.resist_scale]])) - - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, - tau_m=tau_m, priorLeakRate=priorLeakRate, - resist_scale=resist_scale) #, key=self.key.value) - - def load(self, directory, seeded=False, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - ## constants loaded in - self.tau_m = data['tau_m'] - self.priorLeakRate = data['priorLeakRate'] - self.resist_scale = data['resist_scale'] - #if seeded: - # self.key.set(data['key']) - - @classmethod - def help(cls): ## component help function - properties = { - "cell_type": "RateCell - evolves neurons according to rate-coded/" - "continuous dynamics " - } - compartment_props = { - "inputs": - {"j": "External input stimulus value(s)", - "j_td": "External top-down input stimulus value(s); these get " - "multiplied by the derivative of f(x), i.e., df(x)"}, - "states": - {"z": "Update to rate-coded continuous dynamics; value at time t"}, - "outputs": - {"zF": "Nonlinearity/function applied to rate-coded dynamics; f(z)"}, - } - hyperparams = { - "n_units": "Number of neuronal cells to model in this layer", - "batch_size": "Batch size dimension of this component", - "tau_m": "Cell state/membrane time constant", - "prior": "What kind of kurtotic prior to place over neuronal dynamics?", - "act_fx": "Elementwise activation function to apply over cell state `z`", - "threshold": "What kind of iterative thresholding function to place over neuronal dynamics?", - "integration_type": "Type of numerical integration to use for the cell dynamics", - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "tau_m * dz/dt = Prior(z; gamma) + (j + j_td)", - "hyperparameters": hyperparams} - return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - X = RateCell("X", 9, 0.03) - print(X) \ No newline at end of file diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index 9566f3c9..bf7f2c14 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -14,19 +14,29 @@ #from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment -#@jit -def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics - mask = (rfr >= refract_T) * 1. # get refractory mask - ## update voltage / membrane potential - dv_dt = (v_rest - v) * v_decay + (j * mask) - dv_dt = dv_dt * (1./tau_m) - return dv_dt +# def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics +# mask = (rfr >= refract_T) * 1. # get refractory mask +# ## update voltage / membrane potential +# dv_dt = (v_rest - v) * v_decay + (j * mask) +# dv_dt = dv_dt * (1./tau_m) +# return dv_dt +# +# def _dfv(t, v, params): ## voltage dynamics wrapper +# j, rfr, tau_m, refract_T, v_rest, v_decay = params +# dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay) +# return dv_dt + + def _dfv(t, v, params): ## voltage dynamics wrapper - j, rfr, tau_m, refract_T, v_rest, v_decay = params - dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay) + j, rfr, tau_m, refract_T, v_rest, g_L = params + mask = (rfr >= refract_T) * 1. # get refractory mask + ## update voltage / membrane potential + dv_dt = (v_rest - v) * g_L + (j * mask) + dv_dt = dv_dt * (1. / tau_m) return dv_dt + #@partial(jit, static_argnums=[3, 4]) def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05): ### Runs homeostatic threshold update dynamics one step (via Euler integration). @@ -38,6 +48,7 @@ def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05): #_V_theta = V_theta + -V_theta * (dt/tau_theta) + S * alpha return _v_theta + class LIFCell(JaxComponent): ## leaky integrate-and-fire cell """ A spiking cell based on leaky integrate-and-fire (LIF) neuronal dynamics. @@ -73,14 +84,14 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell thr: base value for adaptive thresholds that govern short-term plasticity (in milliVolts, or mV; default: -52. mV) - v_rest: membrane resting potential (in mV; default: -65 mV) + v_rest: reversal potential or membrane resting potential (in mV; default: -65 mV) v_reset: membrane reset potential (in mV) -- upon occurrence of a spike, a neuronal cell's membrane potential will be set to this value; (default: -60 mV) - v_decay: decay factor applied to voltage leak (Default: 1.); setting this - to 0 mV recovers pure integrate-and-fire (IF) dynamics + conduct_leak: leak conductance (g_L) value or decay factor applied to voltage leak + (Default: 1.); setting this to 0 mV recovers pure integrate-and-fire (IF) dynamics tau_theta: homeostatic threshold time constant @@ -112,16 +123,15 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell "arctan" (arc-tangent estimator), and "secant_lif" (the LIF-specialized secant estimator) - lower_clamp_voltage: if True, this will ensure voltage never is below - the value of `v_rest` (default: True) + v_min: minimum voltage to clamp dynamics to (Default: None) """ ## batch_size arg? - @deprecate_args(thr_jitter=None) - def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., - v_reset=-60., v_decay=1., tau_theta=1e7, theta_plus=0.05, - refract_time=5., one_spike=False, integration_type="euler", - surrogate_type="straight_through", lower_clamp_voltage=True, - **kwargs): + @deprecate_args(thr_jitter=None, v_decay="conduct_leak") + def __init__( + self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., conduct_leak=1., tau_theta=1e7, + theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler", surrogate_type="straight_through", + v_min=None, max_one_spike=False, **kwargs + ): super().__init__(name, **kwargs) ## Integration properties @@ -132,11 +142,12 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., self.tau_m = tau_m ## membrane time constant self.resist_m = resist_m ## resistance value self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step - self.lower_clamp_voltage = lower_clamp_voltage ## True ==> ensures voltage is never < v_rest + self.max_one_spike = max_one_spike + self.v_min = v_min ## ensures voltage is never < v_min self.v_rest = v_rest #-65. # mV self.v_reset = v_reset # -60. # -65. # mV (milli-volts) - self.v_decay = v_decay ## controls strength of voltage leak (1 -> LIF, 0 => IF) + self.g_L = conduct_leak ## controls strength of voltage leak (1 -> LIF, 0 => IF) ## basic asserts to prevent neuronal dynamics breaking... #assert (self.v_decay * self.dt / self.tau_m) <= 1. ## <-- to integrate in verify... assert self.resist_m > 0. @@ -178,11 +189,11 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., @transition(output_compartments=["v", "s", "s_raw", "rfr", "thr_theta", "tols", "key", "surrogate"]) @staticmethod def advance_state( - t, dt, tau_m, resist_m, v_rest, v_reset, v_decay, refract_T, thr, tau_theta, theta_plus, - one_spike, lower_clamp_voltage, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols + t, dt, tau_m, resist_m, v_rest, v_reset, g_L, refract_T, thr, tau_theta, theta_plus, one_spike, max_one_spike, + v_min, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols ): skey = None ## this is an empty dkey if single_spike mode turned off - if one_spike: + if one_spike and not max_one_spike: key, skey = random.split(key, 2) ## run one integration step for neuronal dynamics j = j * resist_m @@ -191,13 +202,14 @@ def advance_state( _v_thr = thr_theta + thr ## calc present voltage threshold #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask ## update voltage / membrane potential - v_params = (j, rfr, tau_m, refract_T, v_rest, v_decay) + v_params = (j, rfr, tau_m, refract_T, v_rest, g_L) if intgFlag == 1: _, _v = step_rk2(0., v, _dfv, dt, v_params) else: _, _v = step_euler(0., v, _dfv, dt, v_params) ## obtain action potentials/spikes/pulses s = (_v > _v_thr) * 1. + v_prespike = v ## update refractory variables _rfr = (rfr + dt) * (1. - s) ## perform hyper-polarization of neuronal cells @@ -212,6 +224,9 @@ def advance_state( rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1], dtype=jnp.float32) s = s * (1. - m_switch) + rS * m_switch + if max_one_spike: + rS = nn.one_hot(jnp.argmax(v_prespike, axis=1), num_classes=s.shape[1], dtype=jnp.float32) ## get max-volt spike + s = s * rS ## mask out non-max volt spikes ############################################################################ raw_spikes = raw_s v = _v @@ -223,8 +238,8 @@ def advance_state( thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus) ## update tols tols = (1. - s) * tols + (s * t) - if lower_clamp_voltage: ## ensure voltage never < v_rest - v = jnp.maximum(v, v_rest) + if v_min is not None: ## ensures voltage never < v_rest + v = jnp.maximum(v, v_min) return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate @transition(output_compartments=["j", "v", "s", "s_raw", "rfr", "tols", "surrogate"]) diff --git a/ngclearn/components/other/varTrace.py b/ngclearn/components/other/varTrace.py index 3889cace..94510e75 100644 --- a/ngclearn/components/other/varTrace.py +++ b/ngclearn/components/other/varTrace.py @@ -59,6 +59,8 @@ class VarTrace(JaxComponent): ## low-pass filter a_delta: value to increment a trace by in presence of a spike; note if set to a value <= 0, then a piecewise gated trace will be used instead + P_scale: if `a_delta=0`, then this scales the value that the trace snaps to upon receiving a pulse value + gamma_tr: an extra multiplier in front of the leak of the trace (Default: 1) decay_type: string indicating the decay type to be applied to ODE @@ -69,19 +71,24 @@ class VarTrace(JaxComponent): ## low-pass filter 2) `'exp'` = exponential trace filter, i.e., decay = exp(-dt/tau_tr) * x_tr; 3) `'step'` = step trace, i.e., decay = 0 (a pulse applied upon input value) + n_nearest_spikes: (k) if k > 0, this makes the trace act like a nearest-neighbor trace, + i.e., k = 1 yields the 1-nearest (neighbor) trace (Default: 0) + batch_size: batch size dimension of this cell (Default: 1) """ # Define Functions - def __init__(self, name, n_units, tau_tr, a_delta, gamma_tr=1, decay_type="exp", - batch_size=1, **kwargs): + def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay_type="exp", + n_nearest_spikes=0, batch_size=1, **kwargs): super().__init__(name, **kwargs) ## Trace control coefficients self.tau_tr = tau_tr ## trace time constant self.a_delta = a_delta ## trace increment (if spike occurred) + self.P_scale = P_scale ## trace scale if non-additive trace to be used self.gamma_tr = gamma_tr self.decay_type = decay_type ## lin --> linear decay; exp --> exponential decay + self.n_nearest_spikes = n_nearest_spikes ## Layer Size Setup self.batch_size = batch_size @@ -94,17 +101,22 @@ def __init__(self, name, n_units, tau_tr, a_delta, gamma_tr=1, decay_type="exp", @transition(output_compartments=["outputs", "trace"]) @staticmethod - def advance_state(dt, decay_type, tau_tr, a_delta, gamma_tr, inputs, trace): + def advance_state( + dt, decay_type, tau_tr, a_delta, P_scale, gamma_tr, inputs, trace, n_nearest_spikes + ): decayFactor = 0. if "exp" in decay_type: decayFactor = jnp.exp(-dt/tau_tr) elif "lin" in decay_type: decayFactor = (1. - dt/tau_tr) _x_tr = gamma_tr * trace * decayFactor - if a_delta > 0.: - _x_tr = _x_tr + inputs * a_delta + if n_nearest_spikes > 0: ## run k-nearest neighbor trace + _x_tr = _x_tr + inputs * (a_delta - (trace/n_nearest_spikes)) else: - _x_tr = _x_tr * (1. - inputs) + inputs + if a_delta > 0.: ## run full convolution trace + _x_tr = _x_tr + inputs * a_delta + else: ## run simple max-clamped trace + _x_tr = _x_tr * (1. - inputs) + inputs * P_scale trace = _x_tr return trace, trace @@ -135,12 +147,15 @@ def help(cls): ## component help function "tau_tr": "Trace/filter time constant", "a_delta": "Increment to apply to trace (if not set to 0); " "otherwise, traces clamp to 1 and then decay", + "P_scale": "Max value to snap trace to if a max-clamp trace is triggered/configured", "decay_type": "Indicator of what type of decay dynamics to use " - "as filter is updated at time t" + "as filter is updated at time t", + "n_nearest_neighbors": "Number of nearest pulses to affect/increment trace (if > 0)" } info = {cls.__name__: properties, "compartments": compartment_props, - "dynamics": "tau_tr * dz/dt ~ -z + inputs", + "dynamics": "tau_tr * dz/dt ~ -z + inputs * a_delta (full convolution trace); " + "tau_tr * dz/dt ~ -z + inputs * (a_delta - z/n_nearest_neighbors) (near trace)", "hyperparameters": hyperparams} return info diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py index fc5b8a60..2c21c231 100644 --- a/ngclearn/components/synapses/__init__.py +++ b/ngclearn/components/synapses/__init__.py @@ -4,7 +4,9 @@ ## short-term plasticity components from .STPDenseSynapse import STPDenseSynapse - +from .exponentialSynapse import ExponentialSynapse +from .doubleExpSynapse import DoupleExpSynapse +from .alphaSynapse import AlphaSynapse ## dense synaptic components from .hebbian.hebbianSynapse import HebbianSynapse diff --git a/ngclearn/components/synapses/alphaSynapse.py b/ngclearn/components/synapses/alphaSynapse.py new file mode 100644 index 00000000..cf5f9543 --- /dev/null +++ b/ngclearn/components/synapses/alphaSynapse.py @@ -0,0 +1,186 @@ +from jax import random, numpy as jnp, jit +from ngcsimlib.compilers.process import transition +from ngcsimlib.component import Component +from ngcsimlib.compartment import Compartment + +from ngclearn.utils.weight_distribution import initialize_params +from ngcsimlib.logger import info +from ngclearn.components.synapses import DenseSynapse +from ngclearn.utils import tensorstats + +class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable + """ + A dynamic alpha synaptic cable; this synapse evolves according to alpha synaptic conductance dynamics. + Specifically, the conductance dynamics are as follows: + + | dh/dt = -h/tau_decay + gBar sum_k (t - t_k) // h is an intermediate variable + | dg/dt = -g/tau_decay + h/tau_decay + | i_syn = g * (syn_rest - v) // g is `g_syn` and h is `h_syn` in this synapse implementation + | where: syn_rest is the post-synaptic reverse potential for this synapse + | t_k marks time of -pre-synaptic k-th pulse received by post-synaptic unit + + + | --- Synapse Compartments: --- + | inputs - input (takes in external signals, e.g., pre-synaptic pulses/spikes) + | outputs - output signals (also equal to i_syn, total electrical current) + | v - coupled voltages from post-synaptic neurons this synaptic cable connects to + | weights - current value matrix of synaptic efficacies + | biases - current value vector of synaptic bias values + | --- Dynamic / Short-term Plasticity Compartments: --- + | g_syn - fixed value matrix of synaptic resources (U) + | i_syn - derived total electrical current variable + + Args: + name: the string name of this synapse + + shape: tuple specifying shape of this synaptic cable (usually a 2-tuple + with number of inputs by number of outputs) + + tau_decay: synaptic decay time constant (ms) + + g_syn_bar: maximum conductance elicited by each incoming spike ("synaptic weight") + + syn_rest: synaptic reversal potential; note, if this is set to `None`, then this + synaptic conductance model will no longer be voltage-dependent (and will ignore + the voltage compartment provided by an external spiking cell) + + weight_init: a kernel to drive initialization of this synaptic cable's values; + typically a tuple with 1st element as a string calling the name of + initialization to use + + bias_init: a kernel to drive initialization of biases for this synaptic cable + (Default: None, which turns off/disables biases) + + resist_scale: a fixed (resistance) scaling factor to apply to synaptic + transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + + p_conn: probability of a connection existing (default: 1.); setting + this to < 1 and > 0. will result in a sparser synaptic structure + (lower values yield sparse structure) + + is_nonplastic: boolean indicating if this synapse permits plasticity adjustments (Default: True) + + """ + + # Define Functions + def __init__( + self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., + is_nonplastic=True, **kwargs + ): + super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs) + ## dynamic synapse meta-parameters + self.tau_decay = tau_decay + self.g_syn_bar = g_syn_bar + self.syn_rest = syn_rest ## synaptic resting potential + + ## Set up short-term plasticity / dynamic synapse compartment values + #tmp_key, *subkeys = random.split(self.key.value, 4) + #preVals = jnp.zeros((self.batch_size, shape[0])) + postVals = jnp.zeros((self.batch_size, shape[1])) + self.v = Compartment(postVals) ## coupled voltage (from a post-synaptic neuron) + self.i_syn = Compartment(postVals) ## electrical current output + self.g_syn = Compartment(postVals) ## conductance variable + self.h_syn = Compartment(postVals) ## intermediate conductance variable + if is_nonplastic: + self.weights.set(self.weights.value * 0 + 1.) + + @transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"]) + @staticmethod + def advance_state( + dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v + ): + s = inputs + ## advance conductance variable(s) + _out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron) + dhsyn_dt = -h_syn/tau_decay + (_out * g_syn_bar) * (1./dt) + h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h + + dgsyn_dt = -g_syn/tau_decay + h_syn * (1./dt) # or -g_syn/tau_decay + h_syn/tau_decay + g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g + + ## compute derive electrical current variable + i_syn = -g_syn * Rscale + if syn_rest is not None: + i_syn = -(g_syn * Rscale) * (v - syn_rest) + outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases + return outputs, i_syn, g_syn, h_syn + + @transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "h_syn", "v"]) + @staticmethod + def reset(batch_size, shape): + preVals = jnp.zeros((batch_size, shape[0])) + postVals = jnp.zeros((batch_size, shape[1])) + inputs = preVals + outputs = postVals + i_syn = postVals + g_syn = postVals + h_syn = postVals + v = postVals + return inputs, outputs, i_syn, g_syn, h_syn, v + + def save(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + if self.bias_init != None: + jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) + else: + jnp.savez(file_name, weights=self.weights.value) + + def load(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + data = jnp.load(file_name) + self.weights.set(data['weights']) + if "biases" in data.keys(): + self.biases.set(data['biases']) + + @classmethod + def help(cls): ## component help function + properties = { + "synapse_type": "AlphaSynapse - performs a synaptic transformation of inputs to produce " + "output signals (e.g., a scaled linear multivariate transformation); " + "this synapse is dynamic, changing according to an alpha function" + } + compartment_props = { + "inputs": + {"inputs": "Takes in external input signal values", + "v" : "Post-synaptic voltage dependence (comes from a wired-to spiking cell)"}, + "states": + {"weights": "Synapse efficacy/strength parameter values", + "biases": "Base-rate/bias parameter values", + "g_syn" : "Synaptic conductnace", + "h_syn" : "Intermediate synaptic conductance", + "i_syn" : "Total electrical current", + "key": "JAX PRNG key"}, + "outputs": + {"outputs": "Output of synaptic transformation"}, + } + hyperparams = { + "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", + "weight_init": "Initialization conditions for synaptic weight (W) values", + "bias_init": "Initialization conditions for bias/base-rate (b) values", + "resist_scale": "Resistance level scaling factor (applied to output of transformation)", + "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", + "tau_decay": "Conductance decay time constant (ms)", + "g_bar_syn": "Maximum conductance value", + "syn_rest": "Synaptic reversal potential" + } + info = {cls.__name__: properties, + "compartments": compartment_props, + "dynamics": "outputs = g_syn * (v - syn_rest); " + "dhsyn_dt = (W * inputs) * g_syn_bar - h_syn/tau_decay " + "dgsyn_dt = -g_syn/tau_decay + h_syn", + "hyperparameters": hyperparams} + return info + + def __repr__(self): + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + maxlen = max(len(c) for c in comps) + 5 + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" + for c in comps: + stats = tensorstats(getattr(self, c).value) + if stats is not None: + line = [f"{k}: {v}" for k, v in stats.items()] + line = ", ".join(line) + else: + line = "None" + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" + return lines diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py index 0a63171d..fc4e7ea0 100755 --- a/ngclearn/components/synapses/denseSynapse.py +++ b/ngclearn/components/synapses/denseSynapse.py @@ -41,8 +41,10 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable """ # Define Functions - def __init__(self, name, shape, weight_init=None, bias_init=None, - resist_scale=1., p_conn=1., batch_size=1, **kwargs): + def __init__( + self, name, shape, weight_init=None, bias_init=None, resist_scale=1., + p_conn=1., batch_size=1, **kwargs + ): super().__init__(name, **kwargs) self.batch_size = batch_size @@ -60,10 +62,10 @@ def __init__(self, name, shape, weight_init=None, bias_init=None, self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8} weights = initialize_params(subkeys[0], self.weight_init, shape) if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed - mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape) - weights = weights * mask ## sparsify matrix + p_mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape) + weights = weights * p_mask ## sparsify matrix - self.batch_size = 1 + self.batch_size = batch_size #1 ## Compartment setup preVals = jnp.zeros((self.batch_size, shape[0])) postVals = jnp.zeros((self.batch_size, shape[1])) diff --git a/ngclearn/components/synapses/doubleExpSynapse.py b/ngclearn/components/synapses/doubleExpSynapse.py new file mode 100644 index 00000000..86225a68 --- /dev/null +++ b/ngclearn/components/synapses/doubleExpSynapse.py @@ -0,0 +1,192 @@ +from jax import random, numpy as jnp, jit +from ngcsimlib.compilers.process import transition +from ngcsimlib.component import Component +from ngcsimlib.compartment import Compartment + +from ngclearn.utils.weight_distribution import initialize_params +from ngcsimlib.logger import info +from ngclearn.components.synapses import DenseSynapse +from ngclearn.utils import tensorstats + +class DoupleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cable + """ + A dynamic double-exponential synaptic cable; this synapse evolves according to difference of two exponentials + synaptic conductance dynamics. + Specifically, the conductance dynamics are as follows: + + | dh/dt = -h/tau_rise + gBar sum_k (t - t_k) * (1/tau_rise - 1/tau_decay) // h is an intermediate variable + | dg/dt = -g/tau_decay + h + | i_syn = g * (syn_rest - v) // g is `g_syn` and h is `h_syn` in this synapse implementation + | where: syn_rest is the post-synaptic reverse potential for this synapse + | t_k marks time of -pre-synaptic k-th pulse received by post-synaptic unit + + | --- Synapse Compartments: --- + | inputs - input (takes in external signals, e.g., pre-synaptic pulses/spikes) + | outputs - output signals (also equal to i_syn, total electrical current) + | v - coupled voltages from post-synaptic neurons this synaptic cable connects to + | weights - current value matrix of synaptic efficacies + | biases - current value vector of synaptic bias values + | --- Dynamic / Short-term Plasticity Compartments: --- + | g_syn - fixed value matrix of synaptic resources (U) + | i_syn - derived total electrical current variable + + Args: + name: the string name of this synapse + + shape: tuple specifying shape of this synaptic cable (usually a 2-tuple + with number of inputs by number of outputs) + + tau_decay: synaptic decay time constant (ms) + + tau_rise: synaptic increase/rise time constant (ms) + + g_syn_bar: maximum conductance elicited by each incoming spike ("synaptic weight") + + syn_rest: synaptic reversal potential; note, if this is set to `None`, then this + synaptic conductance model will no longer be voltage-dependent (and will ignore + the voltage compartment provided by an external spiking cell) + + weight_init: a kernel to drive initialization of this synaptic cable's values; + typically a tuple with 1st element as a string calling the name of + initialization to use + + bias_init: a kernel to drive initialization of biases for this synaptic cable + (Default: None, which turns off/disables biases) + + resist_scale: a fixed (resistance) scaling factor to apply to synaptic + transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + + p_conn: probability of a connection existing (default: 1.); setting + this to < 1 and > 0. will result in a sparser synaptic structure + (lower values yield sparse structure) + + is_nonplastic: boolean indicating if this synapse permits plasticity adjustments (Default: True) + + """ + + # Define Functions + def __init__( + self, name, shape, tau_decay, tau_rise, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., + is_nonplastic=True, **kwargs + ): + super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs) + ## dynamic synapse meta-parameters + self.tau_decay = tau_decay + self.tau_rise = tau_rise + self.g_syn_bar = g_syn_bar + self.syn_rest = syn_rest ## synaptic resting potential + + ## Set up short-term plasticity / dynamic synapse compartment values + #tmp_key, *subkeys = random.split(self.key.value, 4) + #preVals = jnp.zeros((self.batch_size, shape[0])) + postVals = jnp.zeros((self.batch_size, shape[1])) + self.v = Compartment(postVals) ## coupled voltage (from a post-synaptic neuron) + self.i_syn = Compartment(postVals) ## electrical current output + self.g_syn = Compartment(postVals) ## conductance variable + self.h_syn = Compartment(postVals) ## intermediate conductance variable + if is_nonplastic: + self.weights.set(self.weights.value * 0 + 1.) + + @transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"]) + @staticmethod + def advance_state( + dt, tau_decay, tau_rise, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v + ): + s = inputs + #A = tau_decay/(tau_decay - tau_rise) * jnp.power((tau_rise/tau_decay), tau_rise/(tau_rise - tau_decay)) + A = 1. + ## advance conductance variable(s) + _out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron) + dhsyn_dt = -h_syn/tau_rise + ((_out * g_syn_bar) * (1. / tau_rise - 1. / tau_decay) * A) * (1./dt) + h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h + + dgsyn_dt = -g_syn/tau_decay + h_syn * (1./dt) + g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g + + ## compute derive electrical current variable + i_syn = -g_syn * Rscale + if syn_rest is not None: + i_syn = -(g_syn * Rscale) * (v - syn_rest) + outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases + return outputs, i_syn, g_syn, h_syn + + @transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "h_syn", "v"]) + @staticmethod + def reset(batch_size, shape): + preVals = jnp.zeros((batch_size, shape[0])) + postVals = jnp.zeros((batch_size, shape[1])) + inputs = preVals + outputs = postVals + i_syn = postVals + g_syn = postVals + h_syn = postVals + v = postVals + return inputs, outputs, i_syn, g_syn, h_syn, v + + def save(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + if self.bias_init != None: + jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) + else: + jnp.savez(file_name, weights=self.weights.value) + + def load(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + data = jnp.load(file_name) + self.weights.set(data['weights']) + if "biases" in data.keys(): + self.biases.set(data['biases']) + + @classmethod + def help(cls): ## component help function + properties = { + "synapse_type": "DoubleExpSynapse - performs a synaptic transformation of inputs to produce " + "output signals (e.g., a scaled linear multivariate transformation); " + "this synapse is dynamic, changing according to a difference of exponentials kernel" + } + compartment_props = { + "inputs": + {"inputs": "Takes in external input signal values", + "v" : "Post-synaptic voltage dependence (comes from a wired-to spiking cell)"}, + "states": + {"weights": "Synapse efficacy/strength parameter values", + "biases": "Base-rate/bias parameter values", + "g_syn" : "Synaptic conductnace", + "h_syn" : "Intermediate synaptic conductance", + "i_syn" : "Total electrical current", + "key": "JAX PRNG key"}, + "outputs": + {"outputs": "Output of synaptic transformation"}, + } + hyperparams = { + "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", + "weight_init": "Initialization conditions for synaptic weight (W) values", + "bias_init": "Initialization conditions for bias/base-rate (b) values", + "resist_scale": "Resistance level scaling factor (applied to output of transformation)", + "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", + "tau_decay": "Conductance decay time constant (ms)", + "tau_rise": "Conductance rise/increase time constant (ms)", + "g_bar_syn": "Maximum conductance value", + "syn_rest": "Synaptic reversal potential" + } + info = {cls.__name__: properties, + "compartments": compartment_props, + "dynamics": "outputs = g_syn * (v - syn_rest); " + "dhsyn_dt = (1/tau_rise - 1/tau_decay) * (W * inputs) * g_syn_bar - h_syn/tau_rise " + "dgsyn_dt = -g_syn/tau_decay + h_syn", + "hyperparameters": hyperparams} + return info + + def __repr__(self): + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + maxlen = max(len(c) for c in comps) + 5 + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" + for c in comps: + stats = tensorstats(getattr(self, c).value) + if stats is not None: + line = [f"{k}: {v}" for k, v in stats.items()] + line = ", ".join(line) + else: + line = "None" + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" + return lines diff --git a/ngclearn/components/synapses/exponentialSynapse.py b/ngclearn/components/synapses/exponentialSynapse.py new file mode 100644 index 00000000..a873baf9 --- /dev/null +++ b/ngclearn/components/synapses/exponentialSynapse.py @@ -0,0 +1,178 @@ +from jax import random, numpy as jnp, jit +from ngcsimlib.compilers.process import transition +from ngcsimlib.component import Component +from ngcsimlib.compartment import Compartment + +from ngclearn.utils.weight_distribution import initialize_params +from ngcsimlib.logger import info +from ngclearn.components.synapses import DenseSynapse +from ngclearn.utils import tensorstats + +class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable + """ + A dynamic exponential synaptic cable; this synapse evolves according to exponential synaptic conductance dynamics. + Specifically, the conductance dynamics are as follows: + + | dg/dt = -g/tau_decay + gBar sum_k (t - t_k) + | i_syn = g * (syn_rest - v) // g is `g_syn` in this synapse implementation + | where: syn_rest is the post-synaptic reverse potential for this synapse + | t_k marks time of -pre-synaptic k-th pulse received by post-synaptic unit + + + | --- Synapse Compartments: --- + | inputs - input (takes in external signals, e.g., pre-synaptic pulses/spikes) + | outputs - output signals (also equal to i_syn, total electrical current) + | v - coupled voltages from post-synaptic neurons this synaptic cable connects to + | weights - current value matrix of synaptic efficacies + | biases - current value vector of synaptic bias values + | --- Dynamic / Short-term Plasticity Compartments: --- + | g_syn - fixed value matrix of synaptic resources (U) + | i_syn - derived total electrical current variable + + Args: + name: the string name of this synapse + + shape: tuple specifying shape of this synaptic cable (usually a 2-tuple + with number of inputs by number of outputs) + + tau_decay: synaptic decay time constant (ms) + + g_syn_bar: maximum conductance elicited by each incoming spike ("synaptic weight") + + syn_rest: synaptic reversal potential; note, if this is set to `None`, then this + synaptic conductance model will no longer be voltage-dependent (and will ignore + the voltage compartment provided by an external spiking cell) + + weight_init: a kernel to drive initialization of this synaptic cable's values; + typically a tuple with 1st element as a string calling the name of + initialization to use + + bias_init: a kernel to drive initialization of biases for this synaptic cable + (Default: None, which turns off/disables biases) + + resist_scale: a fixed (resistance) scaling factor to apply to synaptic + transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + + p_conn: probability of a connection existing (default: 1.); setting + this to < 1 and > 0. will result in a sparser synaptic structure + (lower values yield sparse structure) + + is_nonplastic: boolean indicating if this synapse permits plasticity adjustments (Default: True) + + """ + + # Define Functions + def __init__( + self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., + is_nonplastic=True, **kwargs + ): + super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs) + ## dynamic synapse meta-parameters + self.tau_decay = tau_decay + self.g_syn_bar = g_syn_bar + self.syn_rest = syn_rest ## synaptic resting potential + + ## Set up short-term plasticity / dynamic synapse compartment values + #tmp_key, *subkeys = random.split(self.key.value, 4) + #preVals = jnp.zeros((self.batch_size, shape[0])) + postVals = jnp.zeros((self.batch_size, shape[1])) + self.v = Compartment(postVals) ## coupled voltage (from a post-synaptic neuron) + self.i_syn = Compartment(postVals) ## electrical current output + self.g_syn = Compartment(postVals) ## conductance variable + if is_nonplastic: + self.weights.set(self.weights.value * 0 + 1.) + + @transition(output_compartments=["outputs", "i_syn", "g_syn"]) + @staticmethod + def advance_state( + dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, v + ): + s = inputs + ## advance conductance variable + _out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron) + dgsyn_dt = -g_syn/tau_decay + (_out * g_syn_bar) * (1./dt) + g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance + ## compute derive electrical current variable + i_syn = -g_syn * Rscale + if syn_rest is not None: + i_syn = -(g_syn * Rscale) * (v - syn_rest) + outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases + return outputs, i_syn, g_syn + + @transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "v"]) + @staticmethod + def reset(batch_size, shape): + preVals = jnp.zeros((batch_size, shape[0])) + postVals = jnp.zeros((batch_size, shape[1])) + inputs = preVals + outputs = postVals + i_syn = postVals + g_syn = postVals + v = postVals + return inputs, outputs, i_syn, g_syn, v + + def save(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + if self.bias_init != None: + jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) + else: + jnp.savez(file_name, weights=self.weights.value) + + def load(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + data = jnp.load(file_name) + self.weights.set(data['weights']) + if "biases" in data.keys(): + self.biases.set(data['biases']) + + @classmethod + def help(cls): ## component help function + properties = { + "synapse_type": "ExponentialSynapse - performs a synaptic transformation of inputs to produce " + "output signals (e.g., a scaled linear multivariate transformation); " + "this synapse is dynamic, evolving according to an exponential kernel" + } + compartment_props = { + "inputs": + {"inputs": "Takes in external input signal values", + "v" : "Post-synaptic voltage dependence (comes from a wired-to spiking cell)"}, + "states": + {"weights": "Synapse efficacy/strength parameter values", + "biases": "Base-rate/bias parameter values", + "g_syn" : "Synaptic conductnace", + "h_syn" : "Intermediate synaptic conductance", + "i_syn" : "Total electrical current", + "key": "JAX PRNG key"}, + "outputs": + {"outputs": "Output of synaptic transformation"}, + } + hyperparams = { + "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", + "weight_init": "Initialization conditions for synaptic weight (W) values", + "bias_init": "Initialization conditions for bias/base-rate (b) values", + "resist_scale": "Resistance level scaling factor (applied to output of transformation)", + "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", + "tau_decay": "Conductance decay time constant (ms)", + "g_bar_syn": "Maximum conductance value", + "syn_rest": "Synaptic reversal potential" + } + info = {cls.__name__: properties, + "compartments": compartment_props, + "dynamics": "outputs = g_syn * (v - syn_rest); " + "dgsyn_dt = (W * inputs) * g_syn_bar - g_syn/tau_decay ", + "hyperparameters": hyperparams} + return info + + def __repr__(self): + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + maxlen = max(len(c) for c in comps) + 5 + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" + for c in comps: + stats = tensorstats(getattr(self, c).value) + if stats is not None: + line = [f"{k}: {v}" for k, v in stats.items()] + line = ", ".join(line) + else: + line = "None" + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" + return lines diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index 43a41799..faaee5c9 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -162,10 +162,11 @@ class HebbianSynapse(DenseSynapse): # Define Functions @deprecate_args(_rebind=False, w_decay='prior') - def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None, - w_bound=1., is_nonnegative=False, prior=("constant", 0.), w_decay=0., sign_value=1., - optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1., - resist_scale=1., batch_size=1, **kwargs): + def __init__( + self, name, shape, eta=0., weight_init=None, bias_init=None, w_bound=1., is_nonnegative=False, + prior=("constant", 0.), w_decay=0., sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1., + resist_scale=1., batch_size=1, **kwargs + ): super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, batch_size=batch_size, **kwargs) diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapseOld.py b/ngclearn/components/synapses/hebbian/hebbianSynapseOld.py deleted file mode 100644 index 04ebd4cb..00000000 --- a/ngclearn/components/synapses/hebbian/hebbianSynapseOld.py +++ /dev/null @@ -1,326 +0,0 @@ -from jax import random, numpy as jnp, jit -from functools import partial -from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn -from ngclearn import resolver, Component, Compartment -from ngclearn.components.synapses import DenseSynapse -from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args - -@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9]) -def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., - prior_type=None, prior_lmbda=0., - pre_wght=1., post_wght=1.): - """ - Compute a tensor of adjustments to be applied to a synaptic value matrix. - - Args: - pre: pre-synaptic statistic to drive Hebbian update - - post: post-synaptic statistic to drive Hebbian update - - W: synaptic weight values (at time t) - - w_bound: maximum value to enforce over newly computed efficacies - - is_nonnegative: (Unused) - - signVal: multiplicative factor to modulate final update by (good for - flipping the signs of a computed synaptic change matrix) - - prior_type: prior type or name (Default: None) - - prior_lmbda: prior parameter (Default: 0.0) - - pre_wght: pre-synaptic weighting term (Default: 1.) - - post_wght: post-synaptic weighting term (Default: 1.) - - Returns: - an update/adjustment matrix, an update adjustment vector (for biases) - """ - _pre = pre * pre_wght - _post = post * post_wght - dW = jnp.matmul(_pre.T, _post) - db = jnp.sum(_post, axis=0, keepdims=True) - dW_reg = 0. - - if w_bound > 0.: - dW = dW * (w_bound - jnp.abs(W)) - - if prior_type == "l2" or prior_type == "ridge": - dW_reg = W - if prior_type == "l1" or prior_type == "lasso": - dW_reg = jnp.sign(W) - if prior_type == "l1l2" or prior_type == "elastic_net": - l1_ratio = prior_lmbda[1] - prior_lmbda = prior_lmbda[0] - dW_reg = jnp.sign(W) * l1_ratio + W * (1-l1_ratio)/2 - - dW = dW + prior_lmbda * dW_reg - return dW * signVal, db * signVal - -@partial(jit, static_argnums=[1,2]) -def _enforce_constraints(W, w_bound, is_nonnegative=True): - """ - Enforces constraints that the (synaptic) efficacies/values within matrix - `W` must adhere to. - - Args: - W: synaptic weight values (at time t) - - w_bound: maximum value to enforce over newly computed efficacies - - is_nonnegative: ensure updated value matrix is strictly non-negative - - Returns: - the newly evolved synaptic weight value matrix - """ - _W = W - if w_bound > 0.: - if is_nonnegative == True: - _W = jnp.clip(_W, 0., w_bound) - else: - _W = jnp.clip(_W, -w_bound, w_bound) - return _W - - -class HebbianSynapse(DenseSynapse): - """ - A synaptic cable that adjusts its efficacies via a two-factor Hebbian - adjustment rule. - - | --- Synapse Compartments: --- - | inputs - input (takes in external signals) - | outputs - output signals (transformation induced by synapses) - | weights - current value matrix of synaptic efficacies - | biases - current value vector of synaptic bias values - | key - JAX PRNG key - | --- Synaptic Plasticity Compartments: --- - | pre - pre-synaptic signal to drive first term of Hebbian update (takes in external signals) - | post - post-synaptic signal to drive 2nd term of Hebbian update (takes in external signals) - | dWeights - current delta matrix containing changes to be applied to synaptic efficacies - | dBiases - current delta vector containing changes to be applied to bias values - | opt_params - locally-embedded optimizer statisticis (e.g., Adam 1st/2nd moments if adam is used) - - Args: - name: the string name of this cell - - shape: tuple specifying shape of this synaptic cable (usually a 2-tuple - with number of inputs by number of outputs) - - eta: global learning rate - - weight_init: a kernel to drive initialization of this synaptic cable's values; - typically a tuple with 1st element as a string calling the name of - initialization to use - - bias_init: a kernel to drive initialization of biases for this synaptic cable - (Default: None, which turns off/disables biases) - - w_bound: maximum weight to softly bound this cable's value matrix to; if - set to 0, then no synaptic value bounding will be applied - - is_nonnegative: enforce that synaptic efficacies are always non-negative - after each synaptic update (if False, no constraint will be applied) - - prior: a kernel to drive prior of this synaptic cable's values; - typically a tuple with 1st element as a string calling the name of - prior to use and 2nd element as a floating point number - calling the prior parameter lambda (Default: (None, 0.)) - currently it supports "l1" or "lasso" or "l2" or "ridge" or "l1l2" or "elastic_net". - usage guide: - prior = ('l1', 0.01) or prior = ('lasso', lmbda) - prior = ('l2', 0.01) or prior = ('ridge', lmbda) - prior = ('l1l2', (0.01, 0.01)) or prior = ('elastic_net', (lmbda, l1_ratio)) - - - - sign_value: multiplicative factor to apply to final synaptic update before - it is applied to synapses; this is useful if gradient descent style - optimization is required (as Hebbian rules typically yield - adjustments for ascent) - - optim_type: optimization scheme to physically alter synaptic values - once an update is computed (Default: "sgd"); supported schemes - include "sgd" and "adam" - - :Note: technically, if "sgd" or "adam" is used but `signVal = 1`, - then the ascent form of each rule is employed (signVal = -1) or - a negative learning rate will mean a descent form of the - `optim_scheme` is being employed - - pre_wght: pre-synaptic weighting factor (Default: 1.) - - post_wght: post-synaptic weighting factor (Default: 1.) - - resist_scale: a fixed scaling factor to apply to synaptic transform - (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b - - p_conn: probability of a connection existing (default: 1.); setting - this to < 1. will result in a sparser synaptic structure - """ - - # Define Functions - @deprecate_args(_rebind=False, w_decay='prior') - def __init__(self, name, shape, eta=0., weight_init=None, bias_init=None, - w_bound=1., is_nonnegative=False, prior=(None, 0.), w_decay=0., sign_value=1., - optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1., - resist_scale=1., batch_size=1, **kwargs): - super().__init__(name, shape, weight_init, bias_init, resist_scale, - p_conn, batch_size=batch_size, **kwargs) - - if w_decay > 0.: - prior = ('l2', w_decay) - - prior_type, prior_lmbda = prior - ## synaptic plasticity properties and characteristics - self.shape = shape - self.Rscale = resist_scale - self.prior_type = prior_type - self.prior_lmbda = prior_lmbda - self.w_bound = w_bound - self.pre_wght = pre_wght - self.post_wght = post_wght - self.eta = eta - self.is_nonnegative = is_nonnegative - self.sign_value = sign_value - - ## optimization / adjustment properties (given learning dynamics above) - self.opt = get_opt_step_fn(optim_type, eta=self.eta) - - # compartments (state of the cell, parameters, will be updated through stateless calls) - self.preVals = jnp.zeros((self.batch_size, shape[0])) - self.postVals = jnp.zeros((self.batch_size, shape[1])) - self.pre = Compartment(self.preVals) - self.post = Compartment(self.postVals) - self.dWeights = Compartment(jnp.zeros(shape)) - self.dBiases = Compartment(jnp.zeros(shape[1])) - - #key, subkey = random.split(self.key.value) - self.opt_params = Compartment(get_opt_init_fn(optim_type)( - [self.weights.value, self.biases.value] - if bias_init else [self.weights.value])) - - @staticmethod - def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, - post_wght, pre, post, weights): - ## calculate synaptic update values - dW, db = _calc_update( - pre, post, weights, w_bound, is_nonnegative=is_nonnegative, - signVal=sign_value, prior_type=prior_type, prior_lmbda=prior_lmbda, pre_wght=pre_wght, - post_wght=post_wght) - return dW, db - - @staticmethod - def _evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, - post_wght, bias_init, pre, post, weights, biases, opt_params): - ## calculate synaptic update values - dWeights, dBiases = HebbianSynapse._compute_update( - w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, - pre, post, weights - ) - ## conduct a step of optimization - get newly evolved synaptic weight value matrix - if bias_init != None: - opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases]) - else: - # ignore db since no biases configured - opt_params, [weights] = opt(opt_params, [weights], [dWeights]) - ## ensure synaptic efficacies adhere to constraints - weights = _enforce_constraints(weights, w_bound, is_nonnegative=is_nonnegative) - return opt_params, weights, biases, dWeights, dBiases - - @resolver(_evolve) - def evolve(self, opt_params, weights, biases, dWeights, dBiases): - self.opt_params.set(opt_params) - self.weights.set(weights) - self.biases.set(biases) - self.dWeights.set(dWeights) - self.dBiases.set(dBiases) - - @staticmethod - def _reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - return ( - preVals, # inputs - postVals, # outputs - preVals, # pre - postVals, # post - jnp.zeros(shape), # dW - jnp.zeros(shape[1]), # db - ) - - @resolver(_reset) - def reset(self, inputs, outputs, pre, post, dWeights, dBiases): - self.inputs.set(inputs) - self.outputs.set(outputs) - self.pre.set(pre) - self.post.set(post) - self.dWeights.set(dWeights) - self.dBiases.set(dBiases) - - @classmethod - def help(cls): ## component help function - properties = { - "synapse_type": "HebbianSynapse - performs an adaptable synaptic " - "transformation of inputs to produce output signals; " - "synapses are adjusted via two-term/factor Hebbian adjustment" - } - compartment_props = { - "inputs": - {"inputs": "Takes in external input signal values", - "pre": "Pre-synaptic statistic for Hebb rule (z_j)", - "post": "Post-synaptic statistic for Hebb rule (z_i)"}, - "states": - {"weights": "Synapse efficacy/strength parameter values", - "biases": "Base-rate/bias parameter values", - "key": "JAX PRNG key"}, - "analytics": - {"dWeights": "Synaptic weight value adjustment matrix produced at time t", - "dBiases": "Synaptic bias/base-rate value adjustment vector produced at time t"}, - "outputs": - {"outputs": "Output of synaptic transformation"}, - } - hyperparams = { - "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", - "batch_size": "Batch size dimension of this component", - "weight_init": "Initialization conditions for synaptic weight (W) values", - "bias_init": "Initialization conditions for bias/base-rate (b) values", - "resist_scale": "Resistance level scaling factor (applied to output of transformation)", - "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", - "is_nonnegative": "Should synapses be constrained to be non-negative post-updates?", - "sign_value": "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0", - "eta": "Global (fixed) learning rate", - "pre_wght": "Pre-synaptic weighting coefficient (q_pre)", - "post_wght": "Post-synaptic weighting coefficient (q_post)", - "w_bound": "Soft synaptic bound applied to synapses post-update", - "prior": "prior name and value for synaptic updating prior", - "optim_type": "Choice of optimizer to adjust synaptic weights" - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "outputs = [(W * Rscale) * inputs] + b ;" - "dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - g(W_{ij}) * prior_lmbda", - "hyperparameters": hyperparams} - return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - Wab = HebbianSynapse("Wab", (2, 3), 0.0004, optim_type='adam', - sign_value=-1.0, prior=("l1l2", 0.001)) - print(Wab) \ No newline at end of file diff --git a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py index d1cfce3e..777c26cc 100755 --- a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py @@ -68,13 +68,14 @@ class TraceSTDPSynapse(DenseSynapse): # power-law / trace-based STDP # Define Functions def __init__( self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., weight_init=None, resist_scale=1., - p_conn=1., w_bound=1., batch_size=1, **kwargs + p_conn=1., w_bound=1., tau_w=0., weight_mask=None, batch_size=1, **kwargs ): super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs) ## Synaptic hyper-parameters self.shape = shape ## shape of synaptic efficacy matrix + self.tau_w = tau_w self.mu = mu ## controls power-scaling of STDP rule self.preTrace_target = pretrace_target ## target (pre-synaptic) trace activity value # 0.7 self.Aplus = A_plus ## LTP strength @@ -82,6 +83,10 @@ def __init__( self.Rscale = resist_scale ## post-transformation scale factor self.w_bound = w_bound #1. ## soft weight constraint self.w_eps = 0. ## w_eps = 0.01 + self.weight_mask = weight_mask + if self.weight_mask is None: + self.weight_mask = jnp.ones((1, 1)) + self.weights.set(self.weights.value * self.weight_mask) ## Compartment setup preVals = jnp.zeros((self.batch_size, shape[0])) @@ -93,6 +98,12 @@ def __init__( self.dWeights = Compartment(self.weights.value * 0) self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate + #@transition(output_compartments=["outputs"]) + #@staticmethod + #def advance_state(Rscale, inputs, weights, biases, weight_mask): + # outputs = (jnp.matmul(inputs, weights * weight_mask) * Rscale) + biases + # return outputs + @staticmethod def _compute_update( dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights @@ -126,16 +137,23 @@ def _compute_update( @transition(output_compartments=["weights", "dWeights"]) @staticmethod def evolve( - dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights, eta + dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_w, preSpike, postSpike, preTrace, + postTrace, weights, eta, weight_mask ): + #_wm = weight_mask # + _wm = (weight_mask != 0.) dWeights = TraceSTDPSynapse._compute_update( dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights ) ## do a gradient ascent update/shift - weights = weights + dWeights * eta + decayTerm = 0. + if tau_w > 0.: + decayTerm = weights / tau_w + weights = weights + (dWeights * eta) - decayTerm #weight_mask * eta) ## enforce non-negativity #w_eps = 0. # 0.01 # 0.001 weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound)) + weights = weights * _wm # weight_mask return weights, dWeights @transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights"]) diff --git a/ngclearn/components/synapses/modulated/MSTDPETSynapse.py b/ngclearn/components/synapses/modulated/MSTDPETSynapse.py index 149108de..6e5dd8c4 100755 --- a/ngclearn/components/synapses/modulated/MSTDPETSynapse.py +++ b/ngclearn/components/synapses/modulated/MSTDPETSynapse.py @@ -61,6 +61,8 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit elg_decay: eligibility decay constant (default: 1) + tau_w: amount of synaptic decay to augment each MSTDP/MSTDP-ET update with + weight_init: a kernel to drive initialization of this synaptic cable's values; typically a tuple with 1st element as a string calling the name of initialization to use @@ -74,7 +76,7 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit # Define Functions def __init__( - self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1., + self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1., tau_w=0., weight_init=None, resist_scale=1., p_conn=1., w_bound=1., batch_size=1, **kwargs ): super().__init__( @@ -82,18 +84,20 @@ def __init__( resist_scale=resist_scale, p_conn=p_conn, w_bound=w_bound, batch_size=batch_size, **kwargs ) self.w_eps = 0. + self.tau_w = tau_w ## MSTDP/MSTDP-ET meta-parameters self.tau_elg = tau_elg self.elg_decay = elg_decay ## MSTDP/MSTDP-ET compartments self.modulator = Compartment(jnp.zeros((self.batch_size, 1))) self.eligibility = Compartment(jnp.zeros(shape)) + self.outmask = Compartment(jnp.zeros((1, shape[1]))) @transition(output_compartments=["weights", "dWeights", "eligibility"]) @staticmethod def evolve( - dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_elg, elg_decay, preSpike, postSpike, preTrace, - postTrace, weights, dWeights, eta, modulator, eligibility + dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_elg, elg_decay, tau_w, preSpike, postSpike, + preTrace, postTrace, weights, dWeights, eta, modulator, eligibility, outmask ): # dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update # dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights @@ -105,21 +109,25 @@ def evolve( else: ## otherwise, just do M-STDP eligibility = dWeights ## dynamics of M-STDP had no eligibility tracing ## do a gradient ascent update/shift - weights = weights + eligibility * modulator * eta ## do modulated update - #''' + decayTerm = 0. + if tau_w > 0.: + decayTerm = weights * (1. / tau_w) + weights = weights + (eligibility * modulator * eta) * outmask - decayTerm ## do modulated update + dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights ) dWeights = dW_dt ## can think of this as eligibility at time t - #''' - + #w_eps = 0.01 weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound)) return weights, dWeights, eligibility @transition( - output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights", "eligibility"] + output_compartments=[ + "inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights", "eligibility", "outmask" + ] ) @staticmethod def reset(batch_size, shape): @@ -134,7 +142,8 @@ def reset(batch_size, shape): postTrace = postVals dWeights = synVals eligibility = synVals - return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights, eligibility + outmask = postVals + 1. + return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights, eligibility, outmask @classmethod def help(cls): ## component help function diff --git a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py index 00a014b0..1415f51a 100644 --- a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py +++ b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py @@ -8,7 +8,7 @@ from ngcsimlib.compilers.process import transition @partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9]) -def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1., +def _calc_update(pre, post, W, mask, w_bound, is_nonnegative=True, signVal=1., prior_type=None, prior_lmbda=0., pre_wght=1., post_wght=1.): """ @@ -21,7 +21,7 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1., W: synaptic weight values (at time t) - w_mask: synaptic weight masking matrix (same shape as W) + mask: synaptic weight masking matrix (same shape as W) w_bound: maximum value to enforce over newly computed efficacies @@ -64,13 +64,13 @@ def _calc_update(pre, post, W, w_mask, w_bound, is_nonnegative=True, signVal=1., dW = dW + prior_lmbda * dW_reg - if w_mask!=None: - dW = dW * w_mask + if mask!=None: + dW = dW * mask return dW * signVal, db * signVal @partial(jit, static_argnums=[1,2, 3]) -def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True): +def _enforce_constraints(W, block_mask, w_bound, is_nonnegative=True): """ Enforces constraints that the (synaptic) efficacies/values within matrix `W` must adhere to. @@ -78,7 +78,7 @@ def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True): Args: W: synaptic weight values (at time t) - w_mask: weight mask matrix + block_mask: weight mask matrix w_bound: maximum value to enforce over newly computed efficacies @@ -94,8 +94,8 @@ def _enforce_constraints(W, w_mask, w_bound, is_nonnegative=True): else: _W = jnp.clip(_W, -w_bound, w_bound) - if w_mask!=None: - _W = _W * w_mask + if block_mask!=None: + _W = _W * block_mask return _W @@ -124,7 +124,7 @@ class HebbianPatchedSynapse(PatchedSynapse): shape: tuple specifying shape of this synaptic cable (usually a 2-tuple with number of inputs by number of outputs) - n_sub_models: The number of submodels in each layer + n_sub_models: The number of submodels in each layer (Default: 1 similar functionality as DenseSynapse) stride_shape: Stride shape of overlapping synaptic weight value matrix (Default: (0, 0)) @@ -138,7 +138,7 @@ class HebbianPatchedSynapse(PatchedSynapse): bias_init: a kernel to drive initialization of biases for this synaptic cable (Default: None, which turns off/disables biases) - w_mask: weight mask matrix + block_mask: weight mask matrix w_bound: maximum weight to softly bound this cable's value matrix to; if set to 0, then no synaptic value bounding will be applied @@ -185,11 +185,11 @@ class HebbianPatchedSynapse(PatchedSynapse): batch_size: the size of each mini batch """ - def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None, - w_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1., + def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None, + block_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1., resist_scale=1., batch_size=1, **kwargs): - super().__init__(name, shape, n_sub_models, stride_shape, w_mask, weight_init, bias_init, resist_scale, + super().__init__(name, shape, n_sub_models, stride_shape, block_mask, weight_init, bias_init, resist_scale, p_conn, batch_size=batch_size, **kwargs) prior_type, prior_lmbda = prior @@ -221,7 +221,7 @@ def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight self.postVals = jnp.zeros((self.batch_size, self.shape[1])) self.pre = Compartment(self.preVals) self.post = Compartment(self.postVals) - self.w_mask = w_mask + self.block_mask = block_mask self.dWeights = Compartment(jnp.zeros(self.shape)) self.dBiases = Compartment(jnp.zeros(self.shape[1])) @@ -231,11 +231,11 @@ def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), eta=0., weight if bias_init else [self.weights.value])) @staticmethod - def _compute_update(w_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, + def _compute_update(block_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, pre, post, weights): ## calculate synaptic update values dW, db = _calc_update( - pre, post, weights, w_mask, w_bound, is_nonnegative=is_nonnegative, + pre, post, weights, block_mask, w_bound, is_nonnegative=is_nonnegative, signVal=sign_value, prior_type=prior_type, prior_lmbda=prior_lmbda, pre_wght=pre_wght, post_wght=post_wght) @@ -243,11 +243,11 @@ def _compute_update(w_mask, w_bound, is_nonnegative, sign_value, prior_type, pri @transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"]) @staticmethod - def evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, + def evolve(block_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, bias_init, pre, post, weights, biases, opt_params): ## calculate synaptic update values dWeights, dBiases = HebbianPatchedSynapse._compute_update( - w_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, + block_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, pre, post, weights ) ## conduct a step of optimization - get newly evolved synaptic weight value matrix @@ -257,7 +257,7 @@ def evolve(w_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_l # ignore db since no biases configured opt_params, [weights] = opt(opt_params, [weights], [dWeights]) ## ensure synaptic efficacies adhere to constraints - weights = _enforce_constraints(weights, w_mask, w_bound, is_nonnegative=is_nonnegative) + weights = _enforce_constraints(weights, block_mask, w_bound, is_nonnegative=is_nonnegative) return opt_params, weights, biases, dWeights, dBiases @transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "dBiases"]) @@ -313,7 +313,7 @@ def help(cls): ## component help function "post_wght": "Post-synaptic weighting coefficient (q_post)", "w_bound": "Soft synaptic bound applied to synapses post-update", "prior": "prior name and value for synaptic updating prior", - "w_mask": "weight mask matrix", + "block_mask": "weight mask matrix", "optim_type": "Choice of optimizer to adjust synaptic weights" } info = {cls.__name__: properties, diff --git a/ngclearn/components/synapses/patched/patchedSynapse.py b/ngclearn/components/synapses/patched/patchedSynapse.py index d0fb2c52..43d1dc16 100644 --- a/ngclearn/components/synapses/patched/patchedSynapse.py +++ b/ngclearn/components/synapses/patched/patchedSynapse.py @@ -65,7 +65,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable shape: tuple specifying shape of this synaptic cable (usually a 2-tuple with number of inputs by number of outputs) - n_sub_models: The number of submodels in each layer + n_sub_models: The number of submodels in each layer (Default: 1 similar functionality as DenseSynapse) stride_shape: Stride shape of overlapping synaptic weight value matrix (Default: (0, 0)) @@ -79,7 +79,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable bias_init: a kernel to drive initialization of biases for this synaptic cable (Default: None, which turns off/disables biases) - w_mask: weight mask matrix + block_mask: weight mask matrix pre_wght: pre-synaptic weighting factor (Default: 1.) @@ -92,7 +92,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable this to < 1. will result in a sparser synaptic structure """ - def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), w_mask=None, weight_init=None, bias_init=None, + def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), block_mask=None, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs): super().__init__(name, **kwargs) @@ -112,7 +112,7 @@ def __init__(self, name, shape, n_sub_models, stride_shape=(0,0), w_mask=None, w weights = create_multi_patch_synapses(key=subkeys, shape=shape, n_sub_models=self.n_sub_models, sub_stride=self.sub_stride, weight_init=self.weight_init) - self.w_mask = jnp.where(weights!=0, 1, 0) + self.block_mask = jnp.where(weights!=0, 1, 0) self.sub_shape = (shape[0]//n_sub_models, shape[1]//n_sub_models) self.shape = weights.shape @@ -192,7 +192,7 @@ def help(cls): ## component help function "weight_init": "Initialization conditions for synaptic weight (W) values", "bias_init": "Initialization conditions for bias/base-rate (b) values", "resist_scale": "Resistance level scaling factor (Rscale); applied to output of transformation", - "w_mask": "weight mask matrix", + "block_mask": "weight mask matrix", "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)" } info = {cls.__name__: properties, diff --git a/ngclearn/components/synapses/staticSynapse.py b/ngclearn/components/synapses/staticSynapse.py index dadc1608..0cf06757 100755 --- a/ngclearn/components/synapses/staticSynapse.py +++ b/ngclearn/components/synapses/staticSynapse.py @@ -27,4 +27,4 @@ class StaticSynapse(DenseSynapse): this to < 1 and > 0. will result in a sparser synaptic structure (lower values yield sparse structure) """ - pass \ No newline at end of file + pass diff --git a/ngclearn/utils/diffeq/ode_utils.py b/ngclearn/utils/diffeq/ode_utils.py index 52211130..30ddb2d4 100755 --- a/ngclearn/utils/diffeq/ode_utils.py +++ b/ngclearn/utils/diffeq/ode_utils.py @@ -112,6 +112,30 @@ def _euler(carry, dfx, dt, params, x_scale=1.): new_carry = (_t, _x) return new_carry, (new_carry, carry) +@partial(jit, static_argnums=(1)) +def _leapfrog(carry, dfq, dt, params): + t, q, p = carry + dq_dt = dfq(t, q, params) + + _p = p + dq_dt * (dt/2.) + _q = q + p * dt + dq_dtpdt = dfq(t+dt, _q, params) + _p = _p + dq_dtpdt * (dt/2.) + _t = t + dt + new_carry = (_t, _q, _p) + return new_carry, (new_carry, carry) + +@partial(jit, static_argnums=(3, 4)) +def leapfrog(t_curr, q_curr, p_curr, dfq, L, step_size, params): + t = t_curr + 0. + q = q_curr + 0. + p = p_curr + 0. + def scanner(carry, _): + return _leapfrog(carry, dfq, step_size, params) + new_values, (xs_next, xs_carry) = _scan(scanner, init=(t, q, p), xs=jnp.arange(L)) + t, q, p = new_values + return t, q, p + @partial(jit, static_argnums=(2)) def step_heun(t, x, dfx, dt, params, x_scale=1.): """ diff --git a/ngclearn/utils/jaxProcess.py b/ngclearn/utils/jaxProcess.py index e8c057de..dd1dabc3 100644 --- a/ngclearn/utils/jaxProcess.py +++ b/ngclearn/utils/jaxProcess.py @@ -1,27 +1,82 @@ +from ngcsimlib.compartment import Compartment from ngcsimlib.compilers.process import Process from jax.lax import scan as _scan from ngcsimlib.logger import warn from jax import numpy as jnp + class JaxProcess(Process): """ - The JaxProcess is a subclass of the ngcsimlib Process class. The - functionality added by this subclass is the use of the jax scanner to run a - process quickly through the use of jax's JIT compiler. + The JaxProcess is a subclass of the ngcsimlib Process class. The + functionality added by this subclass is the use of the jax scanner to run a + process quickly through the use of jax's JIT compiler. """ - def scan(self, compartments_to_monitor=None, - save_state=True, scan_length=None, **kwargs): + + def __init__(self, name): + super().__init__(name) + self._process_scan_method = None + self._monitoring = [] + + def _make_scanner(self): + arg_order = self.get_required_args() + + def _pure(current_state, x): + v = self.pure(current_state, + **{key: value for key, value in zip(arg_order, x)}) + return v, [v[m] for m in self._monitoring] + + return _pure + + def watch(self, compartment): + """ + Adds a compartment to the process to watch during a scan + + Args: + compartment: the compartment to watch + """ + if not isinstance(compartment, Compartment): + warn( + "Jax Process trying to watch a value that is not a compartment") + + self._monitoring.append(compartment.path) + self._process_scan_method = self._make_scanner() + + def clear_watch_list(self): + """ + Clears the watch list so no values are watched + """ + self._monitoring = [] + self._process_scan_method = self._make_scanner() + + def transition(self, transition_call): + """ + Appends to the base transition call to create pure method for use by its + scanner + + Args: + transition_call: the transition being passed into the default process + + Returns: + this JaxProcess instance for chaining + """ + super().transition(transition_call) + self._process_scan_method = self._make_scanner() + return self + + def scan(self, save_state=True, scan_length=None, **kwargs): """ There a quite a few ways to initialize the scan method for the - jaxProcess. To start the straight forward arguments are - "compartments_to_monitor" and "save_state". Monitoring compartments - means at the end of each process cycle record the value of each - compartment in the list and then at the end a tuple of concatenated - values will be returned that correspond to each compartment in the - original list. The save_state flag is simply there to note if the state + jaxProcess. To start the straight forward arguments is "save_state". + The save_state flag is simply there to note if the state of the model should reflect the final state of the model after the scan is complete. + This scan method can also watch and report intermediate compartment + values defined through calling the JaxProcess.watch() method watching a + compartment means at the end of each process cycle record the value of + the compartment and then at the end a tuple of concatenated values will + be returned that correspond to each compartment the process is watching. + Where there are options for the arguments is when defining the keyword arguments for the process. The process will do its best to broadcast all the inputs to the largest size, so they can be scanned over. This means @@ -39,18 +94,16 @@ def scan(self, compartments_to_monitor=None, Args: - compartments_to_monitor: A list of compartments to monitor - save_state: A boolean flag to indicate if the model state should be - saved - scan_length: a value to be used to denote the number of iterations - of the scanner if all keyword arguments are passed as ints or floats + save_state: A boolean flag to indicate if the model state should be saved + + scan_length: a value to be used to denote the number of iterations of the scanner if all keyword + arguments are passed as ints or floats + **kwargs: the required keyword arguments for the process to run Returns: the final state of the model, the stacked output of the scan method """ - if compartments_to_monitor is None: - compartments_to_monitor = [] arg_order = list(self.get_required_args()) args = [] @@ -91,7 +144,7 @@ def scan(self, compartments_to_monitor=None, max_next_axis = 0 new_args = [] for a in args: - if len(a.shape) >= axis+1: + if len(a.shape) >= axis + 1: if a.shape[axis] == current_axis: new_args.append(a) else: @@ -99,20 +152,20 @@ def scan(self, compartments_to_monitor=None, "broadcasted to the largest shape") return else: - new_args.append(jnp.zeros(list(a.shape) + [current_axis], dtype=a.dtype) + a.reshape(*a.shape, 1)) + new_args.append(jnp.zeros(list(a.shape) + [current_axis], + dtype=a.dtype) + a.reshape( + *a.shape, 1)) - if len(a.shape) > axis+1: - max_next_axis = max(max_next_axis, a.shape[axis+1]) + if len(a.shape) > axis + 1: + max_next_axis = max(max_next_axis, a.shape[axis + 1]) args = new_args - args = jnp.array(args).transpose([1, 0] + [i for i in range(2, max_axis+1)]) - - def _pure(current_state, x): - v = self.pure(current_state, **{key: value for key, value in zip(arg_order, x)}) - return v, [v[c.path] for c in compartments_to_monitor] - - vals, stacked = _scan(_pure, init=self.get_required_state(include_special_compartments=True), xs=args) + args = jnp.array(args).transpose( + [1, 0] + [i for i in range(2, max_axis + 1)]) + state, stacked = _scan(self._process_scan_method, + init=self.get_required_state( + include_special_compartments=True), xs=args) if save_state: - self.updated_modified_state(vals) - return vals, stacked + self.updated_modified_state(state) + return state, stacked diff --git a/ngclearn/utils/patch_utils.py b/ngclearn/utils/patch_utils.py index 74cf5641..f3116e84 100755 --- a/ngclearn/utils/patch_utils.py +++ b/ngclearn/utils/patch_utils.py @@ -118,7 +118,7 @@ def create_patches(self, add_frame=False, center=True): -def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True, seed=1234): ## scikit +def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True, seed=1234, vis_mode=False): ## scikit """ Generates a set of patches from an array/list of image arrays (via random sampling with replacement). This uses scikit-learn's patch creation @@ -151,10 +151,16 @@ def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True, p_batch = np.concatenate((p_batch,patches),axis=0) else: p_batch = patches + + mu = 0 if center: ## center patches by subtracting out their means mu = np.mean(p_batch, axis=1, keepdims=True) p_batch = p_batch - mu - return jnp.array(p_batch) + if vis_mode: + return jnp.array(p_batch), mu + else: + return jnp.array(p_batch) + def generate_pacthify_patch_set(x_batch_, patch_size=(5, 5), center=True): ## patchify ## this is a patchify-specific function (only use if you have patchify installed...) diff --git a/ngclearn/utils/viz/dim_reduce.py b/ngclearn/utils/viz/dim_reduce.py index 4f9095dd..4fd8c244 100755 --- a/ngclearn/utils/viz/dim_reduce.py +++ b/ngclearn/utils/viz/dim_reduce.py @@ -29,7 +29,7 @@ def extract_pca_latents(vectors): ## PCA mapping routine z_2D = vectors return z_2D -def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32): ## tSNE mapping routine +def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): ## tSNE mapping routine """ Projects collection of K vectors (stored in a matrix) to a two-dimensional (2D) visualization space via the t-distributed stochastic neighbor embedding @@ -42,10 +42,13 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32): ## tSNE mapping perplexity: the perplexity control factor for t-SNE (Default: 30) + batch_size: number of sampled embedding vectors to use per iteration + of online internal PCA + Returns: a matrix (K x 2) of projected vectors (to 2D space) """ - batch_size = 500 #50 + #batch_size = 500 #50 z_dim = vectors.shape[1] z_2D = None if z_dim != 2: diff --git a/pyproject.toml b/pyproject.toml index 885b1940..34249491 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "ngclearn" -version = "2.0.0" +version = "2.0.1" description = "Simulation software for building and analyzing arbitrary predictive coding, spiking network, and biomimetic neural systems." authors = [ {name = "Alexander Ororbia", email = "ago@cs.rit.edu"}, @@ -14,13 +14,13 @@ readme = "README.md" keywords = ['python', 'ngc-learn', 'predictive-processing', 'predictive-coding', 'neuro-ai', 'jax', 'spiking-neural-networks', 'biomimetics', 'bionics', 'computational-neuroscience'] requires-python = ">=3.10" #3.8 -license = {text = "BSD-3-Clause License"} +license = "BSD-3-Clause" # {text = "BSD-3-Clause License"} classifiers=[ "Development Status :: 4 - Beta", #3 - Alpha", # 5 - Production/Stable "Intended Audience :: Education", "Intended Audience :: Science/Research", "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", + #"License :: OSI Approved :: BSD License", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Scientific/Engineering :: Artificial Intelligence", diff --git a/requirements.txt b/requirements.txt index 06df6260..26351a9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy>=1.24.0 -scikit-learn>=0.24.2 +scikit-learn>=1.6.1 scipy>=1.7.0 matplotlib>=3.8.0 patchify diff --git a/tests/components/synapses/test_exponentialSynapse.py b/tests/components/synapses/test_exponentialSynapse.py new file mode 100644 index 00000000..83ad19ee --- /dev/null +++ b/tests/components/synapses/test_exponentialSynapse.py @@ -0,0 +1,55 @@ +from jax import numpy as jnp, random, jit +import numpy as np +np.random.seed(42) +from ngclearn.components import ExponentialSynapse + +from ngcsimlib.compilers.process import Process +from ngcsimlib.context import Context +import ngclearn.utils.weight_distribution as dist + +def test_exponentialSynapse1(): + name = "expsyn_ctx" + ## create seeding keys + dkey = random.PRNGKey(1234) + dkey, *subkeys = random.split(dkey, 6) + dt = 1. # ms + ## excitatory properties + tau_syn = 2. + E_rest = 0. + # ---- build a single exp-synapse system ---- + with Context(name) as ctx: + a = ExponentialSynapse( + name="a", shape=(1,1), tau_decay=tau_syn, g_syn_bar=2.4, syn_rest=E_rest, weight_init=dist.constant(value=1.), + key=subkeys[0] + ) + + advance_process = (Process("advance_proc") + >> a.advance_state) + # ctx.wrap_and_add_command(advance_process.pure, name="run") + ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + + reset_process = (Process("reset_proc") + >> a.reset) + ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + + sp_train = jnp.array([1., 0., 1.], dtype=jnp.float32) + post_syn_neuron_volt = jnp.ones((1, 1)) * -65. ## post-syn neuron is at rest + + outs_truth = jnp.array([[156., 78., 195.]]) + + outs = [] + ctx.reset() + for t in range(3): + in_pulse = jnp.expand_dims(sp_train[t], axis=0) + a.inputs.set(in_pulse) + a.v.set(post_syn_neuron_volt) + ctx.run(t=t * dt, dt=dt) + #print("g: ",a.g_syn.value) + #print("i: ", a.i_syn.value) + outs.append(a.outputs.value) + outs = jnp.concatenate(outs, axis=1) + #print(outs) + + np.testing.assert_allclose(outs, outs_truth, atol=1e-8) + +#test_exponentialSynapse1()