@@ -44,7 +44,7 @@ def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
4444 ## Layer Size Setup
4545 self .batch_size = batch_size
4646 self .n_units = n_units
47- _key , subkey = random .split (self .key .value , 2 )
47+ _key , * subkey = random .split (self .key .value , 3 )
4848 self .key .set (_key )
4949 ## Compartment setup
5050 restVals = jnp .zeros ((self .batch_size , self .n_units ))
@@ -62,7 +62,7 @@ def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
6262 # alpha = ((random.normal(subkey, self.angles.value.shape) * (jnp.sqrt(target_freq) / target_freq)) + 1)
6363 # beta = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq
6464
65- self .base_scale = random .poisson (subkey , lam = target_freq , shape = self .angles .value .shape ) / target_freq
65+ self .base_scale = random .poisson (subkey [ 0 ] , lam = target_freq , shape = self .angles .value .shape ) / target_freq
6666
6767 def validate (self , dt = None , ** validation_kwargs ):
6868 valid = super ().validate (** validation_kwargs )
@@ -95,11 +95,11 @@ def advance_state(t, dt, target_freq, key, inputs, angles, tols, base_scale):
9595 angle_per_event = 2 * jnp .pi # rad / e
9696 angle_per_timestep = angle_per_event / time_step_per_event # rad / e
9797 # * e/ts -> rad / ts
98- key , subkey = random .split (key , 2 )
98+ key , * subkey = random .split (key , 3 )
9999 # scatter = random.uniform(subkey, angles.shape, minval=0.5,
100100 # maxval=1.5) * base_scale
101101
102- scatter = ((random .normal (subkey , angles .shape ) * 0.2 ) + 1 ) * base_scale
102+ scatter = ((random .normal (subkey [ 0 ] , angles .shape ) * 0.2 ) + 1 ) * base_scale
103103 scattered_update = angle_per_timestep * scatter
104104 scaled_scattered_update = scattered_update * inputs
105105
@@ -116,7 +116,7 @@ def advance_state(t, dt, target_freq, key, inputs, angles, tols, base_scale):
116116 @staticmethod
117117 def reset (batch_size , n_units , key , target_freq ):
118118 restVals = jnp .zeros ((batch_size , n_units ))
119- key , subkey = random .split (key , 2 )
119+ key , * subkey = random .split (key , 3 )
120120 return restVals , restVals , restVals , restVals , key
121121
122122 def save (self , directory , ** kwargs ):
0 commit comments