diff --git a/README.md b/README.md index 7e43e7fa..1e7412fd 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ $ python install -e . **Version:**
-2.0.0 +2.0.1 Author: Alexander G. Ororbia II
diff --git a/docs/museum/index.rst b/docs/museum/index.rst index 6f534770..25f75dfc 100644 --- a/docs/museum/index.rst +++ b/docs/museum/index.rst @@ -18,3 +18,4 @@ relevant, referenced publicly available ngc-learn simulation code. snn_dc snn_bfa sindy + rl_snn diff --git a/docs/museum/rl_snn.md b/docs/museum/rl_snn.md new file mode 100644 index 00000000..f0e2cb0d --- /dev/null +++ b/docs/museum/rl_snn.md @@ -0,0 +1,35 @@ +# Reinforcement Learning through a Spiking Controller + +In this exhibit, we will see how to construct a simple biophysical model for +reinforcement learning with a spiking neural network and modulated +spike-timing-dependent plasticity. +This model incorporates a mechanisms from several different models, including +the constrained RL-centric SNN of [1] as well as the simplifications +made with respect to the model of [2]. The model code for this +exhibit can be found +[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/rl_snn). + +## Modeling Operant Conditioning through Modulation + + +### Reward-Modulated Spike-Timing-Dependent Plasticity (R-STDP) + + +## The Spiking Neural Circuit Model + + +### Neuronal Dynamics + + +## Running the RL-SNN Model + + + +## References +[1] Chevtchenko, Sérgio F., and Teresa B. Ludermir. "Learning from sparse +and delayed rewards with a multilayer spiking neural network." 2020 International +Joint Conference on Neural Networks (IJCNN). IEEE, 2020.
+[2] Diehl, Peter U., and Matthew Cook. "Unsupervised learning of digit +recognition using spike-timing-dependent plasticity." Frontiers in computational +neuroscience 9 (2015): 99. + diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py index 114d3785..29b5f267 100755 --- a/ngclearn/components/neurons/graded/gaussianErrorCell.py +++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py @@ -65,6 +65,12 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs): self.modulator = Compartment(restVals + 1.0) # to be set/consumed self.mask = Compartment(restVals + 1.0) + @staticmethod + def eval_log_density(target, mu, Sigma): + _dmu = (target - mu) + log_density = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) + return log_density + @transition(output_compartments=["dmu", "dtarget", "dSigma", "L", "mask"]) @staticmethod def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output @@ -79,6 +85,7 @@ def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian e dtarget = -dmu # reverse of e dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) + #L = GaussianErrorCell.eval_log_density(target, mu, Sigma) dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance... dtarget = dtarget * modulator * mask diff --git a/ngclearn/utils/diffeq/ode_utils.py b/ngclearn/utils/diffeq/ode_utils.py index 52211130..30ddb2d4 100755 --- a/ngclearn/utils/diffeq/ode_utils.py +++ b/ngclearn/utils/diffeq/ode_utils.py @@ -112,6 +112,30 @@ def _euler(carry, dfx, dt, params, x_scale=1.): new_carry = (_t, _x) return new_carry, (new_carry, carry) +@partial(jit, static_argnums=(1)) +def _leapfrog(carry, dfq, dt, params): + t, q, p = carry + dq_dt = dfq(t, q, params) + + _p = p + dq_dt * (dt/2.) + _q = q + p * dt + dq_dtpdt = dfq(t+dt, _q, params) + _p = _p + dq_dtpdt * (dt/2.) + _t = t + dt + new_carry = (_t, _q, _p) + return new_carry, (new_carry, carry) + +@partial(jit, static_argnums=(3, 4)) +def leapfrog(t_curr, q_curr, p_curr, dfq, L, step_size, params): + t = t_curr + 0. + q = q_curr + 0. + p = p_curr + 0. + def scanner(carry, _): + return _leapfrog(carry, dfq, step_size, params) + new_values, (xs_next, xs_carry) = _scan(scanner, init=(t, q, p), xs=jnp.arange(L)) + t, q, p = new_values + return t, q, p + @partial(jit, static_argnums=(2)) def step_heun(t, x, dfx, dt, params, x_scale=1.): """