Skip to content

Commit 0d1a35f

Browse files
author
Alexander Ororbia
committed
refactored short-term syn, tests passed - including stp-dense-syn and minor cleanup/edit to synapse __init__
1 parent 55756d0 commit 0d1a35f

File tree

7 files changed

+122
-171
lines changed

7 files changed

+122
-171
lines changed
Lines changed: 58 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from jax import random, numpy as jnp, jit
2-
from ngcsimlib.compilers.process import transition
3-
from ngcsimlib.component import Component
4-
from ngcsimlib.compartment import Compartment
5-
62
from ngclearn.utils.weight_distribution import initialize_params
73
from ngcsimlib.logger import info
4+
85
from ngclearn.components.synapses import DenseSynapse
9-
from ngclearn.utils import tensorstats
6+
from ngcsimlib.compartment import Compartment
7+
from ngcsimlib.parser import compilable
108

119
class STPDenseSynapse(DenseSynapse): ## short-term plastic synaptic cable
1210
"""
@@ -56,80 +54,82 @@ class STPDenseSynapse(DenseSynapse): ## short-term plastic synaptic cable
5654
resources_int: initialization kernel for synaptic resources matrix
5755
"""
5856

59-
# Define Functions
60-
def __init__(self, name, shape, weight_init=None, bias_init=None,
61-
resist_scale=1., p_conn=1., tau_f=750., tau_d=50.,
62-
resources_init=None, **kwargs):
57+
def __init__(
58+
self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., tau_f=750., tau_d=50.,
59+
resources_init=None, **kwargs
60+
):
6361
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
6462
## STP meta-parameters
6563
self.resources_init = resources_init
6664
self.tau_f = tau_f
6765
self.tau_d = tau_d
6866

6967
## Set up short-term plasticity / dynamic synapse compartment values
70-
tmp_key, *subkeys = random.split(self.key.value, 4)
68+
tmp_key, *subkeys = random.split(self.key.get(), 4)
7169
preVals = jnp.zeros((self.batch_size, shape[0]))
7270
self.u = Compartment(preVals) ## release prob variables
7371
self.x = Compartment(preVals + 1) ## resource availability variables
74-
self.Wdyn = Compartment(self.weights.value * 0) ## dynamic synapse values
72+
self.Wdyn = Compartment(self.weights.get() * 0) ## dynamic synapse values
7573
if self.resources_init is None:
7674
info(self.name, "is using default resources value initializer!")
7775
self.resources_init = {"dist": "uniform", "amin": 0.125, "amax": 0.175} # 0.15
7876
self.resources = Compartment(
7977
initialize_params(subkeys[2], self.resources_init, shape)
8078
) ## matrix U - synaptic resources matrix
8179

