Skip to content

Commit 098f3db

Browse files
author
Alexander Ororbia
committed
patched ode_utils backend wrt jax, cleaned up unit-tests, added disable flag for phasor-cell
1 parent 9ea587a commit 098f3db

File tree

4 files changed

+26
-23
lines changed

4 files changed

+26
-23
lines changed

ngclearn/components/input_encoders/phasorCell.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class PhasorCell(JaxComponent):
3434
"""
3535

3636
# Define Functions
37-
def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
38-
**kwargs):
37+
def __init__(
38+
self, name, n_units, target_freq=63.75, batch_size=1, disable_phasor=False, **kwargs):
3939
super().__init__(name, **kwargs)
4040

4141
## Phasor meta-parameters
@@ -63,6 +63,7 @@ def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
6363
# beta = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq
6464

6565
self.base_scale = random.poisson(subkey[0], lam=target_freq, shape=self.angles.value.shape) / target_freq
66+
self.disable_phasor = disable_phasor
6667

6768
def validate(self, dt=None, **validation_kwargs):
6869
valid = super().validate(**validation_kwargs)
@@ -87,7 +88,7 @@ def validate(self, dt=None, **validation_kwargs):
8788

8889
@transition(output_compartments=["outputs", "tols", "key", "angles"])
8990
@staticmethod
90-
def advance_state(t, dt, target_freq, key, inputs, angles, tols, base_scale):
91+
def advance_state(t, dt, target_freq, key, inputs, angles, tols, base_scale, disable_phasor):
9192
ms_per_second = 1000 # ms/s
9293
events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
9394
ms_per_event = 1 / events_per_ms # ms/e
@@ -108,6 +109,8 @@ def advance_state(t, dt, target_freq, key, inputs, angles, tols, base_scale):
108109
updated_angles = jnp.where(updated_angles > angle_per_event,
109110
updated_angles - angle_per_event,
110111
updated_angles)
112+
if disable_phasor:
113+
outputs = inputs + 0
111114
tols = tols * (1. - outputs) + t * outputs
112115

113116
return outputs, tols, key, updated_angles

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#from ngcsimlib.component import Component
1515
from ngcsimlib.compartment import Compartment
1616

17-
@jit
17+
#@jit
1818
def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics
1919
mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
2020
## update voltage / membrane potential
@@ -188,16 +188,16 @@ def advance_state(
188188
j = j * resist_m
189189
############################################################################
190190
### Runs leaky integrator (leaky integrate-and-fire; LIF) neuronal dynamics.
191-
_v_thr = thr_theta + thr #v_theta + v_thr ## calc present voltage threshold
191+
_v_thr = thr_theta + thr ## calc present voltage threshold
192192
#mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
193193
## update voltage / membrane potential
194194
v_params = (j, rfr, tau_m, refract_T, v_rest, v_decay)
195195
if intgFlag == 1:
196196
_, _v = step_rk2(0., v, _dfv, dt, v_params)
197-
else: #_v = v + (v_rest - v) * (dt/tau_m) + (j * mask)
197+
else:
198198
_, _v = step_euler(0., v, _dfv, dt, v_params)
199-
## obtain action potentials/spikes
200-
s = (_v > _v_thr).astype(jnp.float32)
199+
## obtain action potentials/spikes/pulses
200+
s = (_v > _v_thr) * 1.
201201
## update refractory variables
202202
_rfr = (rfr + dt) * (1. - s)
203203
## perform hyper-polarization of neuronal cells

ngclearn/utils/diffeq/ode_utils.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@ def get_integrator_code(integrationType): ## integrator type decoding routine
4545
@jit
4646
def _sum_combine(*args, **kwargs): ## fast co-routine for simple addition/summation
4747
_sum = 0
48-
for arg, val in zip(args, kwargs.values()):
48+
for arg, val in zip(args, kwargs.values()): ## Sigma^I_{i=1} a_i
4949
_sum = _sum + val * arg
5050
return _sum
5151

5252
@jit
5353
def _step_forward(t, x, dx_dt, dt, x_scale): ## internal step co-routine
54-
_t = t + dt
55-
_x = x * x_scale + dx_dt * dt
54+
_t = t + dt ## advance time forward by dt (denominator)
55+
_x = x * x_scale + dx_dt * dt ## advance variable(s) forward by dt (numerator)
5656
return _t, _x
5757

58-
@partial(jit, static_argnums=(2, 3, 5)) #(2, 3, 4, 5)
58+
@partial(jit, static_argnums=(2))
5959
def step_euler(t, x, dfx, dt, params, x_scale=1.):
6060
"""
6161
Iteratively integrates one step forward via the Euler method, i.e., a
@@ -84,7 +84,7 @@ def step_euler(t, x, dfx, dt, params, x_scale=1.):
8484
_t, _x = next_state
8585
return _t, _x
8686

