Skip to content

Commit 55756d0

Browse files
author
Alexander Ororbia
committed
refactored alpha and exp-synapses, tests passed; minor edit to __init__ for synapses
1 parent b86ae3d commit 55756d0

File tree

5 files changed

+144
-173
lines changed

5 files changed

+144
-173
lines changed

ngclearn/components/synapses/__init__.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,36 @@
33

44

55
## short-term plasticity components
6-
from .STPDenseSynapse import STPDenseSynapse
6+
# from .STPDenseSynapse import STPDenseSynapse
77
from .exponentialSynapse import ExponentialSynapse
8-
from .doubleExpSynapse import DoupleExpSynapse
8+
# from .doubleExpSynapse import DoupleExpSynapse
99
from .alphaSynapse import AlphaSynapse
10-
11-
## dense synaptic components
12-
from .hebbian.hebbianSynapse import HebbianSynapse
13-
from .hebbian.traceSTDPSynapse import TraceSTDPSynapse
14-
from .hebbian.expSTDPSynapse import ExpSTDPSynapse
15-
from .hebbian.eventSTDPSynapse import EventSTDPSynapse
16-
from .hebbian.BCMSynapse import BCMSynapse
17-
18-
19-
## conv/deconv synaptic components
20-
from .convolution.convSynapse import ConvSynapse
21-
from .convolution.staticConvSynapse import StaticConvSynapse
22-
from .convolution.hebbianConvSynapse import HebbianConvSynapse
23-
from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse
24-
from .convolution.deconvSynapse import DeconvSynapse
25-
from .convolution.staticDeconvSynapse import StaticDeconvSynapse
26-
from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse
27-
from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
28-
29-
30-
## modulated synaptic components
31-
from .modulated.MSTDPETSynapse import MSTDPETSynapse
32-
from .modulated.REINFORCESynapse import REINFORCESynapse
33-
34-
## patched synaptic components
35-
from .patched.patchedSynapse import PatchedSynapse
36-
from .patched.staticPatchedSynapse import StaticPatchedSynapse
37-
from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse
38-
10+
#
11+
# ## dense synaptic components
12+
# from .hebbian.hebbianSynapse import HebbianSynapse
13+
# from .hebbian.traceSTDPSynapse import TraceSTDPSynapse
14+
# from .hebbian.expSTDPSynapse import ExpSTDPSynapse
15+
# from .hebbian.eventSTDPSynapse import EventSTDPSynapse
16+
# from .hebbian.BCMSynapse import BCMSynapse
17+
#
18+
#
19+
# ## conv/deconv synaptic components
20+
# from .convolution.convSynapse import ConvSynapse
21+
# from .convolution.staticConvSynapse import StaticConvSynapse
22+
# from .convolution.hebbianConvSynapse import HebbianConvSynapse
23+
# from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse
24+
# from .convolution.deconvSynapse import DeconvSynapse
25+
# from .convolution.staticDeconvSynapse import StaticDeconvSynapse
26+
# from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse
27+
# from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
28+
#
29+
#
30+
# ## modulated synaptic components
31+
# from .modulated.MSTDPETSynapse import MSTDPETSynapse
32+
# from .modulated.REINFORCESynapse import REINFORCESynapse
33+
#
34+
# ## patched synaptic components
35+
# from .patched.patchedSynapse import PatchedSynapse
36+
# from .patched.staticPatchedSynapse import StaticPatchedSynapse
37+
# from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse
38+
#

ngclearn/components/synapses/alphaSynapse.py

Lines changed: 51 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from jax import random, numpy as jnp, jit
2-
from ngcsimlib.compilers.process import transition
3-
from ngcsimlib.component import Component
4-
from ngcsimlib.compartment import Compartment
5-
2+
from ngclearn.components.jaxComponent import JaxComponent
3+
from ngclearn.utils import tensorstats
64
from ngclearn.utils.weight_distribution import initialize_params
75
from ngcsimlib.logger import info
6+
87
from ngclearn.components.synapses import DenseSynapse
9-
from ngclearn.utils import tensorstats
8+
from ngcsimlib.compartment import Compartment
9+
from ngcsimlib.parser import compilable
1010