82-
@transition(output_compartments=["outputs", "u", "x", "Wdyn"])
83-
@staticmethod
84-
def advance_state(
85-
tau_f, tau_d, Rscale, inputs, weights, biases, resources, u, x, Wdyn
86-
):
87-
s = inputs
80+
@compilable
81+
def advance_state(self, t, dt):
82+
s = self.inputs.get()
8883
## compute short-term facilitation
8984
#u = u - u * (1./tau_f) + (resources * (1. - u)) * s
90-
if tau_f > 0.: ## compute short-term facilitation
91-
u = u - u * (1./tau_f) + (resources * (1. - u)) * s
85+
if self.tau_f > 0.: ## compute short-term facilitation
86+
u = self.u.get() - self.u.get() * (1./self.tau_f) + (self.resources.get() * (1. - self.u.get())) * s
9287
else:
93-
u = resources ## disabling STF yields fixed resource u variables
88+
u = self.resources.get() ## disabling STF yields fixed resource u variables
9489
## compute dynamic synaptic values/conductances
95-
Wdyn = (weights * u * x) * s + Wdyn * (1. - s) ## OR: -W/tau_w + W * u * x
96-
if tau_d > 0.:
97-
## compute short-term depression
98-
x = x + (1. - x) * (1./tau_d) - u * x * s
99-
outputs = jnp.matmul(inputs, Wdyn * Rscale) + biases
100-
return outputs, u, x, Wdyn
101-
102-
@transition(output_compartments=["inputs", "outputs", "u", "x", "Wdyn"])
103-
@staticmethod
104-
def reset(batch_size, shape):
105-
preVals = jnp.zeros((batch_size, shape[0]))
106-
postVals = jnp.zeros((batch_size, shape[1]))
107-
inputs = preVals
108-
outputs = postVals
109-
u = preVals
110-
x = preVals + 1
111-
Wdyn = jnp.zeros(shape)
112-
return inputs, outputs, u, x, Wdyn
113-
114-
def save(self, directory, **kwargs):
115-
file_name = directory + "/" + self.name + ".npz"
116-
if self.bias_init != None:
117-
jnp.savez(file_name,
118-
weights=self.weights.value,
119-
biases=self.biases.value,
120-
resources=self.resources.value)
121-
else:
122-
jnp.savez(file_name,
123-
weights=self.weights.value,
124-
resources=self.resources.value)
125-
126-
def load(self, directory, **kwargs):
127-
file_name = directory + "/" + self.name + ".npz"
128-
data = jnp.load(file_name)
129-
self.weights.set(data['weights'])
130-
self.resources.set(data['resources'])
131-
if "biases" in data.keys():
132-
self.biases.set(data['biases'])
90+
Wdyn = (self.weights.get() * u * self.x.get()) * s + self.Wdyn.get() * (1. - s) ## OR: -W/tau_w + W * u * x
91+
## compute short-term depression
92+
x = self.x.get()
93+
if self.tau_d > 0.:
94+
x = x + (1. - x) * (1./self.tau_d) - u * x * s
95+
## else, do nothing with x (keep it pointing to current x compartment)
96+
outputs = jnp.matmul(self.inputs.get(), Wdyn * self.resist_scale) + self.biases.get()
97+
98+
self.outputs.set(outputs)
99+
self.u.set(u)
100+
self.x.set(x)
101+
self.Wdyn.set(Wdyn)
102+
103+
@compilable
104+
def reset(self):
105+
preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
106+
postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
107+
if not self.inputs.targeted:
108+
self.inputs.set(preVals)
109+
self.outputs.set(postVals)
110+
self.u.set(preVals)
111+
self.x.set(preVals + 1)
112+
self.Wdyn.set(jnp.zeros(self.shape.get()))
113+
114+
# def save(self, directory, **kwargs):
115+
# file_name = directory + "/" + self.name + ".npz"
116+
# if self.bias_init != None:
117+
# jnp.savez(file_name,
118+
# weights=self.weights.value,
119+
# biases=self.biases.value,
120+
# resources=self.resources.value)
121+
# else:
122+
# jnp.savez(file_name,
123+
# weights=self.weights.value,
124+
# resources=self.resources.value)
125+
#
126+
# def load(self, directory, **kwargs):
127+
# file_name = directory + "/" + self.name + ".npz"
128+
# data = jnp.load(file_name)
129+
# self.weights.set(data['weights'])
130+
# self.resources.set(data['resources'])
131+
# if "biases" in data.keys():
132+
# self.biases.set(data['biases'])
133133

134134
@classmethod
135135
def help(cls): ## component help function
@@ -166,17 +166,3 @@ def help(cls): ## component help function
166166
"dW/dt = W_full * u * x * inputs",
167167
"hyperparameters": hyperparams}
168168
return info
169-
170-
def __repr__(self):
171-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
172-
maxlen = max(len(c) for c in comps) + 5
173-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
174-
for c in comps:
175-
stats = tensorstats(getattr(self, c).value)
176-
if stats is not None:
177-
line = [f"{k}: {v}" for k, v in stats.items()]
178-
line = ", ".join(line)
179-
else:
180-
line = "None"
181-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
182-
return lines

ngclearn/components/synapses/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44

