Skip to content

Commit edc1803

Browse files
committed
update phasor cell
1 parent b98fd1a commit edc1803

File tree

2 files changed

+79
-29
lines changed

2 files changed

+79
-29
lines changed

ngclearn/components/input_encoders/phasorCell.py

Lines changed: 78 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import jax
44
from typing import Union
55

6+
from ngcsimlib.logger import info, warn
67
from ngcsimlib.compartment import Compartment
78
from ngcsimlib.parser import compilable
89

9-
1010
class PhasorCell(JaxComponent):
1111
"""
1212
A phasor cell that emits a pulse at a regular interval.
@@ -33,25 +33,19 @@ class PhasorCell(JaxComponent):
3333

3434
# Define Functions
3535
def __init__(
36-
self, name: str, n_units: int, target_freq: float = 63.75,
37-
batch_size: int = 1, key: Union[jax.Array, None] = None):
38-
super().__init__(name=name, key=key)
39-
40-
_key, subkey = random.split(self.key.get(), 2)
41-
self.key.set(_key)
36+
self, name, n_units, target_freq=63.75, batch_size=1, disable_phasor=False, **kwargs):
37+
super().__init__(name, **kwargs)
4238

4339
## Phasor meta-parameters
44-
self.target_freq = Compartment(target_freq, fixed=True) ## maximum frequency (in Hertz/Hz)
45-
self.base_scale = Compartment(random.poisson(subkey[0], lam=target_freq, shape=(batch_size, n_units)) / target_freq, fixed=True)
40+
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)
4641

4742
## Layer Size Setup
48-
self.batch_size = Compartment(batch_size, fixed=True)
49-
self.n_units = Compartment(n_units, fixed=True)
50-
51-
52-
43+
self.batch_size = batch_size
44+
self.n_units = n_units
45+
_key, *subkey = random.split(self.key.get(), 3)
46+
self.key.set(_key)
5347
## Compartment setup
54-
restVals = jnp.zeros((batch_size, n_units))
48+
restVals = jnp.zeros((self.batch_size, self.n_units))
5549
self.inputs = Compartment(restVals,
5650
display_name="Input Stimulus") # input
5751
# compartment
@@ -60,44 +54,100 @@ def __init__(
6054
self.tols = Compartment(initial_value=restVals,
6155
display_name="Time-of-Last-Spike", units="ms") # time of last spike
6256
self.angles = Compartment(restVals, display_name="Angles", units="deg")
63-
57+
# self.base_scale = random.uniform(subkey, self.angles.value.shape,
58+
# minval=0.75, maxval=1.25)
59+
# self.base_scale = ((random.normal(subkey, self.angles.value.shape) * 0.15) + 1)
60+
# alpha = ((random.normal(subkey, self.angles.value.shape) * (jnp.sqrt(target_freq) / target_freq)) + 1)
61+
# beta = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq
62+
63+
self.base_scale = random.poisson(subkey[0], lam=target_freq, shape=self.angles.get().shape) / target_freq
64+
self.disable_phasor = disable_phasor
65+
66+
def validate(self, dt=None, **validation_kwargs):
67+
valid = super().validate(**validation_kwargs)
68+
if dt is None:
69+
warn(f"{self.name} requires a validation kwarg of `dt`")
70+
return False
71+
## check for unstable combinations of dt and target-frequency
72+
# meta-params
73+
events_per_timestep = (dt / 1000.) * self.target_freq ##
74+
# compute scaled probability
75+
if events_per_timestep > 1.:
76+
valid = False
77+
warn(
78+
f"{self.name} will be unable to make as many temporal events "
79+
f"as "
80+
f"requested! ({events_per_timestep} events/timestep) Unstable "
81+
f"combination of dt = {dt} and target_freq = "
82+
f"{self.target_freq} "
83+
f"being used!"
84+
)
85+
return valid
86+
87+
# @transition(output_compartments=["outputs", "tols", "key", "angles"])
88+
# @staticmethod
6489
@compilable
65-
def advance_state(self, t, dt):
90+
def advance_state(self, t, dt, ):
91+
92+
inputs = self.inputs.get()
93+
angles = self.angles.get()
94+
tols = self.tols.get()
95+
6696
ms_per_second = 1000 # ms/s
67-
events_per_ms = self.target_freq.get() / ms_per_second # e/s s/ms -> e/ms
97+
events_per_ms = self.target_freq / ms_per_second # e/s s/ms -> e/ms
6898
ms_per_event = 1 / events_per_ms # ms/e
6999
time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e
70100
angle_per_event = 2 * jnp.pi # rad / e
71101
angle_per_timestep = angle_per_event / time_step_per_event # rad / e
72102
# * e/ts -> rad / ts
73103
key, *subkey = random.split(self.key.get(), 3)
104+
# scatter = random.uniform(subkey, angles.shape, minval=0.5,
105+
# maxval=1.5) * base_scale
74106

75-
scatter = ((random.normal(subkey[0], self.angles.get().shape) * 0.2) + 1) * self.base_scale.get()
107+
scatter = ((random.normal(subkey[0], angles.shape) * 0.2) + 1) * self.base_scale
76108
scattered_update = angle_per_timestep * scatter
77-
scaled_scattered_update = scattered_update * self.inputs.get()
78-
79-
updated_angles = self.angles.get() + scaled_scattered_update
80-
self.outputs.set(jnp.where(updated_angles > angle_per_event, 1., 0.))
109+
scaled_scattered_update = scattered_update * inputs
81110

82-
self.angles.set(jnp.where(updated_angles > angle_per_event,
111+
updated_angles = angles + scaled_scattered_update
112+
outputs = jnp.where(updated_angles > angle_per_event, 1., 0.)
113+
updated_angles = jnp.where(updated_angles > angle_per_event,
83114
updated_angles - angle_per_event,
84-
updated_angles))
115+
updated_angles)
116+
if self.disable_phasor:
117+
outputs = inputs + 0
118+
tols = tols * (1. - outputs) + t * outputs
119+
120+
self.outputs.set(outputs)
121+
self.tols.set(tols)
122+
self.key.set(key)
123+
self.angles.set(updated_angles)
85124

86-
self.tols.set(self.tols.get() * (1. - self.outputs.get()) + t * self.outputs.get())
87125

126+
# @transition(output_compartments=["inputs", "outputs", "tols", "angles", "key"])
127+
# @staticmethod
88128
@compilable
89129
def reset(self):
90-
restVals = jnp.zeros((self.batch_size.get(), self.n_units.get()))
130+
restVals = jnp.zeros((self.batch_size, self.n_units))
131+
key, *subkey = random.split(self.key.get(), 3)
132+
91133
# BUG: the self.inputs here does not have the targeted field
92134
# NOTE: Quick workaround is to check if targeted is in the input or not
93135
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
94136
self.outputs.set(restVals)
95137
self.tols.set(restVals)
96138
self.angles.set(restVals)
97-
key, _ = random.split(self.key.get(), 2)
98139
self.key.set(key)
99140

100141

142+
def save(self, directory, **kwargs):
143+
file_name = directory + "/" + self.name + ".npz"
144+
jnp.savez(file_name, key=self.key.value)
145+
146+
def load(self, directory, **kwargs):
147+
file_name = directory + "/" + self.name + ".npz"
148+
data = jnp.load(file_name)
149+
self.key.set(data['key'])
150+
101151
@classmethod
102152
def help(cls): ## component help function
103153
properties = {

tests/components/input_encoders/test_phasorCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ def clamp(x):
4444
## output should equal input
4545
assert_array_equal(outs, x_seq)
4646

47-
#test_phasorCell1()
47+
test_phasorCell1()

0 commit comments

Comments
 (0)