Skip to content

Commit 03371ec

Browse files
author
Alexander Ororbia
committed
fixed some minor bugs in rate-coded cells/hebb-syn
1 parent e163c37 commit 03371ec

File tree

7 files changed

+51
-93
lines changed

7 files changed

+51
-93
lines changed

docs/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ Welcome to ngc-learn's documentation!
66
=====================================
77

88
**ngc-learn** is a Python library for building, simulating, and analyzing
9-
biomimetic computational models, arbitrary predictive processing/coding models,
10-
and spiking neural networks. This toolkit is built on top of
9+
biomimetic and NeuroAI computational models, arbitrary predictive processing/coding models,
10+
spiking neural networks, and general dynamical systems. This toolkit is built on top of
1111
`JAX <https://github.com/google/jax>`_ and is distributed under the 3-Clause BSD license.
1212

1313
.. toctree::

ngclearn/components/neurons/graded/bernoulliErrorCell.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ def advance_state(self, dt): ## compute Bernoulli error cell output
110110

111111
# @transition(output_compartments=["dp", "dtarget", "target", "p", "modulator", "L", "mask"])
112112
@compilable
113-
def reset(self, batch_size): ## reset core components/statistics
114-
_shape = (batch_size, self.shape[0])
113+
def reset(self): ## reset core components/statistics
114+
_shape = (self.batch_size, self.shape[0])
115115
if len(self.shape) > 1:
116-
_shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])
116+
_shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
117117
restVals = jnp.zeros(_shape) ## "rest"/reset values
118118
dp = restVals
119119
dtarget = restVals
@@ -161,20 +161,6 @@ def help(cls): ## component help function
161161
"hyperparameters": hyperparams}
162162
return info
163163

164-
def __repr__(self):
165-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
166-
maxlen = max(len(c) for c in comps) + 5
167-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
168-
for c in comps:
169-
stats = tensorstats(getattr(self, c).get())
170-
if stats is not None:
171-
line = [f"{k}: {v}" for k, v in stats.items()]
172-
line = ", ".join(line)
173-
else:
174-
line = "None"
175-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
176-
return lines
177-
178164
if __name__ == '__main__':
179165
from ngcsimlib.context import Context
180166
with Context("Bar") as bar:

ngclearn/components/neurons/graded/gaussianErrorCell.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,14 @@ def advance_state(self, dt): ## compute Gaussian error cell output
111111
# @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"])
112112
# @staticmethod
113113
@compilable
114-
def reset(self, batch_size, shape, sigma_shape): ## reset core components/statistics
115-
_shape = (batch_size, shape[0])
116-
if len(shape) > 1:
117-
_shape = (batch_size, shape[0], shape[1], shape[2])
114+
def reset(self): ## reset core components/statistics
115+
_shape = (self.batch_size, self.shape[0])
116+
if len(self.shape) > 1:
117+
_shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
118118
restVals = jnp.zeros(_shape)
119119
dmu = restVals
120120
dtarget = restVals
121-
dSigma = jnp.zeros(sigma_shape)
121+
dSigma = jnp.zeros(self.sigma_shape)
122122
target = restVals
123123
mu = restVals
124124
modulator = mu + 1.
@@ -164,20 +164,6 @@ def help(cls): ## component help function
164164
"hyperparameters": hyperparams}
165165
return info
166166

167-
def __repr__(self):
168-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
169-
maxlen = max(len(c) for c in comps) + 5
170-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
171-
for c in comps:
172-
stats = tensorstats(getattr(self, c).get())
173-
if stats is not None:
174-
line = [f"{k}: {v}" for k, v in stats.items()]
175-
line = ", ".join(line)
176-
else:
177-
line = "None"
178-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
179-
return lines
180-
181167
if __name__ == '__main__':
182168
from ngcsimlib.context import Context
183169
with Context("Bar") as bar:

ngclearn/components/neurons/graded/laplacianErrorCell.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,16 @@ def advance_state(self, dt): ## compute Laplacian error cell output
103103
self.L.set(jnp.squeeze(L))
104104
self.mask.set(mask)
105105

106-
def reset(self, batch_size, n_units, scale_shape):
107-
restVals = jnp.zeros((batch_size, n_units))
106+
def reset(self): ## reset core components/statistics
107+
restVals = jnp.zeros((self.batch_size, self.n_units))
108108
dshift = restVals
109109
dtarget = restVals
110-
dScale = jnp.zeros(scale_shape)
110+
dScale = jnp.zeros(self.scale_shape)
111111
target = restVals
112112
shift = restVals
113113
modulator = shift + 1.
114114
L = 0.
115-
mask = jnp.ones((batch_size, n_units))
115+
mask = jnp.ones((self.batch_size, self.n_units))
116116

117117
self.dshift.set(dshift)
118118
self.dtarget.set(dtarget)
@@ -152,20 +152,6 @@ def help(cls): ## component help function
152152
"hyperparameters": hyperparams}
153153
return info
154154

155-
def __repr__(self):
156-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
157-
maxlen = max(len(c) for c in comps) + 5
158-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
159-
for c in comps:
160-
stats = tensorstats(getattr(self, c).get())
161-
if stats is not None:
162-
line = [f"{k}: {v}" for k, v in stats.items()]
163-
line = ", ".join(line)
164-
else:
165-
line = "None"
166-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
167-
return lines
168-
169155
if __name__ == '__main__':
170156
from ngcsimlib.context import Context
171157
with Context("Bar") as bar:

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -252,40 +252,39 @@ def advance_state(self, dt):
252252
self.zF.set(zF)
253253