55
## short-term plasticity components
6-
# from .STPDenseSynapse import STPDenseSynapse
6+
from .STPDenseSynapse import STPDenseSynapse
77
from .exponentialSynapse import ExponentialSynapse
8-
# from .doubleExpSynapse import DoupleExpSynapse
8+
from .doubleExpSynapse import DoupleExpSynapse
99
from .alphaSynapse import AlphaSynapse
1010
#
1111
# ## dense synaptic components

ngclearn/components/synapses/alphaSynapse.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn.components.jaxComponent import JaxComponent
3-
from ngclearn.utils import tensorstats
42
from ngclearn.utils.weight_distribution import initialize_params
53
from ngcsimlib.logger import info
64

ngclearn/components/synapses/doubleExpSynapse.py

Lines changed: 50 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from jax import random, numpy as jnp, jit
2-
from ngcsimlib.compilers.process import transition
3-
from ngcsimlib.component import Component
4-
from ngcsimlib.compartment import Compartment
5-
62
from ngclearn.utils.weight_distribution import initialize_params
73
from ngcsimlib.logger import info
4+
85
from ngclearn.components.synapses import DenseSynapse
9-
from ngclearn.utils import tensorstats
6+
from ngcsimlib.compartment import Compartment
7+
from ngcsimlib.parser import compilable
108