87-
@partial(jit, static_argnums=(1, 2, 4)) #(1, 2, 3, 4))
87+
@partial(jit, static_argnums=(1))
8888
def _euler(carry, dfx, dt, params, x_scale=1.):
8989
"""
9090
Iteratively integrates one step forward via the Euler method, i.e., a
@@ -112,7 +112,7 @@ def _euler(carry, dfx, dt, params, x_scale=1.):
112112
new_carry = (_t, _x)
113113
return new_carry, (new_carry, carry)
114114

115-
@partial(jit, static_argnums=(2, 3, 5)) #(2, 3, 4, 5))
115+
@partial(jit, static_argnums=(2))
116116
def step_heun(t, x, dfx, dt, params, x_scale=1.):
117117
"""
118118
Iteratively integrates one step forward via Heun's method, i.e., a
@@ -150,7 +150,7 @@ def step_heun(t, x, dfx, dt, params, x_scale=1.):
150150
_t, _x = next_state
151151
return _t, _x
152152

153-
@partial(jit, static_argnums=(1, 2, 4)) #(1, 2, 3, 4, ))
153+
@partial(jit, static_argnums=(1))
154154
def _heun(carry, dfx, dt, params, x_scale=1.):
155155
"""
156156
Iteratively integrates one step forward via Heun's method, i.e., a
@@ -189,7 +189,7 @@ def _heun(carry, dfx, dt, params, x_scale=1.):
189189
new_carry = (_t, _x)
190190
return new_carry, (new_carry, carry)
191191

192-
@partial(jit, static_argnums=(2, 3, 5)) #(2, 3, 4, 5))
192+
@partial(jit, static_argnums=(2))
193193
def step_rk2(t, x, dfx, dt, params, x_scale=1.):
194194
"""
195195
Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -224,7 +224,7 @@ def step_rk2(t, x, dfx, dt, params, x_scale=1.):
224224
_t, _x = next_state
225225
return _t, _x
226226

227-
@partial(jit, static_argnums=(1, 2, 4)) #(1, 2, 3, 4, ))
227+
@partial(jit, static_argnums=(1))
228228
def _rk2(carry, dfx, dt, params, x_scale=1.):
229229
"""
230230
Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -260,7 +260,7 @@ def _rk2(carry, dfx, dt, params, x_scale=1.):
260260
new_carry = (_t, _x)
261261
return new_carry, (new_carry, carry)
262262

263-
@partial(jit, static_argnums=(2, 3, 5)) #(2, 3, 4, 5))
263+
@partial(jit, static_argnums=(2))
264264
def step_rk4(t, x, dfx, dt, params, x_scale=1.):
265265
"""
266266
Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -295,7 +295,7 @@ def step_rk4(t, x, dfx, dt, params, x_scale=1.):
295295
_t, _x = next_state
296296
return _t, _x
297297

298-
@partial(jit, static_argnums=(1, 2, 4)) #(1, 2, 3, 4, ))
298+
@partial(jit, static_argnums=(1))
299299
def _rk4(carry, dfx, dt, params, x_scale=1.):
300300
"""
301301
Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -338,7 +338,7 @@ def _rk4(carry, dfx, dt, params, x_scale=1.):
338338
new_carry = (_t, _x)
339339
return new_carry, (new_carry, carry)
340340

341-
@partial(jit, static_argnums=(2, 3, 5)) #(2, 3, 4, 5))
341+
@partial(jit, static_argnums=(2))
342342
def step_ralston(t, x, dfx, dt, params, x_scale=1.):
343343
"""
344344
Iteratively integrates one step forward via Ralston's method, i.e., a
@@ -375,7 +375,7 @@ def step_ralston(t, x, dfx, dt, params, x_scale=1.):
375375
_t, _x = next_state
376376
return _t, _x
377377

378-
@partial(jit, static_argnums=(1, 2, 4)) #(1, 2, 3, 4,))
378+
@partial(jit, static_argnums=(1))
379379
def _ralston(carry, dfx, dt, params, x_scale=1.):
380380
"""
381381
Iteratively integrates one step forward via Ralston's method, i.e., a
@@ -416,7 +416,6 @@ def _ralston(carry, dfx, dt, params, x_scale=1.):
416416
new_carry = (_t, _x)
417417
return new_carry, (new_carry, carry)
418418

419-
420419
@partial(jit, static_argnums=(0, 3, 4, 5, 6, 7, 8))
421420
def solve_ode(method_name, t0, x0, T, dfx, dt, params=None, x_scale=1., sols_only=True):
422421
if method_name =='euler':

tests/components/input_encoders/test_phasorCell.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_phasorCell1():
2222
# T = 300 # ms
2323
# ---- build a simple Poisson cell system ----
2424
with Context(name) as ctx:
25-
a = PhasorCell(name="a", n_units=1, target_freq=1000., key=subkeys[0])
25+
a = PhasorCell(name="a", n_units=1, target_freq=1000., disable_phasor=True, key=subkeys[0])
2626

2727
advance_process = (Process()
2828
>> a.advance_state)
@@ -49,6 +49,7 @@ def clamp(x):
4949
outs.append(a.outputs.value)
5050
#print(a.outputs.value)
5151
outs = jnp.concatenate(outs, axis=1)
52+
#print(outs)
5253

5354
## output should equal input
5455
assert_array_equal(outs, x_seq)

0 commit comments

Comments
 (0)