1111
class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable
1212
"""
@@ -64,8 +64,8 @@ class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable
6464

6565
# Define Functions
6666
def __init__(
67-
self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
68-
is_nonplastic=True, **kwargs
67+
self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1.,
68+
p_conn=1., is_nonplastic=True, **kwargs
6969
):
7070
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
7171
## dynamic synapse meta-parameters
@@ -82,55 +82,55 @@ def __init__(
8282
self.g_syn = Compartment(postVals) ## conductance variable
8383
self.h_syn = Compartment(postVals) ## intermediate conductance variable
8484
if is_nonplastic:
85-
self.weights.set(self.weights.value * 0 + 1.)
85+
self.weights.set(self.weights.get() * 0 + 1.)
8686

87-
@transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"])
88-
@staticmethod
89-
def advance_state(
90-
dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v
91-
):
92-
s = inputs
87+
@compilable
88+
def advance_state(self, t, dt):
89+
s = self.inputs.get()
9390
## advance conductance variable(s)
94-
_out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
95-
dhsyn_dt = -h_syn/tau_decay + (_out * g_syn_bar) * (1./dt)
96-
h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
91+
_out = jnp.matmul(s, self.weights.get()) ## sum all pre-syn spikes at t going into post-neuron)
92+
dhsyn_dt = -self.h_syn.get()/self.tau_decay + (_out * self.g_syn_bar) * (1./dt)
93+
h_syn = self.h_syn.get() + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
9794

98-
dgsyn_dt = -g_syn/tau_decay + h_syn * (1./dt) # or -g_syn/tau_decay + h_syn/tau_decay
99-
g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g
95+
dgsyn_dt = -self.g_syn.get()/self.tau_decay + h_syn * (1./dt) # or -g_syn/tau_decay + h_syn/tau_decay
96+
g_syn = self.g_syn.get() + dgsyn_dt * dt ## run Euler step to move conductance g
10097

10198
## compute derive electrical current variable
102-
i_syn = -g_syn * Rscale
103-
if syn_rest is not None:
104-
i_syn = -(g_syn * Rscale) * (v - syn_rest)
105-
outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases
106-
return outputs, i_syn, g_syn, h_syn
107-
108-
@transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "h_syn", "v"])
109-
@staticmethod
110-
def reset(batch_size, shape):
111-
preVals = jnp.zeros((batch_size, shape[0]))
112-
postVals = jnp.zeros((batch_size, shape[1]))
113-
inputs = preVals
114-
outputs = postVals
115-
i_syn = postVals
116-
g_syn = postVals
117-
h_syn = postVals
118-
v = postVals
119-
return inputs, outputs, i_syn, g_syn, h_syn, v
120-
121-
def save(self, directory, **kwargs):
122-
file_name = directory + "/" + self.name + ".npz"
123-
if self.bias_init != None:
124-
jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
125-
else:
126-
jnp.savez(file_name, weights=self.weights.value)
127-
128-
def load(self, directory, **kwargs):
129-
file_name = directory + "/" + self.name + ".npz"
130-
data = jnp.load(file_name)
131-
self.weights.set(data['weights'])
132-
if "biases" in data.keys():
133-
self.biases.set(data['biases'])
99+
i_syn = -g_syn * self.resist_scale
100+
if self.syn_rest is not None:
101+
i_syn = -(g_syn * self.resist_scale) * (self.v.get() - self.syn_rest)
102+
outputs = i_syn #jnp.matmul(inputs, Wdyn * self.resist_scale) + biases
103+
104+
self.outputs.set(outputs)
105+
self.i_syn.set(i_syn)
106+
self.g_syn.set(g_syn)
107+
self.h_syn.set(h_syn)
108+
109+
@compilable
110+
def reset(self):
111+
preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
112+
postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
113+
if not self.inputs.targeted:
114+
self.inputs.set(preVals)
115+
self.outputs.set(postVals)
116+
self.i_syn.set(postVals)
117+
self.g_syn.set(postVals)
118+
self.h_syn.set(postVals)
119+
self.v.set(postVals)
120+
121+
# def save(self, directory, **kwargs):
122+
# file_name = directory + "/" + self.name + ".npz"
123+
# if self.bias_init != None:
124+
# jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
125+
# else:
126+
# jnp.savez(file_name, weights=self.weights.value)
127+
#
128+
# def load(self, directory, **kwargs):
129+
# file_name = directory + "/" + self.name + ".npz"
130+
# data = jnp.load(file_name)
131+
# self.weights.set(data['weights'])
132+
# if "biases" in data.keys():
133+
# self.biases.set(data['biases'])
134134

135135
@classmethod
136136
def help(cls): ## component help function
@@ -170,17 +170,3 @@ def help(cls): ## component help function
170170
"dgsyn_dt = -g_syn/tau_decay + h_syn",
171171
"hyperparameters": hyperparams}
172172
return info
173-
174-
def __repr__(self):
175-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
176-
maxlen = max(len(c) for c in comps) + 5
177-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
178-
for c in comps:
179-
stats = tensorstats(getattr(self, c).value)
180-
if stats is not None:
181-
line = [f"{k}: {v}" for k, v in stats.items()]
182-
line = ", ".join(line)
183-
else:
184-
line = "None"
185-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
186-
return lines

ngclearn/components/synapses/denseSynapse.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable
4141

4242
# Define Functions
4343
def __init__(
44-
self, name, shape, weight_init=None, bias_init=None, resist_scale=1.,
45-
p_conn=1., batch_size=1, **kwargs
44+
self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs
4645
):
4746
super().__init__(name, **kwargs)
4847

ngclearn/components/synapses/exponentialSynapse.py

Lines changed: 49 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from jax import random, numpy as jnp, jit
2-
from ngcsimlib.compilers.process import transition
3-
from ngcsimlib.component import Component
4-
from ngcsimlib.compartment import Compartment
5-
2+
from ngclearn.components.jaxComponent import JaxComponent
3+
from ngclearn.utils import tensorstats
64
from ngclearn.utils.weight_distribution import initialize_params
75
from ngcsimlib.logger import info
6+
87
from ngclearn.components.synapses import DenseSynapse
9-
from ngclearn.utils import tensorstats
8+
from ngcsimlib.compartment import Compartment
9+
from ngcsimlib.parser import compilable
1010

1111
class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable
1212
"""
@@ -63,8 +63,8 @@ class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable
6363

