1+ from ngcsimlib .compartment import Compartment
12from ngcsimlib .compilers .process import Process
23from jax .lax import scan as _scan
34from ngcsimlib .logger import warn
45from jax import numpy as jnp
56
7+
68class JaxProcess (Process ):
79 """
810 The JaxProcess is a subclass of the ngcsimlib Process class. The
911 functionality added by this subclass is the use of the jax scanner to run a
1012 process quickly through the use of jax's JIT compiler.
1113 """
12- def scan (self , compartments_to_monitor = None ,
13- save_state = True , scan_length = None , ** kwargs ):
14+
15+ def __init__ (self , name ):
16+ super ().__init__ (name )
17+ self ._process_scan_method = None
18+ self ._monitoring = []
19+
20+ def _make_scanner (self ):
21+ arg_order = self .get_required_args ()
22+
23+ def _pure (current_state , x ):
24+ v = self .pure (current_state ,
25+ ** {key : value for key , value in zip (arg_order , x )})
26+ return v , [v [m ] for m in self ._monitoring ]
27+
28+ return _pure
29+
30+ def watch (self , compartment ):
31+ """
32+ Adds a compartment to the process to watch during a scan
33+ Args:
34+ compartment: The compartment to watch
35+ """
36+ if not isinstance (compartment , Compartment ):
37+ warn (
38+ "Jax Process trying to watch a value that is not a compartment" )
39+
40+ self ._monitoring .append (compartment .path )
41+ self ._process_scan_method = self ._make_scanner ()
42+
43+ def clear_watch_list (self ):
44+ """
45+ Clears the watch list so no values are watched
46+ """
47+ self ._monitoring = []
48+ self ._process_scan_method = self ._make_scanner ()
49+
50+ def transition (self , transition_call ):
51+ """
52+ Appends to the base transition call to create pure method for use by its
53+ scanner
54+ Args:
55+ transition_call: the transition being passed into the default
56+ process
57+
58+ Returns: this JaxProcess instance for chaining
59+
60+ """
61+ super ().transition (transition_call )
62+ self ._process_scan_method = self ._make_scanner ()
63+ return self
64+
65+ def scan (self , save_state = True , scan_length = None , ** kwargs ):
1466 """
1567 There a quite a few ways to initialize the scan method for the
16- jaxProcess. To start the straight forward arguments are
17- "compartments_to_monitor" and "save_state". Monitoring compartments
18- means at the end of each process cycle record the value of each
19- compartment in the list and then at the end a tuple of concatenated
20- values will be returned that correspond to each compartment in the
21- original list. The save_state flag is simply there to note if the state
68+ jaxProcess. To start the straight forward arguments is "save_state".
69+ The save_state flag is simply there to note if the state
2270 of the model should reflect the final state of the model after the scan
2371 is complete.
2472
73+ This scan method can also watch and report intermediate compartment
74+ values defined through calling the JaxProcess.watch() method watching a
75+ compartment means at the end of each process cycle record the value of
76+ the compartment and then at the end a tuple of concatenated values will
77+ be returned that correspond to each compartment the process is watching.
78+
2579 Where there are options for the arguments is when defining the keyword
2680 arguments for the process. The process will do its best to broadcast all
2781 the inputs to the largest size, so they can be scanned over. This means
@@ -39,7 +93,6 @@ def scan(self, compartments_to_monitor=None,
3993
4094
4195 Args:
42- compartments_to_monitor: A list of compartments to monitor
4396 save_state: A boolean flag to indicate if the model state should be
4497 saved
4598 scan_length: a value to be used to denote the number of iterations
@@ -49,8 +102,6 @@ def scan(self, compartments_to_monitor=None,
49102 Returns: the final state of the model, the stacked output of the scan method
50103
51104 """
52- if compartments_to_monitor is None :
53- compartments_to_monitor = []
54105 arg_order = list (self .get_required_args ())
55106
56107 args = []
@@ -91,28 +142,28 @@ def scan(self, compartments_to_monitor=None,
91142 max_next_axis = 0
92143 new_args = []
93144 for a in args :
94- if len (a .shape ) >= axis + 1 :
145+ if len (a .shape ) >= axis + 1 :
95146 if a .shape [axis ] == current_axis :
96147 new_args .append (a )
97148 else :
98149 warn ("Keyword arguments must all be able to be "
99150 "broadcasted to the largest shape" )
100151 return
101152 else :
102- new_args .append (jnp .zeros (list (a .shape ) + [current_axis ], dtype = a .dtype ) + a .reshape (* a .shape , 1 ))
153+ new_args .append (jnp .zeros (list (a .shape ) + [current_axis ],
154+ dtype = a .dtype ) + a .reshape (
155+ * a .shape , 1 ))
103156
104- if len (a .shape ) > axis + 1 :
105- max_next_axis = max (max_next_axis , a .shape [axis + 1 ])
157+ if len (a .shape ) > axis + 1 :
158+ max_next_axis = max (max_next_axis , a .shape [axis + 1 ])
106159
107160 args = new_args
108161
109- args = jnp .array (args ).transpose ([1 , 0 ] + [i for i in range (2 , max_axis + 1 )])
110-
111- def _pure (current_state , x ):
112- v = self .pure (current_state , ** {key : value for key , value in zip (arg_order , x )})
113- return v , [v [c .path ] for c in compartments_to_monitor ]
114-
115- vals , stacked = _scan (_pure , init = self .get_required_state (include_special_compartments = True ), xs = args )
162+ args = jnp .array (args ).transpose (
163+ [1 , 0 ] + [i for i in range (2 , max_axis + 1 )])
164+ state , stacked = _scan (self ._process_scan_method ,
165+ init = self .get_required_state (
166+ include_special_compartments = True ), xs = args )
116167 if save_state :
117- self .updated_modified_state (vals )
118- return vals , stacked
168+ self .updated_modified_state (state )
169+ return state , stacked
0 commit comments