119
class DoupleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cable
1210
"""
@@ -66,8 +64,8 @@ class DoupleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cabl
6664

6765
# Define Functions
6866
def __init__(
69-
self, name, shape, tau_decay, tau_rise, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
70-
is_nonplastic=True, **kwargs
67+
self, name, shape, tau_decay, tau_rise, g_syn_bar, syn_rest, weight_init=None, bias_init=None,
68+
resist_scale=1., p_conn=1., is_nonplastic=True, **kwargs
7169
):
7270
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
7371
## dynamic synapse meta-parameters
@@ -85,57 +83,58 @@ def __init__(
8583
self.g_syn = Compartment(postVals) ## conductance variable
8684
self.h_syn = Compartment(postVals) ## intermediate conductance variable
8785
if is_nonplastic:
88-
self.weights.set(self.weights.value * 0 + 1.)
86+
self.weights.set(self.weights.get() * 0 + 1.)
8987

90-
@transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"])
91-
@staticmethod
92-
def advance_state(
93-
dt, tau_decay, tau_rise, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v
94-
):
95-
s = inputs
88+
@compilable
89+
def advance_state(self, t, dt): #dt, tau_decay, tau_rise, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v
90+
s = self.inputs.get()
9691
#A = tau_decay/(tau_decay - tau_rise) * jnp.power((tau_rise/tau_decay), tau_rise/(tau_rise - tau_decay))
97-
A = 1.
92+
A = 1. ## FIXME: scale factor to use?
9893
## advance conductance variable(s)
99-
_out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
100-
dhsyn_dt = -h_syn/tau_rise + ((_out * g_syn_bar) * (1. / tau_rise - 1. / tau_decay) * A) * (1./dt)
101-
h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
94+
_out = jnp.matmul(s, self.weights.get()) ## sum all pre-syn spikes at t going into post-neuron)
95+
dhsyn_dt = (-self.h_syn.get()/self.tau_rise +
96+
((_out * self.g_syn_bar) * (1. / self.tau_rise - 1. / self.tau_decay) * A) * (1./dt))
97+
h_syn = self.h_syn.get() + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
10298

103-
dgsyn_dt = -g_syn/tau_decay + h_syn * (1./dt)
104-
g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g
99+
dgsyn_dt = -self.g_syn.get()/self.tau_decay + h_syn * (1./dt)
100+
g_syn = self.g_syn.get() + dgsyn_dt * dt ## run Euler step to move conductance g
105101

106102
## compute derive electrical current variable
107-
i_syn = -g_syn * Rscale
108-
if syn_rest is not None:
109-
i_syn = -(g_syn * Rscale) * (v - syn_rest)
103+
i_syn = -g_syn * self.resist_scale
104+
if self.syn_rest is not None:
105+
i_syn = -(g_syn * self.resist_scale) * (self.v.get() - self.syn_rest)
110106
outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases
111-
return outputs, i_syn, g_syn, h_syn
112-
113-
@transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "h_syn", "v"])
114-
@staticmethod
115-
def reset(batch_size, shape):
116-
preVals = jnp.zeros((batch_size, shape[0]))
117-
postVals = jnp.zeros((batch_size, shape[1]))
118-
inputs = preVals
119-
outputs = postVals
120-
i_syn = postVals
121-
g_syn = postVals
122-
h_syn = postVals
123-
v = postVals
124-
return inputs, outputs, i_syn, g_syn, h_syn, v
125-
126-
def save(self, directory, **kwargs):
127-
file_name = directory + "/" + self.name + ".npz"
128-
if self.bias_init != None:
129-
jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
130-
else:
131-
jnp.savez(file_name, weights=self.weights.value)
132-
133-
def load(self, directory, **kwargs):
134-
file_name = directory + "/" + self.name + ".npz"
135-
data = jnp.load(file_name)
136-
self.weights.set(data['weights'])
137-
if "biases" in data.keys():
138-
self.biases.set(data['biases'])
107+
108+
self.outputs.set(outputs)
109+
self.i_syn.set(i_syn)
110+
self.g_syn.set(g_syn)
111+
self.h_syn.set(h_syn)
112+
113+
@compilable
114+
def reset(self):
115+
preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
116+
postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
117+
if not self.inputs.targeted:
118+
self.inputs.set(preVals)
119+
self.outputs.set(postVals)
120+
self.i_syn.set(postVals)
121+
self.g_syn.set(postVals)
122+
self.h_syn.set(postVals)
123+
self.v.set(postVals)
124+
125+
# def save(self, directory, **kwargs):
126+
# file_name = directory + "/" + self.name + ".npz"
127+
# if self.bias_init != None:
128+
# jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
129+
# else:
130+
# jnp.savez(file_name, weights=self.weights.value)
131+
#
132+
# def load(self, directory, **kwargs):
133+
# file_name = directory + "/" + self.name + ".npz"
134+
# data = jnp.load(file_name)
135+
# self.weights.set(data['weights'])
136+
# if "biases" in data.keys():
137+
# self.biases.set(data['biases'])
139138

140139
@classmethod
141140
def help(cls): ## component help function
@@ -176,17 +175,3 @@ def help(cls): ## component help function
176175
"dgsyn_dt = -g_syn/tau_decay + h_syn",
177176
"hyperparameters": hyperparams}
178177
return info
179-
180-
def __repr__(self):
181-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
182-
maxlen = max(len(c) for c in comps) + 5
183-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
184-
for c in comps:
185-
stats = tensorstats(getattr(self, c).value)
186-
if stats is not None:
187-
line = [f"{k}: {v}" for k, v in stats.items()]
188-
line = ", ".join(line)
189-
else:
190-
line = "None"
191-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
192-
return lines

ngclearn/components/synapses/exponentialSynapse.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn.components.jaxComponent import JaxComponent
3-
from ngclearn.utils import tensorstats
42
from ngclearn.utils.weight_distribution import initialize_params
53
from ngcsimlib.logger import info
64

@@ -82,10 +80,8 @@ def __init__(
8280
if is_nonplastic:
8381
self.weights.set(self.weights.get() * 0 + 1.)
8482

85-
# @transition(output_compartments=["outputs", "i_syn", "g_syn"])
86-
# @staticmethod
8783
@compilable
88-
def advance_state(self, t, dt): #dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, v
84+
def advance_state(self, t, dt):
8985
s = self.inputs.get()
9086
## advance conductance variable
9187
_out = jnp.matmul(s, self.weights.get()) ## sum all pre-syn spikes at t going into post-neuron)

0 commit comments

Comments
 (0)