254254
@compilable
255-
def reset(self, batch_size, shape): #n_units
256-
_shape = (batch_size, shape[0])
257-
if len(shape) > 1:
258-
_shape = (batch_size, shape[0], shape[1], shape[2])
255+
def reset(self): #, batch_size, shape): #n_units
256+
_shape = (self.batch_size, self.shape[0])
257+
if len(self.shape) > 1:
258+
_shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
259259
restVals = jnp.zeros(_shape)
260260
self.j.set(restVals)
261261
self.j_td.set(restVals)
262262
self.z.set(restVals)
263263
self.zF.set(restVals)
264264

265-
266-
def save(self, directory, **kwargs):
267-
## do a protected save of constants, depending on whether they are floats or arrays
268-
tau_m = (self.tau_m if isinstance(self.tau_m, float)
269-
else jnp.ones([[self.tau_m]]))
270-
priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float)
271-
else jnp.ones([[self.priorLeakRate]]))
272-
resist_scale = (self.resist_scale if isinstance(self.resist_scale, float)
273-
else jnp.ones([[self.resist_scale]]))
274-
275-
file_name = directory + "/" + self.name + ".npz"
276-
jnp.savez(file_name,
277-
tau_m=tau_m, priorLeakRate=priorLeakRate,
278-
resist_scale=resist_scale) #, key=self.key.value)
279-
280-
def load(self, directory, seeded=False, **kwargs):
281-
file_name = directory + "/" + self.name + ".npz"
282-
data = jnp.load(file_name)
283-
## constants loaded in
284-
self.tau_m = data['tau_m']
285-
self.priorLeakRate = data['priorLeakRate']
286-
self.resist_scale = data['resist_scale']
287-
#if seeded:
288-
# self.key.set(data['key'])
265+
# def save(self, directory, **kwargs):
266+
# ## do a protected save of constants, depending on whether they are floats or arrays
267+
# tau_m = (self.tau_m if isinstance(self.tau_m, float)
268+
# else jnp.ones([[self.tau_m]]))
269+
# priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float)
270+
# else jnp.ones([[self.priorLeakRate]]))
271+
# resist_scale = (self.resist_scale if isinstance(self.resist_scale, float)
272+
# else jnp.ones([[self.resist_scale]]))
273+
#
274+
# file_name = directory + "/" + self.name + ".npz"
275+
# jnp.savez(file_name,
276+
# tau_m=tau_m, priorLeakRate=priorLeakRate,
277+
# resist_scale=resist_scale) #, key=self.key.value)
278+
#
279+
# def load(self, directory, seeded=False, **kwargs):
280+
# file_name = directory + "/" + self.name + ".npz"
281+
# data = jnp.load(file_name)
282+
# ## constants loaded in
283+
# self.tau_m = data['tau_m']
284+
# self.priorLeakRate = data['priorLeakRate']
285+
# self.resist_scale = data['resist_scale']
286+
# #if seeded:
287+
# # self.key.set(data['key'])
289288

290289
@classmethod
291290
def help(cls): ## component help function

ngclearn/components/neurons/graded/rewardErrorCell.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ def evolve(self, dt):
9292
self.mu.set(mu)
9393

9494
@compilable
95-
def reset(self, batch_size, n_units):
96-
restVals = jnp.zeros((batch_size, n_units))
95+
def reset(self): ## reset core components/statistics
96+
restVals = jnp.zeros((self.batch_size, self.n_units))
9797
mu = restVals
9898
rpe = restVals
9999
accum_reward = restVals
100-
n_ep_steps = jnp.zeros((batch_size, 1))
100+
n_ep_steps = jnp.zeros((self.batch_size, 1))
101101
self.mu.set(mu)
102102
self.rpe.set(rpe)
103103
self.accum_reward.set(accum_reward)

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,15 +254,16 @@ def evolve(self):
254254
self.dBiases.set(dBiases)
255255

256256
@compilable
257-
def reset(self, batch_size, shape):
258-
preVals = jnp.zeros((batch_size, shape[0]))
259-
postVals = jnp.zeros((batch_size, shape[1]))
260-
not self.inputs.targeted and self.inputs.set(preVals) # inputs
257+
def reset(self): #, batch_size, shape):
258+
preVals = jnp.zeros((self.batch_size, self.shape[0]))
259+
postVals = jnp.zeros((self.batch_size, self.shape[1]))
260+
#not self.inputs.targeted and self.inputs.set(preVals) # inputs
261+
self.inputs.set(preVals)
261262
self.outputs.set(postVals) # outputs
262263
self.pre.set(preVals) # pre
263264
self.post.set(postVals) # post
264-
self.dWeights.set(jnp.zeros(shape)) # dW
265-
self.dBiases.set(jnp.zeros(shape[1])) # db
265+
self.dWeights.set(jnp.zeros(self.shape)) # dW
266+
self.dBiases.set(jnp.zeros(self.shape[1])) # db
266267

267268
@classmethod
268269
def help(cls): ## component help function

0 commit comments

Comments
 (0)