Skip to content

Commit b86ae3d

Browse files
author
Alexander Ororbia
committed
cleaned up dunder repr method, moved to JaxComponent parent; fixed __init__ pointer to tensorstats
1 parent 33599fd commit b86ae3d

File tree

11 files changed

+15
-142
lines changed

11 files changed

+15
-142
lines changed

ngclearn/components/jaxComponent.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from jax import random
77
from ngcsimlib.compartment import Compartment
88
from ngcsimlib import Component
9+
from ngclearn.utils import tensorstats
910

1011

1112
class JaxComponent(Component):
@@ -57,3 +58,16 @@ def load(self, directory: str):
5758
if d is not None:
5859
comp.set(d)
5960

61+
def __repr__(self):
62+
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
63+
maxlen = max(len(c) for c in comps) + 5
64+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
65+
for c in comps:
66+
stats = tensorstats(getattr(self, c).value)
67+
if stats is not None:
68+
line = [f"{k}: {v}" for k, v in stats.items()]
69+
line = ", ".join(line)
70+
else:
71+
line = "None"
72+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
73+
return lines

ngclearn/components/neurons/spiking/IFCell.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from ngclearn.components.jaxComponent import JaxComponent
22
from jax import numpy as jnp, random, nn, Array, jit
3-
from ngclearn.utils import tensorstats
43
from ngcsimlib import deprecate_args
54
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
65
step_euler, step_rk2
@@ -231,17 +230,3 @@ def help(cls): ## component help function
231230
"hyperparameters": hyperparams}
232231
return info
233232

234-
def __repr__(self):
235-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
236-
maxlen = max(len(c) for c in comps) + 5
237-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
238-
for c in comps:
239-
stats = tensorstats(getattr(self, c).value)
240-
if stats is not None:
241-
line = [f"{k}: {v}" for k, v in stats.items()]
242-
line = ", ".join(line)
243-
else:
244-
line = "None"
245-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
246-
return lines
247-

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from ngclearn.components.jaxComponent import JaxComponent
22
from jax import numpy as jnp, random, nn, Array
3-
from ngclearn.utils import tensorstats
43
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
54
step_euler, step_rk2
65
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
@@ -268,20 +267,6 @@ def help(cls): ## component help function
268267
"hyperparameters": hyperparams}
269268
return info
270269

271-
def __repr__(self):
272-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
273-
maxlen = max(len(c) for c in comps) + 5
274-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
275-
for c in comps:
276-
stats = tensorstats(getattr(self, c).value)
277-
if stats is not None:
278-
line = [f"{k}: {v}" for k, v in stats.items()]
279-
line = ", ".join(line)
280-
else:
281-
line = "None"
282-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
283-
return lines
284-
285270
if __name__ == '__main__':
286271
from ngcsimlib.context import Context
287272
with Context("Bar") as bar:

ngclearn/components/neurons/spiking/RAFCell.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from ngclearn.components.jaxComponent import JaxComponent
22
from jax import numpy as jnp, random, jit, nn
3-
from functools import partial
4-
from ngclearn.utils import tensorstats
53
from ngcsimlib import deprecate_args
64
from ngcsimlib.logger import info, warn
75
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
@@ -214,17 +212,3 @@ def help(cls): ## component help function
214212
"tau_w * dw/dt = w * dampen_factor - v * omega + j",
215213
"hyperparameters": hyperparams}
216214
return info
217-
218-
def __repr__(self):
219-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
220-
maxlen = max(len(c) for c in comps) + 5
221-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
222-
for c in comps:
223-
stats = tensorstats(getattr(self, c).value)
224-
if stats is not None:
225-
line = [f"{k}: {v}" for k, v in stats.items()]
226-
line = ", ".join(line)
227-
else:
228-
line = "None"
229-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
230-
return lines

ngclearn/components/neurons/spiking/WTASCell.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from jax import numpy as jnp, random, jit, nn
22
from ngclearn.components.jaxComponent import JaxComponent
33
from jax import numpy as jnp, random, jit, nn
4-
from functools import partial
5-
from ngclearn.utils import tensorstats
64
from ngcsimlib import deprecate_args
75
from ngcsimlib.logger import info, warn
86

@@ -159,20 +157,6 @@ def help(cls): ## component help function
159157
"hyperparameters": hyperparams}
160158
return info
161159

162-
def __repr__(self):
163-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
164-
maxlen = max(len(c) for c in comps) + 5
165-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
166-
for c in comps:
167-
stats = tensorstats(getattr(self, c).value)
168-
if stats is not None:
169-
line = [f"{k}: {v}" for k, v in stats.items()]
170-
line = ", ".join(line)
171-
else:
172-
line = "None"
173-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
174-
return lines
175-
176160
if __name__ == '__main__':
177161
from ngcsimlib.context import Context
178162
with Context("Bar") as bar:

