Skip to content

Commit c80f2b5

Browse files
committed
update components and their related test cases
1 parent b96139f commit c80f2b5

File tree

8 files changed

+72
-115
lines changed

8 files changed

+72
-115
lines changed

ngclearn/components/other/expKernel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def advance_state(self, t):
8686
def reset(self):
8787
restVals = jnp.zeros((self.batch_size, self.n_units)) ## inputs, epsp
8888
restTensor = jnp.zeros([self.win_len, self.batch_size, self.n_units], jnp.float32) ## tf
89-
not self.inputs.targeted and self.inputs.set(restVals)
89+
# BUG: the self.inputs here does not have the targeted field
90+
# NOTE: Quick workaround is to check if targeted is in the input or not
91+
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
9092
self.epsp.set(restVals)
9193
self.tf.set(restTensor)
9294

ngclearn/components/other/varTrace.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def advance_state(self, dt):
126126
@compilable
127127
def reset(self):
128128
restVals = jnp.zeros((self.batch_size, self.n_units))
129-
not self.inputs.targeted and self.inputs.set(restVals)
129+
# BUG: the self.inputs here does not have the targeted field
130+
# NOTE: Quick workaround is to check if targeted is in the input or not
131+
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
130132
self.outputs.set(restVals)
131133
self.trace.set(restVals)
132134

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,18 @@ def evolve(self):
277277
self.dBiases.set(dBiases)
278278

279279
@compilable
280-
def reset(self, batch_size, shape):
281-
preVals = jnp.zeros((batch_size, shape[0]))
282-
postVals = jnp.zeros((batch_size, shape[1]))
283-
not self.inputs.targeted and self.inputs.set(preVals) # inputs
280+
def reset(self):
281+
preVals = jnp.zeros((self.batch_size, self.shape[0]))
282+
postVals = jnp.zeros((self.batch_size, self.shape[1]))
283+
# BUG: the self.inputs here does not have the targeted field
284+
# NOTE: Quick workaround is to check if targeted is in the input or not
285+
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals) # inputs
284286
self.outputs.set(postVals) # outputs
285287
self.pre.set(preVals) # pre
286288
self.post.set(postVals) # post
287-
self.dWeights.set(jnp.zeros(shape)) # dW
288-
self.dBiases.set(jnp.zeros(shape[1])) # db
289+
self.dWeights.set(jnp.zeros(self.shape)) # dW
290+
self.dBiases.set(jnp.zeros(self.shape[1])) # db
291+
289292

290293
@classmethod
291294
def help(cls): ## component help function

ngclearn/components/synapses/patched/patchedSynapse.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,21 +154,23 @@ def advance_state(self):
154154
self.outputs.set(outputs)
155155

156156
@compilable
157-
def reset(self, batch_size, shape):
158-
preVals = jnp.zeros((batch_size, shape[0]))
159-
postVals = jnp.zeros((batch_size, shape[1]))
157+
def reset(self):
158+
preVals = jnp.zeros((self.batch_size, self.shape[0]))
159+
postVals = jnp.zeros((self.batch_size, self.shape[1]))
160160
inputs = preVals
161161
outputs = postVals
162-
not self.inputs.targeted and self.inputs.set(inputs)
162+
# BUG: the self.inputs here does not have the targeted field
163+
# NOTE: Quick workaround is to check if targeted is in the input or not
164+
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(inputs)
163165
self.outputs.set(outputs)
164166

165167
def save(self, directory, **kwargs):
166168
file_name = directory + "/" + self.name + ".npz"
167169
if self.bias_init != None:
168-
jnp.savez(file_name, weights=self.weights.value,
169-
biases=self.biases.value)
170+
jnp.savez(file_name, weights=self.weights.get(),
171+
biases=self.biases.get())
170172
else:
171-
jnp.savez(file_name, weights=self.weights.value)
173+
jnp.savez(file_name, weights=self.weights.get())
172174

173175
def load(self, directory, **kwargs):
174176
file_name = directory + "/" + self.name + ".npz"
Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,8 @@
11
from jax import numpy as jnp, random, jit
2-
from ngcsimlib.context import Context
32
import numpy as np
43
np.random.seed(42)
54
from ngclearn.components import ExpKernel
6-
from ngcsimlib.compilers import compile_command, wrap_command
7-
from numpy.testing import assert_array_equal
8-
9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
12-
from ngcsimlib.context import Context
13-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
5+
from ngclearn import MethodProcess, Context
146

157
def test_expKernel1():
168
name = "expKernel_ctx"
@@ -25,16 +17,12 @@ def test_expKernel1():
2517
name="a", n_units=1, dt=1., tau_w=500., nu=4., key=subkeys[0]
2618
)
2719

