Skip to content

Commit cd9f12a

Browse files
committed
Added a JaxProcess
Added Jax Process to allow for scanning over the process.
1 parent 8dbf83a commit cd9f12a

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

ngclearn/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
from ngcsimlib.compartment import Compartment
3737
from ngcsimlib.resolver import resolver
3838
from ngcsimlib import utils as sim_utils
39-
from ngcsimlib.compilers.process import Process, transition
39+
40+
from ngclearn.utils.jaxProcess import JaxProcess
41+
from ngcsimlib.compilers.process import transition, Process
4042

4143

4244
from ngcsimlib import configure, preload_modules

ngclearn/components/base_monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def load(self, directory, **kwargs):
284284
setattr(self, compartment_path, new_comp)
285285
self.compartments.append(new_comp.path)
286286

287-
self._update_resolver()
287+
# self._update_resolver()
288288

289289
def make_plot(self, compartment, ax=None, ylabel=None, xlabel=None, title=None, n=None, plot_func=None):
290290
vals = self.view(compartment)

ngclearn/utils/jaxProcess.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from ngcsimlib.compilers.process import Process
2+
from jax.lax import scan as _scan
3+
4+
5+
class JaxProcess(Process):
6+
def scan(self, xs, arg_order=None, compartments_to_monitor=None, save_state=True):
7+
if compartments_to_monitor is None:
8+
compartments_to_monitor = []
9+
if arg_order is None:
10+
arg_order = list(self.get_required_args())
11+
12+
def _pure(current_state, x):
13+
v = self.pure(current_state, **{key: value for key, value in zip(arg_order, x)})
14+
return v, [v[c.path] for c in compartments_to_monitor]
15+
16+
vals, stacked = _scan(_pure, init=self.get_required_state(include_special_compartments=True), xs=xs)
17+
if save_state:
18+
self.updated_modified_state(vals)
19+
return stacked
20+

0 commit comments

Comments
 (0)