Skip to content

Commit 2e1b2e8

Browse files
committed
Merge branches 'major_release_update' and 'major_release_update' of github.com:NACLab/ngc-learn into major_release_update
2 parents b594529 + cd9f12a commit 2e1b2e8

40 files changed

+141
-101
lines changed

docs/museum/snn_dc.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ To use your saved model and examine its performance on the MNIST test-set, you
209209
can execute the evaluation script like so:
210210

211211
```console
212-
$ python analyze_dcsnn.py --dataX=../data/mnist/testX.npy --sample_idx=0
212+
$ python analyze_dcsnn.py --dataX=../../data/mnist/testX.npy --sample_idx=0
213213
```
214214

215215
while will produce a visualization of your DC-SNN's receptive fields (the

ngclearn/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
from ngcsimlib.resolver import resolver
3838
from ngcsimlib import utils as sim_utils
3939

40+
from ngclearn.utils.jaxProcess import JaxProcess
41+
from ngcsimlib.compilers.process import transition, Process
42+
4043

4144
from ngcsimlib import configure, preload_modules
4245
from ngcsimlib import logger

ngclearn/components/base_monitor.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import json
22

3-
from ngclearn import Component, Compartment
3+
from ngclearn import Component, Compartment, transition
44
from ngclearn import numpy as np
5-
#from ngcsimlib.utils import add_component_resolver, add_resolver_meta, \
65
from ngcsimlib.utils import get_current_path
76
from ngcsimlib.logger import warn, critical
7+
88
import matplotlib.pyplot as plt
99

1010

@@ -44,9 +44,16 @@ class Base_Monitor(Component):
4444
"""
4545
auto_resolve = False
4646

47+
@staticmethod
48+
def build_reset(component):
49+
return Base_Monitor.reset(component)
50+
51+
@staticmethod
52+
def build_advance_state(component):
53+
return Base_Monitor.record(component)
4754

4855
@staticmethod
49-
def build_advance(compartments):
56+
def _record_internal(compartments):
5057
"""
5158
A method to build the method to advance the stored values.
5259
@@ -61,8 +68,9 @@ def build_advance(compartments):
6168
"monitor found in ngclearn.components or "
6269
"ngclearn.components.lava (If using lava)")
6370

71+
@transition(None, True)
6472
@staticmethod
65-
def build_reset(component):
73+
def reset(component):
6674
"""
6775
A method to build the method to reset the stored values.
6876
Args:
@@ -87,15 +95,16 @@ def _reset(**kwargs):
8795
# pure func, output compartments, args, params, input compartments
8896
return _reset, output_compartments, [], [], output_compartments
8997

98+
@transition(None, True)
9099
@staticmethod
91-
def build_advance_state(component):
100+
def record(component):
92101
output_compartments = []
93102
compartments = []
94103
for comp in component.compartments:
95104
output_compartments.append(comp.split("/")[-1] + "*store")
96105
compartments.append(comp.split("/")[-1])
97106

98-
_advance = component.build_advance(compartments)
107+
_advance = component._record_internal(compartments)
99108

100109
return _advance, output_compartments, [], [], compartments + output_compartments
101110

@@ -124,8 +133,15 @@ def watch(self, compartment, window_length):
124133
"""
125134
cs, end = self._add_path(compartment.path)
126135

127-
dtype = compartment.value.dtype
128-
shape = compartment.value.shape
136+
if hasattr(compartment.value, "dtype"):
137+
dtype = compartment.value.dtype
138+
else:
139+
dtype = type(compartment.value)
140+
141+
if hasattr(compartment.value, "shape"):
142+
shape = compartment.value.shape
143+
else:
144+
shape = (1,)
129145
new_comp = Compartment(np.zeros(shape, dtype=dtype))
130146
new_comp_store = Compartment(np.zeros((window_length, *shape), dtype=dtype))
131147

@@ -268,7 +284,7 @@ def load(self, directory, **kwargs):
268284
setattr(self, compartment_path, new_comp)
269285
self.compartments.append(new_comp.path)
270286

271-
self._update_resolver()
287+
# self._update_resolver()
272288

273289
def make_plot(self, compartment, ax=None, ylabel=None, xlabel=None, title=None, n=None, plot_func=None):
274290
vals = self.view(compartment)

ngclearn/components/monitor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ngclearn.components.base_monitor import Base_Monitor
2+
from ngclearn import transition
23

