|
3 | 3 | from jax import numpy as jnp, random, jit, nn |
4 | 4 | from functools import partial |
5 | 5 | from ngclearn.utils import tensorstats |
6 | | -from ngcsimlib.deprecators import deprecate_args |
| 6 | +from ngcsimlib import deprecate_args |
7 | 7 | from ngcsimlib.logger import info, warn |
8 | 8 |
|
9 | | -from ngcsimlib.compilers.process import transition |
10 | | -#from ngcsimlib.component import Component |
| 9 | +from ngcsimlib.parser import compilable |
11 | 10 | from ngcsimlib.compartment import Compartment |
12 | 11 | from ngclearn.utils.model_utils import softmax |
13 | 12 |
|
@@ -88,42 +87,50 @@ def __init__( |
88 | 87 | self.rfr = Compartment(restVals + self.refract_T) |
89 | 88 | self.tols = Compartment(restVals) ## time-of-last-spike |
90 | 89 |
|
91 | | - @transition(output_compartments=["v", "s", "thr", "rfr", "tols"]) |
92 | | - @staticmethod |
93 | | - def advance_state(t, dt, tau_m, R_m, thr_gain, refract_T, j, v, thr, rfr, tols): |
94 | | - mask = (rfr >= refract_T) * 1. ## check refractory period |
95 | | - v = (j * R_m) * mask |
| 90 | + # @transition(output_compartments=["v", "s", "thr", "rfr", "tols"]) |
| 91 | + # @staticmethod |
| 92 | + @compilable |
| 93 | + def advance_state( |
| 94 | + self, t, dt #, tau_m, R_m, thr_gain, refract_T, j, v, thr, rfr, tols |
| 95 | + ): |
| 96 | + mask = (self.rfr.get() >= self.refract_T) * 1. ## check refractory period |
| 97 | + v = (self.j.get() * self.R_m) * mask |
96 | 98 | vp = softmax(v) # convert to Categorical (spike) probabilities |
97 | 99 | # s = nn.one_hot(jnp.argmax(vp, axis=1), j.shape[1]) ## hard-max spike |
98 | | - s = (vp > thr) * 1. ## calculate action potential |
| 100 | + s = (vp > self.thr.get()) * 1. ## calculate action potential |
99 | 101 | q = 1. ## Note: thr_gain ==> "rho_b" |
100 | 102 | ## increment threshold upon spike(s) occurrence |
101 | 103 | dthr = jnp.sum(s, axis=1, keepdims=True) - q |
102 | | - thr = jnp.maximum(thr + dthr * thr_gain, 0.025) ## calc new threshold |
103 | | - rfr = (rfr + dt) * (1. - s) + s * dt # set refract to dt |
104 | | - |
105 | | - tols = (1. - s) * tols + (s * t) ## update tols |
106 | | - return v, s, thr, rfr, tols |
107 | | - |
108 | | - @transition(output_compartments=["j", "v", "s", "rfr", "tols"]) |
109 | | - @staticmethod |
110 | | - def reset(batch_size, n_units, refract_T): |
111 | | - restVals = jnp.zeros((batch_size, n_units)) |
112 | | - j = restVals #+ 0 |
113 | | - v = restVals #+ 0 |
114 | | - s = restVals #+ 0 |
115 | | - rfr = restVals + refract_T |
116 | | - tols = restVals #+ 0 |
117 | | - return j, v, s, rfr, tols |
118 | | - |
119 | | - def save(self, directory, **kwargs): |
120 | | - file_name = directory + "/" + self.name + ".npz" |
121 | | - jnp.savez(file_name, threshold=self.thr.value) |
122 | | - |
123 | | - def load(self, directory, seeded=False, **kwargs): |
124 | | - file_name = directory + "/" + self.name + ".npz" |
125 | | - data = jnp.load(file_name) |
126 | | - self.thr.set( data['threshold'] ) |
| 104 | + thr = jnp.maximum(self.thr.get() + dthr * self.thr_gain, 0.025) ## calc new threshold |
| 105 | + rfr = (self.rfr.get() + dt) * (1. - s) + s * dt # set refract to dt |
| 106 | + |
| 107 | + self.tols.set((1. - s) * self.tols.get() + (s * t)) ## update times-of-last-spike(s) |
| 108 | + |
| 109 | + self.v.set(v) |
| 110 | + self.s.set(s) |
| 111 | + self.thr.set(thr) |
| 112 | + self.rfr.set(rfr) |
| 113 | + |
| 114 | + # @transition(output_compartments=["j", "v", "s", "rfr", "tols"]) |
| 115 | + # @staticmethod |
| 116 | + @compilable |
| 117 | + def reset(self): |
| 118 | + restVals = jnp.zeros((self.batch_size, self.n_units)) |
| 119 | + if not self.j.targeted: |
| 120 | + self.j.set(restVals) |
| 121 | + self.v.set(restVals) |
| 122 | + self.s.set(restVals) |
| 123 | + self.rfr.set(restVals + self.refract_T) |
| 124 | + self.tols.set(restVals) |
| 125 | + |
| 126 | + # def save(self, directory, **kwargs): |
| 127 | + # file_name = directory + "/" + self.name + ".npz" |
| 128 | + # jnp.savez(file_name, threshold=self.thr.value) |
| 129 | + # |
| 130 | + # def load(self, directory, seeded=False, **kwargs): |
| 131 | + # file_name = directory + "/" + self.name + ".npz" |
| 132 | + # data = jnp.load(file_name) |
| 133 | + # self.thr.set( data['threshold'] ) |
127 | 134 |
|
128 | 135 | @classmethod |
129 | 136 | def help(cls): ## component help function |
|
0 commit comments