Skip to content

Commit f72a063

Browse files
author
Alexander Ororbia
committed
cleaned-up/revised leaky-noise-cell
1 parent a8b156a commit f72a063

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

ngclearn/components/neurons/graded/leakyNoiseCell.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,23 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell
5353
:Note: setting the integration type to the midpoint method will increase the accuracy of the estimate of
5454
the cell's evolution at an increase in computational cost (and simulation time)
5555
56-
sigma_rec: noise scaling factor / standard deviation (Default: 1)
56+
sigma_pre: pre-rectification noise scaling factor / standard deviation (Default: 0.1)
57+
58+
sigma_post: post-rectification noise scaling factor / standard deviation (Default: 0.)
59+
60+
leak_scale: degree to which membrane leak should be scaled (Default: 1)
5761
"""
5862

59-
# Define Functions
6063
def __init__(
61-
self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_rec=1.,
62-
leak_scale=1., shape=None, **kwargs
64+
self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_pre=0.1,
65+
sigma_post=0.1, leak_scale=1., shape=None, **kwargs
6366
):
6467
super().__init__(name, **kwargs)
6568

6669

6770
self.tau_x = tau_x
68-
self.sigma_rec = sigma_rec ## a "resistance" scaling factor
71+
self.sigma_pre = sigma_pre ## a pre-rectification scaling factor
72+
self.sigma_post = sigma_post ## a post-rectification scaling factor
6973
self.leak_scale = leak_scale ## the leak scaling factor (most appropriate default is 1)
7074

7175
## integration properties
@@ -94,9 +98,12 @@ def __init__(
9498

9599
@compilable
96100
def advance_state(self, t, dt):
97-
### run a step of integration over neuronal dynamics
101+
## run a step of integration over neuronal dynamics
102+
### Note: self.fx is the "rectifier" (rectification function)
103+
key, skey = random.split(self.key.get(), 2)
104+
eps_pre = random.normal(skey, shape=self.x.get().shape) ## pre-rectifier distributional noise
98105
key, skey = random.split(self.key.get(), 2)
99-
eps = random.normal(skey, shape=self.x.get().shape) ## sample of unit distributional noise
106+
eps_post = random.normal(skey, shape=self.x.get().shape) ## post-rectifier distributional noise
100107

101108
#x = _run_cell(dt, self.j_input.get(), self.j_recurrent.get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag)
102109
_step_fns = {
@@ -105,13 +112,13 @@ def advance_state(self, t, dt):
105112
2: step_rk4,
106113
}
107114
_step_fn = _step_fns[self.intgFlag] #_step_fns.get(self.intgFlag, step_euler)
108-
params = (self.j_input.get(), self.j_recurrent.get(), eps, self.tau_x, self.sigma_rec, self.leak_scale)
115+
params = (self.j_input.get(), self.j_recurrent.get(), eps_pre, self.tau_x, self.sigma_pre, self.leak_scale)
109116
_, x = _step_fn(0., self.x.get(), _dfz, dt, params) ## update state activation dynamics
110-
r = self.fx(x) ## calculate (rectified) activity rates; f(x)
117+
r = self.fx(x) + (eps_post * self.sigma_post) ## calculate (rectified) activity rates; f(x)
111118
r_prime = self.dfx(x) ## calculate local deriv of activity rates; f'(x)
112119

113120
## set compartments to next state values in accordance with dynamics
114-
self.key.set(key)
121+
self.key.set(key) ## carry noise key over transition (to next state of component)
115122
self.x.set(x)
116123
self.r.set(r)
117124
self.r_prime.set(r_prime)
@@ -146,7 +153,7 @@ def help(cls): ## component help function
146153
"n_units": "Number of neuronal cells to model in this layer",
147154
"batch_size": "Batch size dimension of this component",
148155
"tau_x": "State time constant",
149-
"sigma_rec": "The non-zero degree/scale of noise to inject into this neuron"
156+
"sigma_pre": "The non-zero degree/scale of (pre-rectification) noise to inject into this neuron"
150157
}
151158
info = {cls.__name__: properties,
152159
"compartments": compartment_props,

0 commit comments

Comments
 (0)