Skip to content

Commit 9b75266

Browse files
author
Alexander Ororbia
committed
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn into major_release_update
2 parents f556293 + 4e9176c commit 9b75266

File tree

3 files changed

+22
-21
lines changed

3 files changed

+22
-21
lines changed

ngclearn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ngcsimlib.compartment import Compartment
3737
from ngcsimlib.resolver import resolver
3838
from ngcsimlib import utils as sim_utils
39+
from ngcsimlib.compilers.process import Process, transition
3940

4041

4142
from ngcsimlib import configure, preload_modules

ngclearn/components/base_monitor.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import json
22

3-
from ngclearn import Component, Compartment
3+
from ngclearn import Component, Compartment, transition
44
from ngclearn import numpy as np
5-
#from ngcsimlib.utils import add_component_resolver, add_resolver_meta, \
65
from ngcsimlib.utils import get_current_path
76
from ngcsimlib.logger import warn, critical
7+
88
import matplotlib.pyplot as plt
99

1010

@@ -44,9 +44,8 @@ class Base_Monitor(Component):
4444
"""
4545
auto_resolve = False
4646

47-
4847
@staticmethod
49-
def build_advance(compartments):
48+
def _record_internal(compartments):
5049
"""
5150
A method to build the method to advance the stored values.
5251
@@ -61,8 +60,9 @@ def build_advance(compartments):
6160
"monitor found in ngclearn.components or "
6261
"ngclearn.components.lava (If using lava)")
6362

63+
@transition(None, True)
6464
@staticmethod
65-
def build_reset(component):
65+
def reset(component):
6666
"""
6767
A method to build the method to reset the stored values.
6868
Args:
@@ -87,15 +87,16 @@ def _reset(**kwargs):
8787
# pure func, output compartments, args, params, input compartments
8888
return _reset, output_compartments, [], [], output_compartments
8989

90+
@transition(None, True)
9091
@staticmethod
91-
def build_advance_state(component):
92+
def record(component):
9293
output_compartments = []
9394
compartments = []
9495
for comp in component.compartments:
9596
output_compartments.append(comp.split("/")[-1] + "*store")
9697
compartments.append(comp.split("/")[-1])
9798

98-
_advance = component.build_advance(compartments)
99+
_advance = component._record_internal(compartments)
99100

100101
return _advance, output_compartments, [], [], compartments + output_compartments
101102

@@ -124,8 +125,15 @@ def watch(self, compartment, window_length):
124125
"""
125126
cs, end = self._add_path(compartment.path)
126127

127-
dtype = compartment.value.dtype
128-
shape = compartment.value.shape
128+
if hasattr(compartment.value, "dtype"):
129+
dtype = compartment.value.dtype
130+
else:
131+
dtype = type(compartment.value)
132+
133+
if hasattr(compartment.value, "shape"):
134+
shape = compartment.value.shape
135+
else:
136+
shape = (1,)
129137
new_comp = Compartment(np.zeros(shape, dtype=dtype))
130138
new_comp_store = Compartment(np.zeros((window_length, *shape), dtype=dtype))
131139

ngclearn/components/monitor.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from ngclearn.components.base_monitor import Base_Monitor
2+
from ngclearn import transition
23

34
class Monitor(Base_Monitor):
45
"""
56
A jax implementation of `Base_Monitor`. Designed to be used with all
67
non-lava ngclearn components
78
"""
8-
auto_resolve = False
99

1010
@staticmethod
11-
def build_advance(compartments):
11+
def _record_internal(compartments):
1212
@staticmethod
13-
def _advance(**kwargs):
13+
def _record(**kwargs):
1414
return_vals = []
1515
for comp in compartments:
1616
new_val = kwargs[comp]
@@ -19,12 +19,4 @@ def _advance(**kwargs):
1919
current_store = current_store.at[-1].set(new_val)
2020
return_vals.append(current_store)
2121
return return_vals if len(compartments) > 1 else return_vals[0]
22-
return _advance
23-
24-
@staticmethod
25-
def build_advance_state(component):
26-
return super().build_advance_state(component)
27-
28-
@staticmethod
29-
def build_reset(component):
30-
return super().build_reset(component)
22+
return _record

0 commit comments

Comments
 (0)