Skip to content

Commit d26b417

Browse files
author
Alexander Ororbia
committed
refactored bcm-syn and test passed
1 parent 0d1a35f commit d26b417

File tree

3 files changed

+61
-95
lines changed

3 files changed

+61
-95
lines changed

ngclearn/components/synapses/hebbian/BCMSynapse.py

Lines changed: 44 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from jax import random, numpy as jnp, jit
2-
from ngcsimlib.compilers.process import transition
3-
from ngcsimlib.component import Component
42
from ngcsimlib.compartment import Compartment
3+
from ngcsimlib.parser import compilable
54

6-
from ngclearn.components.synapses import DenseSynapse
7-
from ngclearn.utils import tensorstats
5+
from ngclearn.components.synapses.denseSynapse import DenseSynapse
86

97
class BCMSynapse(DenseSynapse): # BCM-adjusted synaptic cable
108
"""
@@ -71,8 +69,7 @@ def __init__(
7169
self, name, shape, tau_w, tau_theta, theta0=-1., w_bound=0., w_decay=0., weight_init=None, resist_scale=1.,
7270
p_conn=1., batch_size=1, **kwargs
7371
):
74-
super().__init__(name, shape, weight_init, None, resist_scale, p_conn,
75-
batch_size=batch_size, **kwargs)
72+
super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs)
7673

7774
## Synapse and BCM hyper-parameters
7875
self.shape = shape ## shape of synaptic efficacy matrix
@@ -90,48 +87,51 @@ def __init__(
9087
self.post = Compartment(postVals) ## post-synaptic statistic
9188
self.post_term = Compartment(postVals)
9289
self.theta = Compartment(postVals + self.theta0) ## synaptic modification thresholds
93-
self.dWeights = Compartment(self.weights.value * 0)
90+
self.dWeights = Compartment(self.weights.get() * 0)
9491

95-
@transition(output_compartments=["weights", "theta", "dWeights", "post_term"])
96-
@staticmethod
97-
def evolve(t, dt, tau_w, tau_theta, w_bound, w_decay, pre, post, theta, weights):
92+
@compilable
93+
def evolve(self, t, dt): #t, dt, tau_w, tau_theta, w_bound, w_decay, pre, post, theta, weights):
9894
eps = 1e-7
99-
post_term = post * (post - theta) # post - theta
100-
post_term = post_term * (1. / (theta + eps))
101-
dWeights = jnp.matmul(pre.T, post_term)
102-
if w_bound > 0.:
103-
dWeights = dWeights * (w_bound - jnp.abs(weights))
95+
post_term = self.post.get() * (self.post.get() - self.theta.get()) # post - theta
96+
post_term = post_term * (1. / (self.theta.get() + eps))
97+
dWeights = jnp.matmul(self.pre.get().T, post_term)
98+
if self.w_bound > 0.:
99+
dWeights = dWeights * (self.w_bound - jnp.abs(self.weights.get()))
104100
## update synaptic efficacies according to a leaky ODE
105-
dWeights = -weights * w_decay + dWeights
106-
_W = weights + dWeights * dt / tau_w
101+
dWeights = -self.weights.get() * self.w_decay + dWeights
102+
_W = self.weights.get() + dWeights * dt / self.tau_w
107103
## update synaptic modification threshold as a leaky ODE
108-
dtheta = jnp.mean(jnp.square(post), axis=0, keepdims=True) ## batch avg
109-
theta = theta + (-theta + dtheta) * dt / tau_theta
110-
return weights, theta, dWeights, post_term
111-
112-
@transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "post_term"])
113-
@staticmethod
114-
def reset(batch_size, shape):
115-
preVals = jnp.zeros((batch_size, shape[0]))
116-
postVals = jnp.zeros((batch_size, shape[1]))
117-
inputs = preVals
118-
outputs = postVals
119-
pre = preVals
120-
post = postVals
121-
dWeights = jnp.zeros(shape)
122-
post_term = postVals
123-
return inputs, outputs, pre, post, dWeights, post_term
124-
125-
def save(self, directory, **kwargs):
126-
file_name = directory + "/" + self.name + ".npz"
127-
jnp.savez(file_name,
128-
weights=self.weights.value, theta=self.theta.value)
129-
130-
def load(self, directory, **kwargs):
131-
file_name = directory + "/" + self.name + ".npz"
132-
data = jnp.load(file_name)
133-
self.weights.set(data['weights'])
134-
self.theta.set(data['theta'])
104+
dtheta = jnp.mean(jnp.square(self.post.get()), axis=0, keepdims=True) ## batch avg
105+
theta = self.theta.get() + (-self.theta.get() + dtheta) * dt / self.tau_theta
106+
107+
#self.weights.set(weights)
108+
self.theta.set(theta)
109+
self.dWeights.set(dWeights)
110+
self.post_term.set(post_term)
111+
112+
@compilable
113+
def reset(self):
114+
preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
115+
postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
116+
117+
if not self.inputs.targeted:
118+
self.inputs.set(preVals)
119+
self.outputs.set(postVals)
120+
self.pre.set(preVals)
121+
self.post.set(postVals)
122+
self.dWeights.set(jnp.zeros(self.shape.get()))
123+
self.post_term.set(postVals)
124+
125+
# def save(self, directory, **kwargs):
126+
# file_name = directory + "/" + self.name + ".npz"
127+
# jnp.savez(file_name,
128+
# weights=self.weights.value, theta=self.theta.value)
129+
#
130+
# def load(self, directory, **kwargs):
131+
# file_name = directory + "/" + self.name + ".npz"
132+
# data = jnp.load(file_name)
133+
# self.weights.set(data['weights'])
134+
# self.theta.set(data['theta'])
135135

136136
@classmethod
137137
def help(cls): ## component help function
@@ -175,21 +175,6 @@ def help(cls): ## component help function
175175
"hyperparameters": hyperparams}
176176
return info
177177

178-
def __repr__(self):
179-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
180-
maxlen = max(len(c) for c in comps) + 5
181-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
182-
for c in comps:
183-
stats = tensorstats(getattr(self, c).value)
184-
if stats is not None:
185-
line = [f"{k}: {v}" for k, v in stats.items()]
186-
line = ", ".join(line)
187-
else:
188-
line = "None"
189-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
190-
return lines
191-
192-
193178
if __name__ == '__main__':
194179
from ngcsimlib.context import Context
195180
with Context("Bar") as bar:
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .hebbianSynapse import HebbianSynapse
1+
#from .hebbianSynapse import HebbianSynapse
22
from .traceSTDPSynapse import TraceSTDPSynapse
3-
from .expSTDPSynapse import ExpSTDPSynapse
4-
from .eventSTDPSynapse import EventSTDPSynapse
3+
#from .expSTDPSynapse import ExpSTDPSynapse
4+
#from .eventSTDPSynapse import EventSTDPSynapse
55
from .BCMSynapse import BCMSynapse
66

tests/components/synapses/hebbian/test_BCMSynapse.py

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
from ngcsimlib.context import Context
33
import numpy as np
44
np.random.seed(42)
5-
from ngclearn.components import BCMSynapse
6-
from ngcsimlib.compilers import compile_command, wrap_command
7-
from numpy.testing import assert_array_equal
85

9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
12-
from ngcsimlib.context import Context
6+
from ngclearn import Context, MethodProcess
7+
import ngclearn.utils.weight_distribution as dist
8+
from ngclearn.components.synapses.hebbian.BCMSynapse import BCMSynapse
9+
from numpy.testing import assert_array_equal
1310

1411
def test_BCMSynapse1():
1512
name = "bcm_stdp_ctx"
@@ -23,42 +20,26 @@ def test_BCMSynapse1():
2320
name="a", shape=(1,1), tau_w=40., tau_theta=20., key=subkeys[0]
2421
)
2522

26-
#"""
27-
evolve_process = (Process("evolve_proc")
23+
evolve_process = (MethodProcess("evolve_process")
2824
>> a.evolve)
29-
#ctx.wrap_and_add_command(evolve_process.pure, name="run")
30-
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
3125

32-
advance_process = (Process("advance_proc")
26+
advance_process = (MethodProcess("advance_proc")
3327
>> a.advance_state)
34-
# ctx.wrap_and_add_command(advance_process.pure, name="run")
35-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3628

37-
reset_process = (Process("reset_proc")
29+
reset_process = (MethodProcess("reset_proc")
3830
>> a.reset)
39-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
40-
#"""
41-
42-
"""
43-
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
44-
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
45-
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
46-
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
47-
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
48-
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
49-
"""
5031

5132
pre_value = jnp.ones((1, 1)) * 0.425
5233
post_value = jnp.ones((1, 1)) * 1.55
5334

5435
truth = jnp.array([[-1.6798127]])
55-
ctx.reset()
36+
reset_process.run() # ctx.reset()
5637
a.pre.set(pre_value)
5738
a.post.set(post_value)
58-
ctx.run(t=1., dt=dt)
59-
ctx.adapt(t=1., dt=dt)
60-
#print(a.dWeights.value)
61-
assert_array_equal(a.dWeights.value, truth)
62-
39+
advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt)
40+
evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt)
41+
# print(a.dWeights.get())
42+
# print(truth)
43+
assert_array_equal(a.dWeights.get(), truth)
6344

64-
#test_BCMSynapse1()
45+
test_BCMSynapse1()

0 commit comments

Comments
 (0)