Skip to content

Commit 80f2417

Browse files
committed
add not self.inputs.targeted and to required components. Fixing general __repr__ bug in jaxcomponent
1 parent de10028 commit 80f2417

File tree

10 files changed

+12
-10
lines changed

10 files changed

+12
-10
lines changed

ngclearn/components/input_encoders/bernoulliCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def advance_state(self, t):
4848
@compilable
4949
def reset(self):
5050
restVals = jnp.zeros((self.batch_size.get(), self.n_units.get()))
51-
self.inputs.set(restVals)
51+
not self.inputs.targeted and self.inputs.set(restVals)
5252
self.outputs.set(restVals)
5353
self.tols.set(restVals)
5454

ngclearn/components/input_encoders/latencyCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def advance_state(self, t):
211211
@compilable
212212
def reset(self):
213213
restVals = jnp.zeros((self.batch_size.get(), self.n_units.get()))
214-
self.inputs.set(restVals)
214+
not self.inputs.targeted and self.inputs.set(restVals)
215215
self.outputs.set(restVals)
216216
self.tols.set(restVals)
217217
self.mask.set(restVals)

ngclearn/components/input_encoders/phasorCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def advance_state(self, t, dt):
8888
@compilable
8989
def reset(self):
9090
restVals = jnp.zeros((self.batch_size.get(), self.n_units.get()))
91-
self.inputs.set(restVals)
91+
not self.inputs.targeted and self.inputs.set(restVals)
9292
self.outputs.set(restVals)
9393
self.tols.set(restVals)
9494
self.angles.set(restVals)

ngclearn/components/jaxComponent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __repr__(self):
6363
maxlen = max(len(c) for c in comps) + 5
6464
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
6565
for c in comps:
66-
stats = tensorstats(getattr(self, c).value)
66+
stats = tensorstats(getattr(self, c).get())
6767
if stats is not None:
6868
line = [f"{k}: {v}" for k, v in stats.items()]
6969
line = ", ".join(line)

ngclearn/components/other/expKernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def advance_state(self, t):
8686
def reset(self):
8787
restVals = jnp.zeros((self.batch_size, self.n_units)) ## inputs, epsp
8888
restTensor = jnp.zeros([self.win_len, self.batch_size, self.n_units], jnp.float32) ## tf
89-
self.inputs.set(restVals)
89+
not self.inputs.targeted and self.inputs.set(restVals)
9090
self.epsp.set(restVals)
9191
self.tf.set(restTensor)
9292

ngclearn/components/other/varTrace.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# %%
2+
13
from ngclearn.components.jaxComponent import JaxComponent
24
from jax import numpy as jnp, random, jit
35
from functools import partial
@@ -124,7 +126,7 @@ def advance_state(self, dt):
124126
@compilable
125127
def reset(self):
126128
restVals = jnp.zeros((self.batch_size, self.n_units))
127-
self.inputs.set(restVals)
129+
not self.inputs.targeted and self.inputs.set(restVals)
128130
self.outputs.set(restVals)
129131
self.trace.set(restVals)
130132

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def evolve(self):
257257
def reset(self, batch_size, shape):
258258
preVals = jnp.zeros((batch_size, shape[0]))
259259
postVals = jnp.zeros((batch_size, shape[1]))
260-
self.inputs.set(preVals) # inputs
260+
not self.inputs.targeted and self.inputs.set(preVals) # inputs
261261
self.outputs.set(postVals) # outputs
262262
self.pre.set(preVals) # pre
263263
self.post.set(postVals) # post

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def reset(self, batch_size, shape):
207207
seed = jax.random.PRNGKey(42)
208208

209209

210-
self.inputs.set(inputs)
210+
not self.inputs.targeted and self.inputs.set(inputs)
211211
self.outputs.set(outputs)
212212
self.objective.set(objective)
213213
self.rewards.set(rewards)

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def evolve(self):
280280
def reset(self, batch_size, shape):
281281
preVals = jnp.zeros((batch_size, shape[0]))
282282
postVals = jnp.zeros((batch_size, shape[1]))
283-
self.inputs.set(preVals) # inputs
283+
not self.inputs.targeted and self.inputs.set(preVals) # inputs
284284
self.outputs.set(postVals) # outputs
285285
self.pre.set(preVals) # pre
286286
self.post.set(postVals) # post

ngclearn/components/synapses/patched/patchedSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def reset(self, batch_size, shape):
159159
postVals = jnp.zeros((batch_size, shape[1]))
160160
inputs = preVals
161161
outputs = postVals
162-
self.inputs.set(inputs)
162+
not self.inputs.targeted and self.inputs.set(inputs)
163163
self.outputs.set(outputs)
164164

165165
def save(self, directory, **kwargs):

0 commit comments

Comments
 (0)