Skip to content

Commit 50f0db4

Browse files
author
Alexander Ororbia
committed
ported over/refactored WTASCell for v3
1 parent 0330504 commit 50f0db4

File tree

1 file changed

+41
-34
lines changed

1 file changed

+41
-34
lines changed

ngclearn/components/neurons/spiking/WTASCell.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
from jax import numpy as jnp, random, jit, nn
44
from functools import partial
55
from ngclearn.utils import tensorstats
6-
from ngcsimlib.deprecators import deprecate_args
6+
from ngcsimlib import deprecate_args
77
from ngcsimlib.logger import info, warn
88

9-
from ngcsimlib.compilers.process import transition
10-
#from ngcsimlib.component import Component
9+
from ngcsimlib.parser import compilable
1110
from ngcsimlib.compartment import Compartment
1211
from ngclearn.utils.model_utils import softmax
1312

@@ -88,42 +87,50 @@ def __init__(
8887
self.rfr = Compartment(restVals + self.refract_T)
8988
self.tols = Compartment(restVals) ## time-of-last-spike
9089

91-
@transition(output_compartments=["v", "s", "thr", "rfr", "tols"])
92-
@staticmethod
93-
def advance_state(t, dt, tau_m, R_m, thr_gain, refract_T, j, v, thr, rfr, tols):
94-
mask = (rfr >= refract_T) * 1. ## check refractory period
95-
v = (j * R_m) * mask
90+
# @transition(output_compartments=["v", "s", "thr", "rfr", "tols"])
91+
# @staticmethod
92+
@compilable
93+
def advance_state(
94+
self, t, dt #, tau_m, R_m, thr_gain, refract_T, j, v, thr, rfr, tols
95+
):
96+
mask = (self.rfr.get() >= self.refract_T) * 1. ## check refractory period
97+
v = (self.j.get() * self.R_m) * mask
9698
vp = softmax(v) # convert to Categorical (spike) probabilities
9799
# s = nn.one_hot(jnp.argmax(vp, axis=1), j.shape[1]) ## hard-max spike
98-
s = (vp > thr) * 1. ## calculate action potential
100+
s = (vp > self.thr.get()) * 1. ## calculate action potential
99101
q = 1. ## Note: thr_gain ==> "rho_b"
100102
## increment threshold upon spike(s) occurrence
101103
dthr = jnp.sum(s, axis=1, keepdims=True) - q
102-
thr = jnp.maximum(thr + dthr * thr_gain, 0.025) ## calc new threshold
103-
rfr = (rfr + dt) * (1. - s) + s * dt # set refract to dt
104-
105-
tols = (1. - s) * tols + (s * t) ## update tols
106-
return v, s, thr, rfr, tols
107-
108-
@transition(output_compartments=["j", "v", "s", "rfr", "tols"])
109-
@staticmethod
110-
def reset(batch_size, n_units, refract_T):
111-
restVals = jnp.zeros((batch_size, n_units))
112-
j = restVals #+ 0
113-
v = restVals #+ 0
114-
s = restVals #+ 0
115-
rfr = restVals + refract_T
116-
tols = restVals #+ 0
117-
return j, v, s, rfr, tols
118-
119-
def save(self, directory, **kwargs):
120-
file_name = directory + "/" + self.name + ".npz"
121-
jnp.savez(file_name, threshold=self.thr.value)
122-
123-
def load(self, directory, seeded=False, **kwargs):
124-
file_name = directory + "/" + self.name + ".npz"
125-
data = jnp.load(file_name)
126-
self.thr.set( data['threshold'] )
104+
thr = jnp.maximum(self.thr.get() + dthr * self.thr_gain, 0.025) ## calc new threshold
105+
rfr = (self.rfr.get() + dt) * (1. - s) + s * dt # set refract to dt
106+
107+
self.tols.set((1. - s) * self.tols.get() + (s * t)) ## update times-of-last-spike(s)
108+
109+
self.v.set(v)
110+
self.s.set(s)
111+
self.thr.set(thr)
112+
self.rfr.set(rfr)
113+
114+
# @transition(output_compartments=["j", "v", "s", "rfr", "tols"])
115+
# @staticmethod
116+
@compilable
117+
def reset(self):
118+
restVals = jnp.zeros((self.batch_size, self.n_units))
119+
if not self.j.targeted:
120+
self.j.set(restVals)
121+
self.v.set(restVals)
122+
self.s.set(restVals)
123+
self.rfr.set(restVals + self.refract_T)
124+
self.tols.set(restVals)
125+
126+
# def save(self, directory, **kwargs):
127+
# file_name = directory + "/" + self.name + ".npz"
128+
# jnp.savez(file_name, threshold=self.thr.value)
129+
#
130+
# def load(self, directory, seeded=False, **kwargs):
131+
# file_name = directory + "/" + self.name + ".npz"
132+
# data = jnp.load(file_name)
133+
# self.thr.set( data['threshold'] )
127134

128135
@classmethod
129136
def help(cls): ## component help function

0 commit comments

Comments
 (0)