Skip to content

Commit 68f663b

Browse files
author
Alexander Ororbia
committed
minor edits to exp-kernel/wtas-cell
1 parent 0d7c24b commit 68f663b

File tree

3 files changed

+23
-11
lines changed

3 files changed

+23
-11
lines changed

ngclearn/components/neurons/spiking/WTASCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class WTASCell(JaxComponent): ## winner-take-all spiking cell
5050
thr_jitter: scale of uniform jitter to add to initialization of thresholds
5151
"""
5252

53-
# Define Functions
53+
#@deprecate_args(thr_base="thrBase")
5454
def __init__(
5555
self, name, n_units, tau_m, resist_m=1., thr_base=0.4, thr_gain=0.002, refract_time=0., thr_jitter=0.05,
5656
**kwargs

ngclearn/components/other/expKernel.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,22 @@ def advance_state(self, t):
7474

7575
s = inputs
7676
## update spike time window and corresponding window volume
77-
tf, epsp = _apply_kernel(tf, s, t, self.tau_w, self.win_len, krn_start=0,
78-
krn_end=self.win_len-1) #0:win_len-1)
77+
tf, epsp = _apply_kernel(
78+
tf, s, t, self.tau_w, self.win_len, krn_start=0, krn_end=self.win_len-1
79+
) #0:win_len-1)
7980

8081
# Update compartments
8182
self.epsp.set(epsp)
8283
self.tf.set(tf)
8384

85+
@compilable
86+
def reset(self):
87+
restVals = jnp.zeros((self.batch_size, self.n_units)) ## inputs, epsp
88+
restTensor = jnp.zeros([self.win_len, self.batch_size, self.n_units], jnp.float32) ## tf
89+
self.inputs.set(restVals)
90+
self.epsp.set(restVals)
91+
self.tf.set(restTensor)
92+
8493
@classmethod
8594
def help(cls): ## component help function
8695
properties = {

ngclearn/components/synapses/denseSynapse.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from jax import random, numpy as jnp, jit
22
from ngclearn.components.jaxComponent import JaxComponent
3-
from ngclearn.utils import tensorstats
4-
from ngclearn.utils.weight_distribution import initialize_params
3+
from ngclearn.utils.distribution_generator import DistributionGenerator
54
from ngcsimlib.logger import info
65

76
from ngcsimlib.compartment import Compartment
@@ -58,10 +57,13 @@ def __init__(
5857

5958
if self.weight_init is None:
6059
info(self.name, "is using default weight initializer!")
61-
self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8}
62-
weights = initialize_params(subkeys[0], self.weight_init, shape)
60+
# self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8}
61+
# weights = initialize_params(subkeys[0], self.weight_init, shape)
62+
self.weight_init = DistributionGenerator.uniform(0.025, 0.8)
63+
#weights = initialize_params(subkeys[0], self.weight_init, shape)
64+
weights = self.weight_init(shape, subkeys[0])
6365

64-
if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed
66+
if 0. < p_conn < 1.: ## Modifier/constraint: only non-zero and <1 probs allowed
6567
p_mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape)
6668
weights = weights * p_mask ## sparsify matrix
6769

@@ -76,9 +78,10 @@ def __init__(
7678
if self.bias_init is None:
7779
info(self.name, "is using default bias value of zero (no bias "
7880
"kernel provided)!")
79-
self.biases = Compartment(initialize_params(subkeys[2], bias_init,
80-
(1, shape[1]))
81-
if bias_init else 0.0)
81+
self.biases = Compartment(self.bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0)
82+
# self.biases = Compartment(initialize_params(subkeys[2], bias_init,
83+
# (1, shape[1]))
84+
# if bias_init else 0.0)
8285

8386
@compilable
8487
def advance_state(self):

0 commit comments

Comments
 (0)