Skip to content

Commit f480ca2

Browse files
author
Alexander Ororbia
committed
cleaned up a few unit tests to use deterministic syn init vals
1 parent 9f4f7f9 commit f480ca2

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

tests/components/synapses/hebbian/test_eventSTDPSynapse.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,15 @@ def test_eventSTDPSynapse1():
4848
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
4949
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
5050
"""
51+
a.weights.set(jnp.ones((1, 1)) * 0.1)
5152

5253
t = 12. ## fake out current time
5354
## Case 1: outside of pre-syn time window
5455
input_tols = jnp.ones((1, 1,)) * 9.
5556
out_spike = jnp.ones((1, 1))
5657

5758
## check pre-synaptic STDP only
58-
truth = jnp.array([[-0.6296545]])
59+
truth = jnp.array([[-0.101]])
5960
ctx.reset()
6061
a.pre_tols.set(input_tols)
6162
a.postSpike.set(out_spike)
@@ -69,7 +70,7 @@ def test_eventSTDPSynapse1():
6970
out_spike = jnp.ones((1, 1))
7071

7172
## check pre-synaptic STDP only
72-
truth = jnp.array([[0.37034547]])
73+
truth = jnp.array([[0.899]])
7374
ctx.reset()
7475
a.pre_tols.set(input_tols)
7576
a.postSpike.set(out_spike)
@@ -79,3 +80,4 @@ def test_eventSTDPSynapse1():
7980
assert_array_equal(a.dWeights.value, truth)
8081

8182
#test_eventSTDPSynapse1()
83+

tests/components/synapses/hebbian/test_expSTDPSynapse.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,15 @@ def test_expSTDPSynapse1():
4747
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
4848
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
4949
"""
50+
a.weights.set(jnp.ones((1, 1)) * 0.1)
5051

5152
in_spike = jnp.ones((1, 1))
5253
in_trace = jnp.ones((1, 1,)) * 1.25
5354
out_spike = jnp.ones((1, 1))
5455
out_trace = jnp.ones((1, 1,)) * 0.65
5556

5657
## check pre-synaptic STDP only
57-
truth = jnp.array([[0.57342285]])
58+
truth = jnp.array([[1.1031212]])
5859
ctx.reset()
5960
a.preSpike.set(in_spike * 0)
6061
a.preTrace.set(in_trace)
@@ -65,7 +66,7 @@ def test_expSTDPSynapse1():
6566
#print(a.dWeights.value)
6667
assert_array_equal(a.dWeights.value, truth)
6768

68-
truth = jnp.array([[-0.29817986]])
69+
truth = jnp.array([[-0.57362294]])
6970
ctx.reset()
7071
a.preSpike.set(in_spike)
7172
a.preTrace.set(in_trace)
@@ -77,3 +78,4 @@ def test_expSTDPSynapse1():
7778
assert_array_equal(a.dWeights.value, truth)
7879

7980
#test_expSTDPSynapse1()
81+

tests/components/synapses/hebbian/test_traceSTDPSynapse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_traceSTDPSynapse1():
4747
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
4848
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
4949
"""
50+
a.weights.set(jnp.ones((1, 1)) * 0.1)
5051

5152
in_spike = jnp.ones((1, 1))
5253
in_trace = jnp.ones((1, 1,)) * 1.25

0 commit comments

Comments
 (0)