6464
# Define Functions
6565
def __init__(
66-
self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
67-
is_nonplastic=True, **kwargs
66+
self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1.,
67+
p_conn=1., is_nonplastic=True, **kwargs
6868
):
6969
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
7070
## dynamic synapse meta-parameters
@@ -80,50 +80,51 @@ def __init__(
8080
self.i_syn = Compartment(postVals) ## electrical current output
8181
self.g_syn = Compartment(postVals) ## conductance variable
8282
if is_nonplastic:
83-
self.weights.set(self.weights.value * 0 + 1.)
83+
self.weights.set(self.weights.get() * 0 + 1.)
8484

85-
@transition(output_compartments=["outputs", "i_syn", "g_syn"])
86-
@staticmethod
87-
def advance_state(
88-
dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, v
89-
):
90-
s = inputs
85+
# @transition(output_compartments=["outputs", "i_syn", "g_syn"])
86+
# @staticmethod
87+
@compilable
88+
def advance_state(self, t, dt): #dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, v
89+
s = self.inputs.get()
9190
## advance conductance variable
92-
_out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
93-
dgsyn_dt = -g_syn/tau_decay + (_out * g_syn_bar) * (1./dt)
94-
g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance
91+
_out = jnp.matmul(s, self.weights.get()) ## sum all pre-syn spikes at t going into post-neuron)
92+
dgsyn_dt = -self.g_syn.get()/self.tau_decay + (_out * self.g_syn_bar) * (1./dt)
93+
g_syn = self.g_syn.get() + dgsyn_dt * dt ## run Euler step to move conductance
9594
## compute derive electrical current variable
96-
i_syn = -g_syn * Rscale
97-
if syn_rest is not None:
98-
i_syn = -(g_syn * Rscale) * (v - syn_rest)
99-
outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases
100-
return outputs, i_syn, g_syn
101-
102-
@transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "v"])
103-
@staticmethod
104-
def reset(batch_size, shape):
105-
preVals = jnp.zeros((batch_size, shape[0]))
106-
postVals = jnp.zeros((batch_size, shape[1]))
107-
inputs = preVals
108-
outputs = postVals
109-
i_syn = postVals
110-
g_syn = postVals
111-
v = postVals
112-
return inputs, outputs, i_syn, g_syn, v
113-
114-
def save(self, directory, **kwargs):
115-
file_name = directory + "/" + self.name + ".npz"
116-
if self.bias_init != None:
117-
jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
118-
else:
119-
jnp.savez(file_name, weights=self.weights.value)
120-
121-
def load(self, directory, **kwargs):
122-
file_name = directory + "/" + self.name + ".npz"
123-
data = jnp.load(file_name)
124-
self.weights.set(data['weights'])
125-
if "biases" in data.keys():
126-
self.biases.set(data['biases'])
95+
i_syn = -g_syn * self.resist_scale
96+
if self.syn_rest is not None:
97+
i_syn = -(g_syn * self.resist_scale) * (self.v.get() - self.syn_rest)
98+
outputs = i_syn #jnp.matmul(inputs, Wdyn * self.resist_scale) + biases
99+
100+
self.outputs.set(outputs)
101+
self.i_syn.set(i_syn)
102+
self.g_syn.set(g_syn)
103+
104+
@compilable
105+
def reset(self):
106+
preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
107+
postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
108+
if not self.inputs.targeted:
109+
self.inputs.set(preVals)
110+
self.outputs.set(postVals)
111+
self.i_syn.set(postVals)
112+
self.g_syn.set(postVals)
113+
self.v.set(postVals)
114+
115+
# def save(self, directory, **kwargs):
116+
# file_name = directory + "/" + self.name + ".npz"
117+
# if self.bias_init != None:
118+
# jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
119+
# else:
120+
# jnp.savez(file_name, weights=self.weights.value)
121+
#
122+
# def load(self, directory, **kwargs):
123+
# file_name = directory + "/" + self.name + ".npz"
124+
# data = jnp.load(file_name)
125+
# self.weights.set(data['weights'])
126+
# if "biases" in data.keys():
127+
# self.biases.set(data['biases'])
127128

128129
@classmethod
129130
def help(cls): ## component help function
@@ -162,17 +163,3 @@ def help(cls): ## component help function
162163
"dgsyn_dt = (W * inputs) * g_syn_bar - g_syn/tau_decay ",
163164
"hyperparameters": hyperparams}
164165
return info
165-
166-
def __repr__(self):
167-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
168-
maxlen = max(len(c) for c in comps) + 5
169-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
170-
for c in comps:
171-
stats = tensorstats(getattr(self, c).value)
172-
if stats is not None:
173-
line = [f"{k}: {v}" for k, v in stats.items()]
174-
line = ", ".join(line)
175-
else:
176-
line = "None"
177-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
178-
return lines

0 commit comments

Comments
 (0)