Skip to content

Commit 7c56b47

Browse files
author
Alexander Ororbia
committed
revised dyn/chem-syn neurocog doc, cleaned up dynamic syn
1 parent 588e3f5 commit 7c56b47

File tree

5 files changed

+52
-107
lines changed

5 files changed

+52
-107
lines changed

docs/tutorials/neurocog/dynamic_synapses.md

Lines changed: 49 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
In this lesson, we will study dynamic synapses, or synaptic cable components in
44
ngc-learn that evolve on fast time-scales in response to their pre-synaptic inputs.
55
These types of chemical synapse components are useful for modeling time-varying
6-
conductance which ultimately drives eletrical current input into neuronal units
6+
conductance which ultimately drives electrical current input into neuronal units
77
(such as spiking cells). Here, we will learn how to build three important types of dynamic synapses in
88
ngc-learn -- the exponential, the alpha, and the double-exponential synapse -- and visualize
99
the time-course of their resulting conductances. In addition, we will then
@@ -24,17 +24,14 @@ value matrices we might initially employ (as in synapse components such as the
2424
Building a dynamic synapse can be done by importing the [exponential synapse](ngclearn.components.synapses.exponentialSynapse),
2525
the [double-exponential synapse](ngclearn.components.synapses.doubleExpSynapse), or the [alpha synapse](ngclearn.components.synapses.alphaSynapse) from ngc-learn's in-built components and setting them up within a model context for easy analysis. Go ahead and create a Python script named `probe_synapses.py` to place
2626
the code you will write within.
27-
For the first part of this lesson, we will import all three dynamic synpapse models and compare their behavior.
27+
For the first part of this lesson, we will import all three dynamic synapse models and compare their behavior.
2828
This can be done as follows (using the meta-parameters we provide in the code block below to ensure reasonable dynamics):
2929

3030
```python
3131
from jax import numpy as jnp, random, jit
32-
from ngcsimlib.context import Context
33-
from ngclearn.components import ExponentialSynapse, AlphaSynapse, DoupleExpSynapse
34-
35-
from ngcsimlib.compilers.process import Process
36-
from ngcsimlib.context import Context
37-
import ngclearn.utils.weight_distribution as dist
32+
from ngclearn import Context, MethodProcess
33+
from ngclearn.components import ExponentialSynapse, AlphaSynapse, DoubleExpSynapse
34+
from ngclearn.utils.distribution_generator import DistributionGenerator
3835

3936

4037
dkey = random.PRNGKey(1234) ## creating seeding keys for synapses
@@ -46,29 +43,27 @@ T = 8. # ms ## total duration time
4643
with Context("dual_syn_system") as ctx:
4744
Wexp = ExponentialSynapse( ## exponential dynamic synapse
4845
name="Wexp", shape=(1, 1), tau_decay=3., g_syn_bar=1., syn_rest=0., resist_scale=1.,
49-
weight_init=dist.constant(value=1.), key=subkeys[0]
46+
weight_init=DistributionGenerator.constant(value=1.), key=subkeys[0]
5047
)
5148
Walpha = AlphaSynapse( ## alpha dynamic synapse
5249
name="Walpha", shape=(1, 1), tau_decay=1., g_syn_bar=1., syn_rest=0., resist_scale=1.,
53-
weight_init=dist.constant(value=1.), key=subkeys[0]
50+
weight_init=DistributionGenerator.constant(value=1.), key=subkeys[0]
5451
)
55-
Wexp2 = DoupleExpSynapse(
52+
Wexp2 = DoubleExpSynapse(
5653
name="Wexp2", shape=(1, 1), tau_rise=1., tau_decay=3., g_syn_bar=1., syn_rest=0., resist_scale=1.,
57-
weight_init=dist.constant(value=1.), key=subkeys[0]
54+
weight_init=DistributionGenerator.constant(value=1.), key=subkeys[0]
5855
)
5956

6057
## set up basic simulation process calls
61-
advance_process = (Process("advance_proc")
58+
advance_process = (MethodProcess("advance_proc")
6259
>> Wexp.advance_state
6360
>> Walpha.advance_state
6461
>> Wexp2.advance_state)
65-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
6662

67-
reset_process = (Process("reset_proc")
63+
reset_process = (MethodProcess("reset_proc")
6864
>> Wexp.reset
6965
>> Walpha.reset
7066
>> Wexp2.reset)
71-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
7267
```
7368

7469
where we notice in the above we have instantiated three different kinds of chemical synapse components
@@ -90,7 +85,7 @@ $$
9085
$$
9186

9287
where the conductance (for a post-synaptic unit) output of this synapse is driven by a sum over all of its incoming
93-
pre-synaptic spikes; this ODE means that pre-synaptic spikes are filtered via an expoential kernel (i.e., a low-pass filter).
88+
pre-synaptic spikes; this ODE means that pre-synaptic spikes are filtered via an exponential kernel (i.e., a low-pass filter).
9489
On the other hand, for the alpha synapse, the dynamics adhere to the following coupled set of ODEs:
9590

9691
$$
@@ -100,7 +95,7 @@ $$
10095

10196
where $h_{\text{syn}}(t)$ is an intermediate variable that operates in service of driving the conductance variable $g_{\text{syn}}(t)$ itself.
10297
The double-exponential (or difference of exponentials) synapse model looks similar to the alpha synapse except that the
103-
rise and fall/decay of its condutance dynamics are set separately using two different time constants, i.e., $\tau_{\text{rise}}$ and $\tau_{\text{decay}}$,
98+
rise and fall/decay of its conductance dynamics are set separately using two different time constants, i.e., $\tau_{\text{rise}}$ and $\tau_{\text{decay}}$,
10499
as follows:
105100

106101
$$
@@ -128,29 +123,31 @@ time_span = []
128123
g = []
129124
ga = []
130125
gexp2 = []
131-
ctx.reset()
126+
reset_process.run()
132127
Tsteps = int(T/dt) + 1
133128
for t in range(Tsteps):
134129
s_t = jnp.zeros((1, 1))
135130
if t * dt == 1.: ## pulse at 1 ms
136131
s_t = jnp.ones((1, 1))
137132
Wexp.inputs.set(s_t)
138133
Walpha.inputs.set(s_t)
139-
Wexp.v.set(Wexp.v.value * 0)
134+
Wexp.v.set(Wexp.v.get() * 0)
140135
Wexp2.inputs.set(s_t)
141-
Walpha.v.set(Walpha.v.value * 0)
142-
Wexp2.v.set(Wexp2.v.value * 0)
143-
ctx.run(t=t * dt, dt=dt)
144-
145-
print(f"\r g = {Wexp.g_syn.value} ga = {Walpha.g_syn.value} gexp2 = {Wexp2.g_syn.value}", end="")
146-
g.append(Wexp.g_syn.value)
147-
ga.append(Walpha.g_syn.value)
136+
Walpha.v.set(Walpha.v.get() * 0)
137+
Wexp2.v.set(Wexp2.v.get() * 0)
138+
advance_process.run(t=t * dt, dt=dt)
139+
140+
print(f"\r g = {Wexp.g_syn.get()} ga = {Walpha.g_syn.get()} gexp2 = {Wexp2.g_syn.get()}", end="")
141+
g.append(Wexp.g_syn.get())
142+
ga.append(Walpha.g_syn.get())
143+
gexp2.append(Wexp2.g_syn.get())
148144
time_span.append(t) #* dt)
149145
print()
150146
g = jnp.squeeze(jnp.concatenate(g, axis=1))
151147
g = g/jnp.amax(g)
152148
ga = jnp.squeeze(jnp.concatenate(ga, axis=1))
153149
ga = ga/jnp.amax(ga)
150+
gexp2 = jnp.squeeze(jnp.concatenate(gexp2, axis=1))
154151
gexp2 = gexp2/jnp.amax(gexp2)
155152
```
156153

@@ -195,6 +192,9 @@ ax.grid(which="major")
195192
fig.savefig("alpha_syn.jpg")
196193
plt.close()
197194

195+
## ---- plot the double-exponential synapse conductance time-course ----
196+
fig, ax = plt.subplots()
197+
198198
gvals = ax.plot(time_span, gexp2, '-', color='tab:blue')
199199
#plt.xticks(time_span, time_labs)
200200
ax.set_xticks(time_ticks, time_labs)
@@ -207,7 +207,7 @@ plt.close()
207207
```
208208

209209
which should produce and save three plots to disk. You can then compare and contrast the plots of the
210-
expoential, alpha synapse, and double-exponential conductance trajectories:
210+
exponential, alpha synapse, and double-exponential conductance trajectories:
211211

212212
```{eval-rst}
213213
.. table::
@@ -222,7 +222,7 @@ expoential, alpha synapse, and double-exponential conductance trajectories:
222222

223223
Note that the alpha synapse (right figure) would produce a more realistic fit to recorded synaptic currents (as it attempts to model
224224
the rise and fall of current in a less simplified manner) at the cost of extra compute, given it uses two ODEs to
225-
emulate condutance, as opposed to the faster yet less-biophysically-realistic exponential synapse (left figure).
225+
emulate conductance, as opposed to the faster yet less-biophysically-realistic exponential synapse (left figure).
226226

227227
## Excitatory-Inhibitory Driven Dynamics
228228

@@ -243,13 +243,10 @@ We will specifically model the excitatory and inhibitory conductance changes usi
243243

244244
```python
245245
from jax import numpy as jnp, random, jit
246-
from ngcsimlib.context import Context
246+
from ngclearn import Context, MethodProcess
247+
from ngclearn.operations import Summation
247248
from ngclearn.components import ExponentialSynapse, PoissonCell, LIFCell
248-
from ngclearn.operations import summation
249-
250-
from ngcsimlib.compilers.process import Process
251-
from ngcsimlib.context import Context
252-
import ngclearn.utils.weight_distribution as dist
249+
from ngclearn.utils.distribution_generator import DistributionGenerator
253250

254251
## create seeding keys
255252
dkey = random.PRNGKey(1234)
@@ -287,39 +284,36 @@ with Context("ei_snn") as ctx:
287284
pre_inh = PoissonCell("pre_inh", n_units=n_inh, target_freq=inh_freq, key=subkeys[1]) ## pre-syn inhibitory group
288285
Wexc = ExponentialSynapse( ## dynamic synapse between excitatory group and LIF
289286
name="Wexc", shape=(n_exc,1), tau_decay=tau_syn_exc, g_syn_bar=g_e_bar, syn_rest=E_rest_exc, resist_scale=1./g_L,
290-
weight_init=dist.constant(value=1.), key=subkeys[2]
287+
weight_init=DistributionGenerator.constant(value=1.), key=subkeys[2]
291288
)
292289
Winh = ExponentialSynapse( ## dynamic synapse between inhibitory group and LIF
293290
name="Winh", shape=(n_inh, 1), tau_decay=tau_syn_inh, g_syn_bar=g_i_bar, syn_rest=E_rest_inh, resist_scale=1./g_L,
294-
weight_init=dist.constant(value=1.), key=subkeys[2]
291+
weight_init=DistributionGenerator.constant(value=1.), key=subkeys[2]
295292
)
296293
post_exc = LIFCell( ## post-syn LIF cell
297294
"post_exc", n_units=1, tau_m=tau_m, resist_m=1., thr=v_thr, v_rest=v_rest, conduct_leak=1., v_reset=-75.,
298295
tau_theta=0., theta_plus=0., refract_time=2., key=subkeys[3]
299296
)
300297

301-
Wexc.inputs << pre_exc.outputs
302-
Winh.inputs << pre_inh.outputs
303-
Wexc.v << post_exc.v ## couple voltage to exc synapse
304-
Winh.v << post_exc.v ## couple voltage to inh synapse
305-
post_exc.j << summation(Wexc.i_syn, Winh.i_syn) ## sum together excitatory & inhibitory pressures
298+
pre_exc.outputs >> Wexc.inputs
299+
pre_inh.outputs >> Winh.inputs
300+
post_exc.v >> Wexc.v ## couple voltage to exc synapse
301+
post_exc.v >> Winh.v ## couple voltage to inh synapse
302+
Summation(Wexc.i_syn, Winh.i_syn) >> post_exc.j ## sum together excitatory & inhibitory pressures
306303

307-
advance_process = (Process("advance_proc")
304+
advance_process = (MethodProcess("advance_proc")
308305
>> pre_exc.advance_state
309306
>> pre_inh.advance_state
310307
>> Wexc.advance_state
311308
>> Winh.advance_state
312309
>> post_exc.advance_state)
313-
# ctx.wrap_and_add_command(advance_process.pure, name="run")
314-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
315310

316-
reset_process = (Process("reset_proc")
311+
reset_process = (MethodProcess("reset_proc")
317312
>> pre_exc.reset
318313
>> pre_inh.reset
319314
>> Wexc.reset
320315
>> Winh.reset
321316
>> post_exc.reset)
322-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
323317
```
324318

325319
### Examining the Simple Spiking Circuit's Behavior
@@ -331,18 +325,18 @@ volts = []
331325
time_span = []
332326
spikes = []
333327

334-
ctx.reset()
328+
reset_process.run()
335329
pre_exc.inputs.set(jnp.ones((1, n_exc)))
336330
pre_inh.inputs.set(jnp.ones((1, n_inh)))
337-
post_exc.v.set(post_exc.v.value * 0 - 65.) ## initial condition for LIF is -65 mV
338-
volts.append(post_exc.v.value)
331+
post_exc.v.set(post_exc.v.get() * 0 - 65.) ## initial condition for LIF is -65 mV
332+
volts.append(post_exc.v.get())
339333
time_span.append(0.)
340334
Tsteps = int(T/dt) + 1
341335
for t in range(1, Tsteps):
342-
ctx.run(t=t * dt, dt=dt)
343-
print(f"\r v {post_exc.v.value}", end="")
344-
volts.append(post_exc.v.value)
345-
spikes.append(post_exc.s.value)
336+
advance_process.run(t=t * dt, dt=dt)
337+
print(f"\r v {post_exc.v.get()}", end="")
338+
volts.append(post_exc.v.get())
339+
spikes.append(post_exc.s.get())
346340
time_span.append(t) #* dt)
347341
print()
348342
volts = jnp.squeeze(jnp.concatenate(volts, axis=1))
@@ -384,9 +378,7 @@ ax.grid()
384378
fig.savefig("ei_circuit_dynamics.jpg")
385379
```
386380

387-
which should produce a figure depicting dynamics similar to the one below. Black tick
388-
marks indicate post-synaptic pulses whereas the horizontal dashed blue shows the LIF unit's
389-
voltage threshold.
381+
which should produce a figure depicting dynamics similar to the one below. Black tick marks indicate post-synaptic pulses whereas the horizontal dashed blue shows the LIF unit's voltage threshold.
390382

391383

392384
```{eval-rst}

ngclearn/components/input_encoders/phasorCell.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ class PhasorCell(JaxComponent):
3232
"""
3333

3434
def __init__(
35-
self, name, n_units, target_freq=63.75, batch_size=1, disable_phasor=False, **kwargs):
35+
self, name, n_units, target_freq=63.75, batch_size=1, disable_phasor=False, **kwargs
36+
):
3637
super().__init__(name, **kwargs)
3738

3839
## Phasor meta-parameters

ngclearn/components/synapses/alphaSynapse.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn.utils.weight_distribution import initialize_params
3-
from ngcsimlib.logger import info
42

53
from ngclearn.components.synapses import DenseSynapse
64
from ngcsimlib.compartment import Compartment
@@ -115,20 +113,6 @@ def reset(self):
115113
self.h_syn.set(postVals)
116114
self.v.set(postVals)
117115

118-
# def save(self, directory, **kwargs):
119-
# file_name = directory + "/" + self.name + ".npz"
120-
# if self.bias_init != None:
121-
# jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
122-
# else:
123-
# jnp.savez(file_name, weights=self.weights.value)
124-
#
125-
# def load(self, directory, **kwargs):
126-
# file_name = directory + "/" + self.name + ".npz"
127-
# data = jnp.load(file_name)
128-
# self.weights.set(data['weights'])
129-
# if "biases" in data.keys():
130-
# self.biases.set(data['biases'])
131-
132116
@classmethod
133117
def help(cls): ## component help function
134118
properties = {

ngclearn/components/synapses/doubleExpSynapse.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn.utils.weight_distribution import initialize_params
3-
from ngcsimlib.logger import info
42

53
from ngclearn.components.synapses import DenseSynapse
64
from ngcsimlib.compartment import Compartment
@@ -85,7 +83,7 @@ def __init__(
8583
self.weights.set(self.weights.get() * 0 + 1.)
8684

8785
@compilable
88-
def advance_state(self, t, dt): #dt, tau_decay, tau_rise, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v
86+
def advance_state(self, t, dt):
8987
s = self.inputs.get()
9088
#A = tau_decay/(tau_decay - tau_rise) * jnp.power((tau_rise/tau_decay), tau_rise/(tau_rise - tau_decay))
9189
A = 1. ## FIXME: scale factor to use?
@@ -121,20 +119,6 @@ def reset(self):
121119
self.h_syn.set(postVals)
122120
self.v.set(postVals)
123121

124-
# def save(self, directory, **kwargs):
125-
# file_name = directory + "/" + self.name + ".npz"
126-
# if self.bias_init != None:
127-
# jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
128-
# else:
129-
# jnp.savez(file_name, weights=self.weights.value)
130-
#
131-
# def load(self, directory, **kwargs):
132-
# file_name = directory + "/" + self.name + ".npz"
133-
# data = jnp.load(file_name)
134-
# self.weights.set(data['weights'])
135-
# if "biases" in data.keys():
136-
# self.biases.set(data['biases'])
137-
138122
@classmethod
139123
def help(cls): ## component help function
140124
properties = {

ngclearn/components/synapses/exponentialSynapse.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn.utils.weight_distribution import initialize_params
3-
from ngcsimlib.logger import info
42

53
from ngclearn.components.synapses import DenseSynapse
64
from ngcsimlib.compartment import Compartment
@@ -107,20 +105,6 @@ def reset(self):
107105
self.g_syn.set(postVals)
108106
self.v.set(postVals)
109107

110-
# def save(self, directory, **kwargs):
111-
# file_name = directory + "/" + self.name + ".npz"
112-
# if self.bias_init != None:
113-
# jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
114-
# else:
115-
# jnp.savez(file_name, weights=self.weights.value)
116-
#
117-
# def load(self, directory, **kwargs):
118-
# file_name = directory + "/" + self.name + ".npz"
119-
# data = jnp.load(file_name)
120-
# self.weights.set(data['weights'])
121-
# if "biases" in data.keys():
122-
# self.biases.set(data['biases'])
123-
124108
@classmethod
125109
def help(cls): ## component help function
126110
properties = {

0 commit comments

Comments
 (0)