34
class Monitor(Base_Monitor):
45
"""
@@ -8,9 +9,9 @@ class Monitor(Base_Monitor):
89
auto_resolve = False
910

1011
@staticmethod
11-
def build_advance(compartments):
12+
def _record_internal(compartments):
1213
@staticmethod
13-
def _advance(**kwargs):
14+
def _record(**kwargs):
1415
return_vals = []
1516
for comp in compartments:
1617
new_val = kwargs[comp]
@@ -19,7 +20,7 @@ def _advance(**kwargs):
1920
current_store = current_store.at[-1].set(new_val)
2021
return_vals.append(current_store)
2122
return return_vals if len(compartments) > 1 else return_vals[0]
22-
return _advance
23+
return _record
2324

2425
@staticmethod
2526
def build_advance_state(component):

ngclearn/utils/jaxProcess.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from ngcsimlib.compilers.process import Process
2+
from jax.lax import scan as _scan
3+
4+
5+
class JaxProcess(Process):
6+
def scan(self, xs, arg_order=None, compartments_to_monitor=None, save_state=True):
7+
if compartments_to_monitor is None:
8+
compartments_to_monitor = []
9+
if arg_order is None:
10+
arg_order = list(self.get_required_args())
11+
12+
def _pure(current_state, x):
13+
v = self.pure(current_state, **{key: value for key, value in zip(arg_order, x)})
14+
return v, [v[c.path] for c in compartments_to_monitor]
15+
16+
vals, stacked = _scan(_pure, init=self.get_required_state(include_special_compartments=True), xs=xs)
17+
if save_state:
18+
self.updated_modified_state(vals)
19+
return stacked
20+

tests/components/input_encoders/test_bernoulliCell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def test_bernoulliCell1():
2424
with Context(name) as ctx:
2525
a = BernoulliCell(name="a", n_units=1, key=subkeys[0])
2626

27-
advance_process = (Process()
27+
advance_process = (Process("advance_proc")
2828
>> a.advance_state)
2929
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3030

31-
reset_process = (Process()
31+
reset_process = (Process("reset_proc")
3232
>> a.reset)
3333
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3434

tests/components/input_encoders/test_latencyCell.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ def test_latencyCell1():
2929
)
3030

3131
## create and compile core simulation commands
32-
advance_process = (Process()
32+
advance_process = (Process("advance_proc")
3333
>> a.advance_state)
3434
ctx.wrap_and_add_command(jit(advance_process.pure), name="advance")
35-
calc_spike_times_process = (Process()
35+
calc_spike_times_process = (Process("calc_sptimes_proc")
3636
>> a.calc_spike_times)
3737
ctx.wrap_and_add_command(jit(calc_spike_times_process.pure), name="calc_spike_times")
38-
reset_process = (Process()
38+
reset_process = (Process("reset_proc")
3939
>> a.reset)
4040
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
4141

tests/components/input_encoders/test_phasorCell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def test_phasorCell1():
2424
with Context(name) as ctx:
2525
a = PhasorCell(name="a", n_units=1, target_freq=1000., disable_phasor=True, key=subkeys[0])
2626

27-
advance_process = (Process()
27+
advance_process = (Process("advance_proc")
2828
>> a.advance_state)
2929
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3030

31-
reset_process = (Process()
31+
reset_process = (Process("reset_proc")
3232
>> a.reset)
3333
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3434

tests/components/input_encoders/test_poissonCell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def test_poissonCell1():
2424
with Context(name) as ctx:
2525
a = PoissonCell(name="a", n_units=1, target_freq=1000., key=subkeys[0])
2626

27-
advance_process = (Process()
27+
advance_process = (Process("advance_proc")
2828
>> a.advance_state)
2929
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3030

31-
reset_process = (Process()
31+
reset_process = (Process("reset_proc")
3232
>> a.reset)
3333
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3434

tests/components/neurons/graded/test_RateCell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def test_RateCell1():
2626
threshold=("none", 0.), integration_type="euler",
2727
batch_size=1, resist_scale=1., shape=None, is_stateful=True
2828
)
29-
advance_process = (Process() >> a.advance_state)
29+
advance_process = (Process("advance_proc") >> a.advance_state)
3030
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
31-
reset_process = (Process() >> a.reset)
31+
reset_process = (Process("reset_proc") >> a.reset)
3232
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3333

3434
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")

0 commit comments

Comments
 (0)