Skip to content

Commit e3791df

Browse files
committed
push hebbian synapse
1 parent bc713f6 commit e3791df

File tree

2 files changed

+90
-19
lines changed

2 files changed

+90
-19
lines changed

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from functools import partial
33
from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn
44
from ngclearn import resolver, Component, Compartment
5+
from ngcsimlib.compilers.process import transition
56
from ngclearn.components.synapses import DenseSynapse
67
from ngclearn.utils import tensorstats
78
from ngcsimlib.deprecators import deprecate_args
@@ -216,8 +217,9 @@ def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda
216217
post_wght=post_wght)
217218
return dW, db
218219

220+
@transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"])
219221
@staticmethod
220-
def _evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
222+
def evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
221223
post_wght, bias_init, pre, post, weights, biases, opt_params):
222224
## calculate synaptic update values
223225
dWeights, dBiases = HebbianSynapse._compute_update(
@@ -234,16 +236,9 @@ def _evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, p
234236
weights = _enforce_constraints(weights, w_bound, is_nonnegative=is_nonnegative)
235237
return opt_params, weights, biases, dWeights, dBiases
236238

237-
@resolver(_evolve)
238-
def evolve(self, opt_params, weights, biases, dWeights, dBiases):
239-
self.opt_params.set(opt_params)
240-
self.weights.set(weights)
241-
self.biases.set(biases)
242-
self.dWeights.set(dWeights)
243-
self.dBiases.set(dBiases)
244-
239+
@transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "dBiases"])
245240
@staticmethod
246-
def _reset(batch_size, shape):
241+
def reset(batch_size, shape):
247242
preVals = jnp.zeros((batch_size, shape[0]))
248243
postVals = jnp.zeros((batch_size, shape[1]))
249244
return (
@@ -255,15 +250,6 @@ def _reset(batch_size, shape):
255250
jnp.zeros(shape[1]), # db
256251
)
257252

258-
@resolver(_reset)
259-
def reset(self, inputs, outputs, pre, post, dWeights, dBiases):
260-
self.inputs.set(inputs)
261-
self.outputs.set(outputs)
262-
self.pre.set(pre)
263-
self.post.set(post)
264-
self.dWeights.set(dWeights)
265-
self.dBiases.set(dBiases)
266-
267253
@classmethod
268254
def help(cls): ## component help function
269255
properties = {
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components import HebbianSynapse
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16+
17+
18+
def test_hebbianSynapse():
19+
np.random.seed(42)
20+
name = "hebbian_synapse_ctx"
21+
dkey = random.PRNGKey(42)
22+
dkey, *subkeys = random.split(dkey, 100)
23+
dt = 1. # ms
24+
25+
# model hyper
26+
shape = (10, 5)
27+
batch_size = 1
28+
resist_scale = 1.0
29+
30+
with Context(name) as ctx:
31+
a = HebbianSynapse(
32+
name="a",
33+
shape=shape,
34+
resist_scale=resist_scale,
35+
batch_size=batch_size,
36+
prior = ("gaussian", 0.01)
37+
)
38+
39+
advance_process = (Process() >> a.advance_state)
40+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
41+
reset_process = (Process() >> a.reset)
42+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
43+
evolve_process = (Process() >> a.evolve)
44+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
45+
46+
# Compile and add commands
47+
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
48+
# ctx.add_command(wrap_command(jit(reset_cmd)), name="reset")
49+
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
50+
# ctx.add_command(wrap_command(jit(advance_cmd)), name="run")
51+
# evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
52+
# ctx.add_command(wrap_command(jit(evolve_cmd)), name="evolve")
53+
54+
@Context.dynamicCommand
55+
def clamp_inputs(x):
56+
a.inputs.set(x)
57+
58+
@Context.dynamicCommand
59+
def clamp_pre(x):
60+
a.pre.set(x)
61+
62+
@Context.dynamicCommand
63+
def clamp_post(x):
64+
a.post.set(x)
65+
66+
# Test input sequence
67+
# Initial weights
68+
a.weights.set(jnp.ones((10, 5)) * 0.5)
69+
70+
in_pre = jnp.ones((1, 10)) * 1.0
71+
in_post = jnp.ones((1, 5)) * 0.75
72+
73+
ctx.reset()
74+
clamp_pre(in_pre)
75+
clamp_post(in_post)
76+
ctx.run(t=1. * dt, dt=dt)
77+
ctx.evolve(t=1. * dt, dt=dt)
78+
79+
print(a.weights.value)
80+
81+
# Basic assertions to check learning dynamics
82+
assert a.weights.value.shape == (10, 5), ""
83+
assert a.weights.value[0, 0] == 0.5, ""
84+
85+
# test_hebbianSynapse()

0 commit comments

Comments
 (0)