Skip to content

Commit 5396eb3

Browse files
author
Alexander Ororbia
committed
Merge branch 'main' of github.com:NACLab/ngc-learn
2 parents 49bccb5 + 9fd17f1 commit 5396eb3

File tree

1 file changed

+75
-24
lines changed

1 file changed

+75
-24
lines changed

ngclearn/utils/jaxProcess.py

Lines changed: 75 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,81 @@
1+
from ngcsimlib.compartment import Compartment
12
from ngcsimlib.compilers.process import Process
23
from jax.lax import scan as _scan
34
from ngcsimlib.logger import warn
45
from jax import numpy as jnp
56

7+
68
class JaxProcess(Process):
79
"""
810
The JaxProcess is a subclass of the ngcsimlib Process class. The
911
functionality added by this subclass is the use of the jax scanner to run a
1012
process quickly through the use of jax's JIT compiler.
1113
"""
12-
def scan(self, compartments_to_monitor=None,
13-
save_state=True, scan_length=None, **kwargs):
14+
15+
def __init__(self, name):
16+
super().__init__(name)
17+
self._process_scan_method = None
18+
self._monitoring = []
19+
20+
def _make_scanner(self):
21+
arg_order = self.get_required_args()
22+
23+
def _pure(current_state, x):
24+
v = self.pure(current_state,
25+
**{key: value for key, value in zip(arg_order, x)})
26+
return v, [v[m] for m in self._monitoring]
27+
28+
return _pure
29+
30+
def watch(self, compartment):
31+
"""
32+
Adds a compartment to the process to watch during a scan
33+
Args:
34+
compartment: The compartment to watch
35+
"""
36+
if not isinstance(compartment, Compartment):
37+
warn(
38+
"Jax Process trying to watch a value that is not a compartment")
39+
40+
self._monitoring.append(compartment.path)
41+
self._process_scan_method = self._make_scanner()
42+
43+
def clear_watch_list(self):
44+
"""
45+
Clears the watch list so no values are watched
46+
"""
47+
self._monitoring = []
48+
self._process_scan_method = self._make_scanner()
49+
50+
def transition(self, transition_call):
51+
"""
52+
Appends to the base transition call to create pure method for use by its
53+
scanner
54+
Args:
55+
transition_call: the transition being passed into the default
56+
process
57+
58+
Returns: this JaxProcess instance for chaining
59+
60+
"""
61+
super().transition(transition_call)
62+
self._process_scan_method = self._make_scanner()
63+
return self
64+
65+
def scan(self, save_state=True, scan_length=None, **kwargs):
1466
"""
1567
There a quite a few ways to initialize the scan method for the
16-
jaxProcess. To start the straight forward arguments are
17-
"compartments_to_monitor" and "save_state". Monitoring compartments
18-
means at the end of each process cycle record the value of each
19-
compartment in the list and then at the end a tuple of concatenated
20-
values will be returned that correspond to each compartment in the
21-
original list. The save_state flag is simply there to note if the state
68+
jaxProcess. To start the straight forward arguments is "save_state".
69+
The save_state flag is simply there to note if the state
2270
of the model should reflect the final state of the model after the scan
2371
is complete.
2472
73+
This scan method can also watch and report intermediate compartment
74+
values defined through calling the JaxProcess.watch() method watching a
75+
compartment means at the end of each process cycle record the value of
76+
the compartment and then at the end a tuple of concatenated values will
77+
be returned that correspond to each compartment the process is watching.
78+
2579
Where there are options for the arguments is when defining the keyword
2680
arguments for the process. The process will do its best to broadcast all
2781
the inputs to the largest size, so they can be scanned over. This means
@@ -39,7 +93,6 @@ def scan(self, compartments_to_monitor=None,
3993
4094
4195
Args:
42-
compartments_to_monitor: A list of compartments to monitor
4396
save_state: A boolean flag to indicate if the model state should be
4497
saved
4598
scan_length: a value to be used to denote the number of iterations
@@ -49,8 +102,6 @@ def scan(self, compartments_to_monitor=None,
49102
Returns: the final state of the model, the stacked output of the scan method
50103
51104
"""
52-
if compartments_to_monitor is None:
53-
compartments_to_monitor = []
54105
arg_order = list(self.get_required_args())
55106

56107
args = []
@@ -91,28 +142,28 @@ def scan(self, compartments_to_monitor=None,
91142
max_next_axis = 0
92143
new_args = []
93144
for a in args:
94-
if len(a.shape) >= axis+1:
145+
if len(a.shape) >= axis + 1:
95146
if a.shape[axis] == current_axis:
96147
new_args.append(a)
97148
else:
98149
warn("Keyword arguments must all be able to be "
99150
"broadcasted to the largest shape")
100151
return
101152
else:
102-
new_args.append(jnp.zeros(list(a.shape) + [current_axis], dtype=a.dtype) + a.reshape(*a.shape, 1))
153+
new_args.append(jnp.zeros(list(a.shape) + [current_axis],
154+
dtype=a.dtype) + a.reshape(
155+
*a.shape, 1))
103156

104-
if len(a.shape) > axis+1:
105-
max_next_axis = max(max_next_axis, a.shape[axis+1])
157+
if len(a.shape) > axis + 1:
158+
max_next_axis = max(max_next_axis, a.shape[axis + 1])
106159

107160
args = new_args
108161

109-
args = jnp.array(args).transpose([1, 0] + [i for i in range(2, max_axis+1)])
110-
111-
def _pure(current_state, x):
112-
v = self.pure(current_state, **{key: value for key, value in zip(arg_order, x)})
113-
return v, [v[c.path] for c in compartments_to_monitor]
114-
115-
vals, stacked = _scan(_pure, init=self.get_required_state(include_special_compartments=True), xs=args)
162+
args = jnp.array(args).transpose(
163+
[1, 0] + [i for i in range(2, max_axis + 1)])
164+
state, stacked = _scan(self._process_scan_method,
165+
init=self.get_required_state(
166+
include_special_compartments=True), xs=args)
116167
if save_state:
117-
self.updated_modified_state(vals)
118-
return vals, stacked
168+
self.updated_modified_state(state)
169+
return state, stacked

0 commit comments

Comments
 (0)