Skip to content

Commit 5272fdc

Browse files
author
Alexander Ororbia
committed
refactored bcm syn w/ unit-test
1 parent 29b49ff commit 5272fdc

File tree

3 files changed

+93
-22
lines changed

3 files changed

+93
-22
lines changed

docs/modeling/synapses.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@ used for fixed value deconvolution/transposed convolution synaptic filters.
6363
:noindex:
6464
```
6565

66+
## Dynamic Synapse Types
67+
68+
### Short-Term Plasticity(Dense) Synapse
69+
70+
This synapse performs a linear transform of its input signals. Note that this
71+
synapse is "dynamic" in the sense that it engages in short-term plasticity (STP), meaning that its efficacy values change as a function of its inputs (and simulated consumed resources), but it does not provide any long-term form of plasticity/adjustment.
72+
73+
```{eval-rst}
74+
.. autoclass:: ngclearn.components.STPDenseSynapse
75+
:noindex:
76+
77+
.. automethod:: advance_state
78+
:noindex:
79+
.. automethod:: reset
80+
:noindex:
81+
```
82+
6683
## Multi-Factor Learning Synapse Types
6784

6885
Hebbian rules operate in a local manner -- they generally use information more

ngclearn/components/synapses/hebbian/BCMSynapse.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
36
from ngclearn.components.synapses import DenseSynapse
47
from ngclearn.utils import tensorstats
58

@@ -64,8 +67,10 @@ class BCMSynapse(DenseSynapse): # BCM-adjusted synaptic cable
6467
"""
6568

6669
# Define Functions
67-
def __init__(self, name, shape, tau_w, tau_theta, theta0=-1., w_bound=0., w_decay=0.,
68-
weight_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs):
70+
def __init__(
71+
self, name, shape, tau_w, tau_theta, theta0=-1., w_bound=0., w_decay=0., weight_init=None, resist_scale=1.,
72+
p_conn=1., batch_size=1, **kwargs
73+
):
6974
super().__init__(name, shape, weight_init, None, resist_scale, p_conn,
7075
batch_size=batch_size, **kwargs)
7176

@@ -87,8 +92,9 @@ def __init__(self, name, shape, tau_w, tau_theta, theta0=-1., w_bound=0., w_deca
8792
self.theta = Compartment(postVals + self.theta0) ## synaptic modification thresholds
8893
self.dWeights = Compartment(self.weights.value * 0)
8994

95+
@transition(output_compartments=["weights", "theta", "dWeights", "post_term"])
9096
@staticmethod
91-
def _evolve(t, dt, tau_w, tau_theta, w_bound, w_decay, pre, post, theta, weights):
97+
def evolve(t, dt, tau_w, tau_theta, w_bound, w_decay, pre, post, theta, weights):
9298
eps = 1e-7
9399
post_term = post * (post - theta) # post - theta
94100
post_term = post_term * (1. / (theta + eps))
@@ -101,18 +107,11 @@ def _evolve(t, dt, tau_w, tau_theta, w_bound, w_decay, pre, post, theta, weights
101107
## update synaptic modification threshold as a leaky ODE
102108
dtheta = jnp.mean(jnp.square(post), axis=0, keepdims=True) ## batch avg
103109
theta = theta + (-theta + dtheta) * dt / tau_theta
104-
105110
return weights, theta, dWeights, post_term
106111

107-
@resolver(_evolve)
108-
def evolve(self, weights, theta, dWeights, post_term):
109-
self.weights.set(weights)
110-
self.theta.set(theta)
111-
self.dWeights.set(dWeights)
112-
self.post_term.set(post_term)
113-
112+
@transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "post_term"])
114113
@staticmethod
115-
def _reset(batch_size, shape):
114+
def reset(batch_size, shape):
116115
preVals = jnp.zeros((batch_size, shape[0]))
117116
postVals = jnp.zeros((batch_size, shape[1]))
118117
inputs = preVals
@@ -123,15 +122,6 @@ def _reset(batch_size, shape):
123122
post_term = postVals
124123
return inputs, outputs, pre, post, dWeights, post_term
125124

126-
@resolver(_reset)
127-
def reset(self, inputs, outputs, pre, post, dWeights, post_term):
128-
self.inputs.set(inputs)
129-
self.outputs.set(outputs)
130-
self.pre.set(pre)
131-
self.post.set(post)
132-
self.dWeights.set(dWeights)
133-
self.post_term.set(post_term)
134-
135125
def save(self, directory, **kwargs):
136126
file_name = directory + "/" + self.name + ".npz"
137127
jnp.savez(file_name,
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import BCMSynapse
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
14+
def test_BCMSynapse1():
15+
name = "bcm_stdp_ctx"
16+
## create seeding keys
17+
dkey = random.PRNGKey(1234)
18+
dkey, *subkeys = random.split(dkey, 6)
19+
dt = 1. # ms
20+
# ---- build a simple Poisson cell system ----
21+
with Context(name) as ctx:
22+
a = BCMSynapse(
23+
name="a", shape=(1,1), tau_w=40., tau_theta=20., key=subkeys[0]
24+
)
25+
26+
#"""
27+
evolve_process = (Process()
28+
>> a.evolve)
29+
#ctx.wrap_and_add_command(evolve_process.pure, name="run")
30+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
31+
32+
advance_process = (Process()
33+
>> a.advance_state)
34+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
35+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
36+
37+
reset_process = (Process()
38+
>> a.reset)
39+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
40+
#"""
41+
42+
"""
43+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
44+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
45+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
46+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
47+
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
48+
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
49+
"""
50+
51+
pre_value = jnp.ones((1, 1)) * 0.425
52+
post_value = jnp.ones((1, 1)) * 1.55
53+
54+
truth = jnp.array([[-1.6798127]])
55+
ctx.reset()
56+
a.pre.set(pre_value)
57+
a.post.set(post_value)
58+
ctx.run(t=1., dt=dt)
59+
ctx.adapt(t=1., dt=dt)
60+
#print(a.dWeights.value)
61+
assert_array_equal(a.dWeights.value, truth)
62+
63+
64+
#test_BCMSynapse1()

0 commit comments

Comments
 (0)