@@ -53,19 +53,23 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell
5353 :Note: setting the integration type to the midpoint method will increase the accuracy of the estimate of
5454 the cell's evolution at an increase in computational cost (and simulation time)
5555
56- sigma_rec: noise scaling factor / standard deviation (Default: 1)
56+ sigma_pre: pre-rectification noise scaling factor / standard deviation (Default: 0.1)
57+
58+ sigma_post: post-rectification noise scaling factor / standard deviation (Default: 0.)
59+
60+ leak_scale: degree to which membrane leak should be scaled (Default: 1)
5761 """
5862
59- # Define Functions
6063 def __init__ (
61- self , name , n_units , tau_x , act_fx = "relu" , integration_type = "euler" , batch_size = 1 , sigma_rec = 1. ,
62- leak_scale = 1. , shape = None , ** kwargs
64+ self , name , n_units , tau_x , act_fx = "relu" , integration_type = "euler" , batch_size = 1 , sigma_pre = 0.1 ,
65+ sigma_post = 0.1 , leak_scale = 1. , shape = None , ** kwargs
6366 ):
6467 super ().__init__ (name , ** kwargs )
6568
6669
6770 self .tau_x = tau_x
68- self .sigma_rec = sigma_rec ## a "resistance" scaling factor
71+ self .sigma_pre = sigma_pre ## a pre-rectification scaling factor
72+ self .sigma_post = sigma_post ## a post-rectification scaling factor
6973 self .leak_scale = leak_scale ## the leak scaling factor (most appropriate default is 1)
7074
7175 ## integration properties
@@ -94,9 +98,12 @@ def __init__(
9498
9599 @compilable
96100 def advance_state (self , t , dt ):
97- ### run a step of integration over neuronal dynamics
101+ ## run a step of integration over neuronal dynamics
102+ ### Note: self.fx is the "rectifier" (rectification function)
103+ key , skey = random .split (self .key .get (), 2 )
104+ eps_pre = random .normal (skey , shape = self .x .get ().shape ) ## pre-rectifier distributional noise
98105 key , skey = random .split (self .key .get (), 2 )
99- eps = random .normal (skey , shape = self .x .get ().shape ) ## sample of unit distributional noise
106+ eps_post = random .normal (skey , shape = self .x .get ().shape ) ## post-rectifier distributional noise
100107
101108 #x = _run_cell(dt, self.j_input.get(), self.j_recurrent.get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag)
102109 _step_fns = {
@@ -105,13 +112,13 @@ def advance_state(self, t, dt):
105112 2 : step_rk4 ,
106113 }
107114 _step_fn = _step_fns [self .intgFlag ] #_step_fns.get(self.intgFlag, step_euler)
108- params = (self .j_input .get (), self .j_recurrent .get (), eps , self .tau_x , self .sigma_rec , self .leak_scale )
115+ params = (self .j_input .get (), self .j_recurrent .get (), eps_pre , self .tau_x , self .sigma_pre , self .leak_scale )
109116 _ , x = _step_fn (0. , self .x .get (), _dfz , dt , params ) ## update state activation dynamics
110- r = self .fx (x ) ## calculate (rectified) activity rates; f(x)
117+ r = self .fx (x ) + ( eps_post * self . sigma_post ) ## calculate (rectified) activity rates; f(x)
111118 r_prime = self .dfx (x ) ## calculate local deriv of activity rates; f'(x)
112119
113120 ## set compartments to next state values in accordance with dynamics
114- self .key .set (key )
121+ self .key .set (key ) ## carry noise key over transition (to next state of component)
115122 self .x .set (x )
116123 self .r .set (r )
117124 self .r_prime .set (r_prime )
@@ -146,7 +153,7 @@ def help(cls): ## component help function
146153 "n_units" : "Number of neuronal cells to model in this layer" ,
147154 "batch_size" : "Batch size dimension of this component" ,
148155 "tau_x" : "State time constant" ,
149- "sigma_rec " : "The non-zero degree/scale of noise to inject into this neuron"
156+ "sigma_pre " : "The non-zero degree/scale of (pre-rectification) noise to inject into this neuron"
150157 }
151158 info = {cls .__name__ : properties ,
152159 "compartments" : compartment_props ,
0 commit comments