Skip to content

Commit 35eae76

Browse files
committed
Additions for inhibition stuff
1 parent 6408ee0 commit 35eae76

File tree

5 files changed

+22
-17
lines changed

5 files changed

+22
-17
lines changed

ngclearn/components/base_monitor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ def watch(self, compartment, window_length):
124124
"""
125125
cs, end = self._add_path(compartment.path)
126126

127+
dtype = compartment.value.dtype
127128
shape = compartment.value.shape
128-
new_comp = Compartment(np.zeros(shape))
129-
new_comp_store = Compartment(np.zeros((window_length, *shape)))
129+
new_comp = Compartment(np.zeros(shape, dtype=dtype))
130+
new_comp_store = Compartment(np.zeros((window_length, *shape), dtype=dtype))
130131

131132
comp_key = "*".join(compartment.path.split("/"))
132133
store_comp_key = comp_key + "*store"

ngclearn/components/input_encoders/poissonCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class PoissonCell(JaxComponent):
4646
"""
4747

4848
@deprecate_args(max_freq="target_freq")
49-
def __init__(self, name, n_units, target_freq=0., batch_size=1, **kwargs):
49+
def __init__(self, name, n_units, target_freq=63.75, batch_size=1, **kwargs):
5050
super().__init__(name, **kwargs)
5151

5252
## Constrained Bernoulli meta-parameters

ngclearn/components/other/varTrace.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ngclearn.utils import tensorstats
66

77
@partial(jit, static_argnums=[4])
8-
def _run_varfilter(dt, x, x_tr, decayFactor, a_delta=0.):
8+
def _run_varfilter(dt, x, x_tr, decayFactor, gamma_tr, a_delta=0.):
99
"""
1010
Run variable trace filter (low-pass filter) dynamics one step forward.
1111
@@ -22,7 +22,7 @@ def _run_varfilter(dt, x, x_tr, decayFactor, a_delta=0.):
2222
Returns:
2323
updated trace/filter value/state
2424
"""
25-
_x_tr = x_tr * decayFactor
25+
_x_tr = gamma_tr * x_tr * decayFactor
2626
#x_tr + (-x_tr) * (dt / tau_tr) = (1 - dt/tau_tr) * x_tr
2727
if a_delta > 0.: ## perform additive form of trace ODE
2828
_x_tr = _x_tr + x * a_delta
@@ -64,13 +64,14 @@ class VarTrace(JaxComponent): ## low-pass filter
6464
"""
6565

6666
# Define Functions
67-
def __init__(self, name, n_units, tau_tr, a_delta, decay_type="exp",
67+
def __init__(self, name, n_units, tau_tr, a_delta, gamma_tr=1, decay_type="exp",
6868
batch_size=1, **kwargs):
6969
super().__init__(name, **kwargs)
7070

7171
## Trace control coefficients
7272
self.tau_tr = tau_tr ## trace time constant
7373
self.a_delta = a_delta ## trace increment (if spike occurred)
74+
self.gamma_tr = gamma_tr
7475
self.decay_type = decay_type ## lin --> linear decay; exp --> exponential decay
7576

7677
## Layer Size Setup
@@ -83,17 +84,20 @@ def __init__(self, name, n_units, tau_tr, a_delta, decay_type="exp",
8384
self.trace = Compartment(restVals)
8485

8586
@staticmethod
86-
def _advance_state(dt, decay_type, tau_tr, a_delta, inputs, trace):
87-
## compute the decay factor
88-
decayFactor = 0. ## <-- pulse filter decay (default)
87+
def _advance_state(dt, decay_type, tau_tr, a_delta, gamma_tr, inputs, trace):
88+
decayFactor = 0.
8989
if "exp" in decay_type:
9090
decayFactor = jnp.exp(-dt/tau_tr)
9191
elif "lin" in decay_type:
9292
decayFactor = (1. - dt/tau_tr)
93-
## else "step" == decay_type, yielding a step/pulse-like filter
94-
trace = _run_varfilter(dt, inputs, trace, decayFactor, a_delta)
95-
outputs = trace
96-
return outputs, trace
93+
94+
_x_tr = gamma_tr * trace * decayFactor
95+
if a_delta > 0.:
96+
_x_tr = _x_tr + inputs * a_delta
97+
else:
98+
_x_tr = _x_tr * (1. - inputs) + inputs
99+
100+
return trace, trace
97101

98102
@resolver(_advance_state)
99103
def advance_state(self, outputs, trace):

ngclearn/components/synapses/denseSynapse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def __init__(self, name, shape, weight_init=None, bias_init=None,
4747
self.bias_init = bias_init
4848

4949
## Synapse meta-parameters
50-
self.shape = shape ## shape of synaptic efficacy matrix
51-
self.Rscale = resist_scale ## post-transformation scale factor
50+
self.shape = shape
51+
self.Rscale = resist_scale
5252

5353
## Set up synaptic weight values
5454
tmp_key, *subkeys = random.split(self.key.value, 4)

ngclearn/utils/viz/synapse_plot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ def visualize_gif(frames, path='.', name='tmp', suffix='.jpg', **kwargs):
138138
_frames = [f.astype(jnp.uint8) for f in frames]
139139
iio.imwrite(path + '/' + name + '.gif', _frames, **kwargs)
140140

141-
def make_video(f_start, f_end, path, prefix, suffix='.jpg', skip=1):
141+
def make_video(f_start, f_end, path, prefix, suffix='.jpg', skip=1, **kwargs):
142142
images = []
143143
for i in range(f_start, f_end+1, skip):
144144
print("Reading frame " + str(i))
145145
images.append(iio.imread(path + "/" + prefix + str(i) + suffix))
146146
print("writing gif")
147-
iio.imwrite(path + '/training.gif', images, loop=0, duration=200)
147+
iio.imwrite(path + '/training.gif', images, **kwargs)
148148

149149

150150
# def visualize_norm(thetas, sizes, prefix, suffix='.jpg'):

0 commit comments

Comments
 (0)