|
2 | 2 | from jax import numpy as jnp, random, jit, nn |
3 | 3 | from functools import partial |
4 | 4 | from ngclearn.utils import tensorstats |
5 | | -from ngcsimlib.deprecators import deprecate_args |
| 5 | +from ngcsimlib import deprecate_args |
6 | 6 | from ngcsimlib.logger import info, warn |
7 | 7 | from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ |
8 | 8 | step_euler, step_rk2 |
9 | | -from ngcsimlib.compilers.process import transition |
10 | | -#from ngcsimlib.component import Component |
| 9 | + |
| 10 | +from ngcsimlib.parser import compilable |
11 | 11 | from ngcsimlib.compartment import Compartment |
12 | 12 |
|
13 | 13 | @jit |
@@ -97,7 +97,7 @@ class AdExCell(JaxComponent): |
97 | 97 | at an increase in computational cost (and simulation time) |
98 | 98 | """ |
99 | 99 |
|
100 | | - @deprecate_args(v_thr="thr") |
| 100 | + #@deprecate_args(v_thr="thr") |
101 | 101 | def __init__( |
102 | 102 | self, name, n_units, tau_m=15., resist_m=1., tau_w=400., v_sharpness=2., intrinsic_mem_thr=-55., thr=5., |
103 | 103 | v_rest=-72., v_reset=-75., a=0.1, b=0.75, v0=-70., w0=0., integration_type="euler", batch_size=1, **kwargs |
@@ -136,39 +136,40 @@ def __init__( |
136 | 136 | self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", |
137 | 137 | units="ms") ## time-of-last-spike |
138 | 138 |
|
139 | | - @transition(output_compartments=["j", "v", "w", "s", "tols"]) |
140 | | - @staticmethod |
141 | | - def advance_state( |
142 | | - t, dt, tau_m, R_m, tau_w, thr, a, b, sharpV, vT, v_rest, v_reset, intgFlag, j, v, w, tols |
143 | | - ): |
144 | | - if intgFlag == 1: ## RK-2/midpoint |
145 | | - v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m) |
146 | | - _, _v = step_rk2(0., v, _dfv, dt, v_params) |
147 | | - w_params = (j, v, a, tau_w, v_rest) |
148 | | - _, _w = step_rk2(0., w, _dfw, dt, w_params) |
| 139 | + @compilable |
| 140 | + def advance_state(self, t, dt): |
| 141 | + if self.intgFlag == 1: ## RK-2/midpoint |
| 142 | + v_params = (self.j.get(), self.w.get(), self.tau_m, self.v_rest, self.sharpV, self.vT, self.R_m) |
| 143 | + _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) |
| 144 | + w_params = (self.j.get(), self.v.get(), self.a, self.tau_w, self.v_rest) |
| 145 | + _, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params) |
149 | 146 | else: # intgFlag == 0 (default -- Euler) |
150 | | - v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m) |
151 | | - _, _v = step_euler(0., v, _dfv, dt, v_params) |
152 | | - w_params = (j, v, a, tau_w, v_rest) |
153 | | - _, _w = step_euler(0., w, _dfw, dt, w_params) |
154 | | - s = (_v > thr) * 1. ## emit spikes/pulses |
| 147 | + v_params = (self.j.get(), self.w.get(), self.tau_m, self.v_rest, self.sharpV, self.vT, self.R_m) |
| 148 | + _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) |
| 149 | + w_params = (self.j.get(), self.v.get(), self.a, self.tau_w, self.v_rest) |
| 150 | + _, _w = step_euler(0., self.w.get(), _dfw, dt, w_params) |
| 151 | + s = (_v > self.thr) * 1. ## emit spikes/pulses |
155 | 152 | ## hyperpolarize/reset/snap variables |
156 | | - v = _v * (1. - s) + s * v_reset |
157 | | - w = _w * (1. - s) + s * (_w + b) |
158 | | - |
159 | | - tols = (1. - s) * tols + (s * t) ## update time-of-last spike variable(s) |
160 | | - return j, v, w, s, tols |
161 | | - |
162 | | - @transition(output_compartments=["j", "v", "w", "s", "tols"]) |
163 | | - @staticmethod |
164 | | - def reset(batch_size, n_units, v0, w0): |
165 | | - restVals = jnp.zeros((batch_size, n_units)) |
166 | | - j = restVals # None |
167 | | - v = restVals + v0 |
168 | | - w = restVals + w0 |
169 | | - s = restVals #+ 0 |
170 | | - tols = restVals #+ 0 |
171 | | - return j, v, w, s, tols |
| 153 | + v = _v * (1. - s) + s * self.v_reset |
| 154 | + w = _w * (1. - s) + s * (_w + self.b) |
| 155 | + |
| 156 | + ## update time-of-last spike variable(s) |
| 157 | + self.tols.set((1. - s) * self.tols.get() + (s * t)) |
| 158 | + |
| 159 | + #self.j.set(j) ## j is not getting modified in these dynamics |
| 160 | + self.v.set(v) |
| 161 | + self.w.set(w) |
| 162 | + self.s.set(s) |
| 163 | + |
| 164 | + @compilable |
| 165 | + def reset(self): |
| 166 | + restVals = jnp.zeros((self.batch_size, self.n_units)) |
| 167 | + if not self.j.targeted: |
| 168 | + self.j.set(restVals) |
| 169 | + self.v.set(restVals + self.v0) |
| 170 | + self.w.set(restVals + self.w0) |
| 171 | + self.s.set(restVals) |
| 172 | + self.tols.set(restVals) |
172 | 173 |
|
173 | 174 | @classmethod |
174 | 175 | def help(cls): ## component help function |
|
0 commit comments