diff --git a/README.md b/README.md
index 7e43e7fa..1e7412fd 100644
--- a/README.md
+++ b/README.md
@@ -122,7 +122,7 @@ $ python install -e .
**Version:**
-2.0.0
+2.0.1
Author:
Alexander G. Ororbia II
diff --git a/docs/museum/index.rst b/docs/museum/index.rst
index 6f534770..25f75dfc 100644
--- a/docs/museum/index.rst
+++ b/docs/museum/index.rst
@@ -18,3 +18,4 @@ relevant, referenced publicly available ngc-learn simulation code.
snn_dc
snn_bfa
sindy
+ rl_snn
diff --git a/docs/museum/rl_snn.md b/docs/museum/rl_snn.md
new file mode 100644
index 00000000..f0e2cb0d
--- /dev/null
+++ b/docs/museum/rl_snn.md
@@ -0,0 +1,35 @@
+# Reinforcement Learning through a Spiking Controller
+
+In this exhibit, we will see how to construct a simple biophysical model for
+reinforcement learning with a spiking neural network and modulated
+spike-timing-dependent plasticity.
+This model incorporates a mechanisms from several different models, including
+the constrained RL-centric SNN of [1] as well as the simplifications
+made with respect to the model of [2]. The model code for this
+exhibit can be found
+[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/rl_snn).
+
+## Modeling Operant Conditioning through Modulation
+
+
+### Reward-Modulated Spike-Timing-Dependent Plasticity (R-STDP)
+
+
+## The Spiking Neural Circuit Model
+
+
+### Neuronal Dynamics
+
+
+## Running the RL-SNN Model
+
+
+
+## References
+[1] Chevtchenko, Sérgio F., and Teresa B. Ludermir. "Learning from sparse
+and delayed rewards with a multilayer spiking neural network." 2020 International
+Joint Conference on Neural Networks (IJCNN). IEEE, 2020.
+[2] Diehl, Peter U., and Matthew Cook. "Unsupervised learning of digit
+recognition using spike-timing-dependent plasticity." Frontiers in computational
+neuroscience 9 (2015): 99.
+
diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py
index 114d3785..29b5f267 100755
--- a/ngclearn/components/neurons/graded/gaussianErrorCell.py
+++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py
@@ -65,6 +65,12 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
self.mask = Compartment(restVals + 1.0)
+ @staticmethod
+ def eval_log_density(target, mu, Sigma):
+ _dmu = (target - mu)
+ log_density = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
+ return log_density
+
@transition(output_compartments=["dmu", "dtarget", "dSigma", "L", "mask"])
@staticmethod
def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output
@@ -79,6 +85,7 @@ def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian e
dtarget = -dmu # reverse of e
dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma
L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
+ #L = GaussianErrorCell.eval_log_density(target, mu, Sigma)
dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
dtarget = dtarget * modulator * mask
diff --git a/ngclearn/utils/diffeq/ode_utils.py b/ngclearn/utils/diffeq/ode_utils.py
index 52211130..30ddb2d4 100755
--- a/ngclearn/utils/diffeq/ode_utils.py
+++ b/ngclearn/utils/diffeq/ode_utils.py
@@ -112,6 +112,30 @@ def _euler(carry, dfx, dt, params, x_scale=1.):
new_carry = (_t, _x)
return new_carry, (new_carry, carry)
+@partial(jit, static_argnums=(1))
+def _leapfrog(carry, dfq, dt, params):
+ t, q, p = carry
+ dq_dt = dfq(t, q, params)
+
+ _p = p + dq_dt * (dt/2.)
+ _q = q + p * dt
+ dq_dtpdt = dfq(t+dt, _q, params)
+ _p = _p + dq_dtpdt * (dt/2.)
+ _t = t + dt
+ new_carry = (_t, _q, _p)
+ return new_carry, (new_carry, carry)
+
+@partial(jit, static_argnums=(3, 4))
+def leapfrog(t_curr, q_curr, p_curr, dfq, L, step_size, params):
+ t = t_curr + 0.
+ q = q_curr + 0.
+ p = p_curr + 0.
+ def scanner(carry, _):
+ return _leapfrog(carry, dfq, step_size, params)
+ new_values, (xs_next, xs_carry) = _scan(scanner, init=(t, q, p), xs=jnp.arange(L))
+ t, q, p = new_values
+ return t, q, p
+
@partial(jit, static_argnums=(2))
def step_heun(t, x, dfx, dt, params, x_scale=1.):
"""