28-
advance_process = (Process("advance_proc")
20+
advance_process = (MethodProcess("advance_proc")
2921
>> a.advance_state)
30-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
31-
32-
reset_process = (Process("reset_proc")
22+
reset_process = (MethodProcess("reset_proc")
3323
>> a.reset)
34-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3524

3625
## set up non-compiled utility commands
37-
@Context.dynamicCommand
3826
def clamp(x):
3927
a.inputs.set(x)
4028

@@ -44,16 +32,16 @@ def clamp(x):
4432
y_seq = jnp.asarray([[0., 1., 0.998002, 0.996008, 1.9940181]], dtype=jnp.float32)
4533

4634
outs = []
47-
ctx.reset()
35+
reset_process.run()
4836
for ts in range(x_seq.shape[1]):
4937
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
50-
ctx.clamp(x_t)
51-
ctx.run(t=ts * 1., dt=dt)
52-
outs.append(a.epsp.value)
38+
clamp(x_t)
39+
advance_process.run(t=ts * 1., dt=dt)
40+
outs.append(a.epsp.get())
5341
outs = jnp.concatenate(outs, axis=1)
5442
#print(outs)
5543

5644
## output should equal input
5745
np.testing.assert_allclose(outs, y_seq, atol=1e-8)
5846

59-
#test_expKernel1()
47+
test_expKernel1()
Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
from jax import numpy as jnp, random, jit
2-
from ngcsimlib.context import Context
32
import numpy as np
43
np.random.seed(42)
54
from ngclearn.components import VarTrace
6-
from ngcsimlib.compilers import compile_command, wrap_command
75
from numpy.testing import assert_array_equal
86

9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
12-
from ngcsimlib.context import Context
13-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
7+
from ngclearn import MethodProcess, Context
8+
149

1510
def test_varTrace1():
1611
name = "trace_ctx"
@@ -26,35 +21,32 @@ def test_varTrace1():
2621
key=subkeys[0]
2722
)
2823

29-
advance_process = (Process("advance_proc")
24+
advance_process = (MethodProcess("advance_proc")
3025
>> a.advance_state)
31-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3226

33-
reset_process = (Process("reset_proc")
27+
reset_process = (MethodProcess("reset_proc")
3428
>> a.reset)
35-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3629

3730
## set up non-compiled utility commands
38-
@Context.dynamicCommand
3931
def clamp(x):
4032
a.inputs.set(x)
4133

4234
## input spike train
4335
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
4436
## desired output pulses
45-
y_seq = x_seq * trace_increment
37+
y_seq = x_seq * trace_increment
4638

4739
outs = []
48-
ctx.reset()
40+
reset_process.run()
4941
for ts in range(x_seq.shape[1]):
5042
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
51-
ctx.clamp(x_t)
52-
ctx.run(t=ts * 1., dt=dt)
53-
outs.append(a.outputs.value)
43+
clamp(x_t)
44+
advance_process.run(t=ts * 1., dt=dt)
45+
outs.append(a.outputs.get())
5446
outs = jnp.concatenate(outs, axis=1)
5547
#print(outs)
5648

5749
## output should equal input
5850
assert_array_equal(outs, y_seq)
5951

60-
#test_varTrace1()
52+
test_varTrace1()
Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
11
# %%
22

33
from jax import numpy as jnp, random, jit
4-
from ngcsimlib.context import Context
54
import numpy as np
65
np.random.seed(42)
76
from ngclearn.components import HebbianPatchedSynapse
8-
from ngcsimlib.compilers import compile_command, wrap_command
97
from numpy.testing import assert_array_equal
108

11-
from ngcsimlib.compilers.process import Process, transition
12-
from ngcsimlib.component import Component
13-
from ngcsimlib.compartment import Compartment
14-
from ngcsimlib.context import Context
15-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16-
9+
from ngclearn import MethodProcess, Context
1710

1811
def test_hebbianPatchedSynapse():
1912
np.random.seed(42)
@@ -31,58 +24,45 @@ def test_hebbianPatchedSynapse():
3124

3225
with Context(name) as ctx:
3326
a = HebbianPatchedSynapse(
34-
name="a",
35-
shape=shape,
36-
n_sub_models=n_sub_models,
27+
name="a",
28+
shape=shape,
29+
n_sub_models=n_sub_models,
3730
stride_shape=stride_shape,
3831
resist_scale=resist_scale,
3932
batch_size=batch_size
4033
)
4134

42-
advance_process = (Process("advance_proc") >> a.advance_state)
43-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
44-
reset_process = (Process("reset_proc") >> a.reset)
45-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
46-
evolve_process = (Process("evolve_proc") >> a.evolve)
47-
ctx.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
48-
49-
# Compile and add commands
50-
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
51-
# ctx.add_command(wrap_command(jit(reset_cmd)), name="reset")
52-
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
53-
# ctx.add_command(wrap_command(jit(advance_cmd)), name="run")
54-
# evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
55-
# ctx.add_command(wrap_command(jit(evolve_cmd)), name="evolve")
56-
57-
@Context.dynamicCommand
35+
advance_process = (MethodProcess("advance_proc") >> a.advance_state)
36+
reset_process = (MethodProcess("reset_proc") >> a.reset)
37+
evolve_process = (MethodProcess("evolve_proc") >> a.evolve)
38+
5839
def clamp_inputs(x):
5940
a.inputs.set(x)
6041

61-
@Context.dynamicCommand
6242
def clamp_pre(x):
6343
a.pre.set(x)
6444

65-
@Context.dynamicCommand
6645
def clamp_post(x):
6746
a.post.set(x)
6847

69-
a.weights.set(jnp.ones((12, 12)) * 0.5)
48+
a.weights.set(jnp.ones((12, 12)) * 0.5)
7049

7150
in_pre = jnp.ones((10, 12)) * 1.0
7251
in_post = jnp.ones((10, 12)) * 0.75
7352

74-
ctx.reset()
53+
reset_process.run()
7554
clamp_pre(in_pre)
7655
clamp_post(in_post)
77-
ctx.run(t=1. * dt, dt=dt)
78-
ctx.evolve(t=1. * dt, dt=dt)
56+
advance_process.run(t=1. * dt, dt=dt)
57+
evolve_process.run(t=1. * dt, dt=dt)
7958

80-
print(a.weights.value)
59+
print(a.weights.get())
8160

8261
# Basic assertions to check learning dynamics
83-
assert a.weights.value.shape == (12, 12), ""
84-
assert a.weights.value[0, 0] == 0.5, ""
62+
assert a.weights.get().shape == (12, 12), ""
63+
assert a.weights.get()[0, 0] == 0.5, ""
64+
8565

66+
test_hebbianPatchedSynapse()
8667

8768

88-
# test_hebbianPatchedSynapse()
Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
# %%
22

33
from jax import numpy as jnp, random, jit
4-
from ngcsimlib.context import Context
54
import numpy as np
65
np.random.seed(42)
76
from ngclearn.components import PatchedSynapse
8-
from ngcsimlib.compilers import compile_command, wrap_command
9-
from numpy.testing import assert_array_equal
107

11-
from ngcsimlib.compilers.process import Process, transition
12-
from ngcsimlib.component import Component
13-
from ngcsimlib.compartment import Compartment
14-
from ngcsimlib.context import Context
15-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
8+
from ngclearn import MethodProcess, Context
9+
1610

1711

1812
def test_patchedSynapse():
@@ -39,31 +33,25 @@ def test_patchedSynapse():
3933
bias_init={"dist": "constant", "value": 0.0}
4034
)
4135

42-
advance_process = (Process("advance_proc") >> a.advance_state)
43-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
44-
reset_process = (Process("reset_proc") >> a.reset)
45-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
46-
47-
# Compile and add commands
48-
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
49-
# ctx.add_command(wrap_command(jit(reset_cmd)), name="reset")
50-
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
51-
# ctx.add_command(wrap_command(jit(advance_cmd)), name="run")
36+
advance_process = (MethodProcess("advance_proc") >> a.advance_state)
37+
reset_process = (MethodProcess("reset_proc") >> a.reset)
5238

53-
@Context.dynamicCommand
5439
def clamp_inputs(x):
5540
a.inputs.set(x)
5641

5742
inputs_seq = jnp.asarray(np.random.randn(1, 12))
58-
weights = a.weights.value
59-
biases = a.biases.value
43+
weights = a.weights.get()
44+
biases = a.biases.get()
6045
expected_outputs = (jnp.matmul(inputs_seq, weights) * resist_scale) + biases
6146
outputs_outs = []
62-
ctx.reset()
63-
ctx.clamp_inputs(inputs_seq)
64-
ctx.run(t=0., dt=dt)
65-
outputs_outs.append(a.outputs.value)
47+
reset_process.run()
48+
clamp_inputs(inputs_seq)
49+
advance_process.run(t=0., dt=dt)
50+
outputs_outs.append(a.outputs.get())
6651
outputs_outs = jnp.concatenate(outputs_outs, axis=1)
6752
# Verify outputs match expected values
6853
np.testing.assert_allclose(outputs_outs, expected_outputs, atol=1e-5)
6954

55+
56+
test_patchedSynapse()
57+

0 commit comments

Comments
 (0)