ngclearn/components/neurons/spiking/adExCell.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from ngclearn.components.jaxComponent import JaxComponent
22
from jax import numpy as jnp, random, jit, nn
3-
from functools import partial
4-
from ngclearn.utils import tensorstats
53
from ngcsimlib import deprecate_args
64
from ngcsimlib.logger import info, warn
75
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2
@@ -211,20 +209,6 @@ def help(cls): ## component help function
211209
"hyperparameters": hyperparams}
212210
return info
213211

214-
def __repr__(self):
215-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
216-
maxlen = max(len(c) for c in comps) + 5
217-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
218-
for c in comps:
219-
stats = tensorstats(getattr(self, c).value)
220-
if stats is not None:
221-
line = [f"{k}: {v}" for k, v in stats.items()]
222-
line = ", ".join(line)
223-
else:
224-
line = "None"
225-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
226-
return lines
227-
228212
if __name__ == '__main__':
229213
from ngcsimlib.context import Context
230214
with Context("Bar") as bar:

ngclearn/components/neurons/spiking/fitzhughNagumoCell.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from ngclearn.components.jaxComponent import JaxComponent
22
from jax import numpy as jnp, random, jit, nn
3-
from functools import partial
4-
from ngclearn.utils import tensorstats
53
from ngcsimlib import deprecate_args
64
from ngcsimlib.logger import info, warn
75
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
@@ -214,20 +212,6 @@ def help(cls): ## component help function
214212
"hyperparameters": hyperparams}
215213
return info
216214

217-
def __repr__(self):
218-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
219-
maxlen = max(len(c) for c in comps) + 5
220-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
221-
for c in comps:
222-
stats = tensorstats(getattr(self, c).value)
223-
if stats is not None:
224-
line = [f"{k}: {v}" for k, v in stats.items()]
225-
line = ", ".join(line)
226-
else:
227-
line = "None"
228-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
229-
return lines
230-
231215
if __name__ == '__main__':
232216
from ngcsimlib.context import Context
233217
with Context("Bar") as bar:

ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from ngclearn.components.jaxComponent import JaxComponent
22
from jax import numpy as jnp, random, jit, nn
3-
from functools import partial
4-
from ngclearn.utils import tensorstats
53
from ngcsimlib import deprecate_args
64
from ngcsimlib.logger import info, warn
75
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2, step_rk4
@@ -273,20 +271,6 @@ def help(cls): ## component help function
273271
"hyperparameters": hyperparams}
274272
return info
275273

276-
def __repr__(self):
277-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
278-
maxlen = max(len(c) for c in comps) + 5
279-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
280-
for c in comps:
281-
stats = tensorstats(getattr(self, c).value)
282-
if stats is not None:
283-
line = [f"{k}: {v}" for k, v in stats.items()]
284-
line = ", ".join(line)
285-
else:
286-
line = "None"
287-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
288-
return lines
289-
290274
if __name__ == '__main__':
291275
from ngcsimlib.context import Context
292276
with Context("Bar") as bar:

ngclearn/components/neurons/spiking/izhikevichCell.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from ngclearn.components.jaxComponent import JaxComponent
22
from jax import numpy as jnp, random, jit, nn
3-
from functools import partial
4-
from ngclearn.utils import tensorstats
53
from ngcsimlib import deprecate_args
64
from ngcsimlib.logger import info, warn
75
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2
@@ -229,20 +227,6 @@ def help(cls): ## component help function
229227
"hyperparameters": hyperparams}
230228
return info
231229

232-
def __repr__(self):
233-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
234-
maxlen = max(len(c) for c in comps) + 5
235-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
236-
for c in comps:
237-
stats = tensorstats(getattr(self, c).value)
238-
if stats is not None:
239-
line = [f"{k}: {v}" for k, v in stats.items()]
240-
line = ", ".join(line)
241-
else:
242-
line = "None"
243-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
244-
return lines
245-
246230
if __name__ == '__main__':
247231
from ngcsimlib.context import Context
248232
with Context("Bar") as bar:

ngclearn/components/neurons/spiking/quadLIFCell.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from ngclearn.components.jaxComponent import JaxComponent
22
from jax import numpy as jnp, random, jit, nn, Array
3-
from functools import partial
4-
from ngclearn.utils import tensorstats
53
from ngcsimlib import deprecate_args
64
from ngcsimlib.logger import info, warn
75
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
@@ -241,20 +239,6 @@ def help(cls): ## component help function
241239
"hyperparameters": hyperparams}
242240
return info
243241

244-
def __repr__(self):
245-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
246-
maxlen = max(len(c) for c in comps) + 5
247-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
248-
for c in comps:
249-
stats = tensorstats(getattr(self, c).value)
250-
if stats is not None:
251-
line = [f"{k}: {v}" for k, v in stats.items()]
252-
line = ", ".join(line)
253-
else:
254-
line = "None"
255-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
256-
return lines
257-
258242
if __name__ == '__main__':
259243
from ngcsimlib.context import Context
260244
with Context("Bar") as bar:

0 commit comments

Comments
 (0)