|
| 1 | +from ngcsimlib import JointProcess, MethodProcess |
| 2 | +from ngcsimlib.global_state import stateManager |
| 3 | +import jax |
| 4 | +from typing import TYPE_CHECKING |
| 5 | +if TYPE_CHECKING: |
| 6 | + from ngcsimlib._src.process.baseProcess import BaseProcess |
| 7 | + |
| 8 | +class JaxProcessesMixin: |
| 9 | + def __init__(self: "BaseProcess"): |
| 10 | + self._previous_result = None |
| 11 | + self._previous_state = None |
| 12 | + |
| 13 | + @property |
| 14 | + def previous_result(self): |
| 15 | + return self._previous_result |
| 16 | + |
| 17 | + @property |
| 18 | + def previous_state(self): |
| 19 | + return self._previous_state |
| 20 | + |
| 21 | + def clear(self): |
| 22 | + self._previous_result = None |
| 23 | + self._previous_state = None |
| 24 | + |
| 25 | + |
| 26 | + def scan(self: "BaseProcess", inputs, current_state=None, save_state: bool = True, store_results: bool = True): |
| 27 | + state = current_state or stateManager.state |
| 28 | + final_state, result = jax.lax.scan(self.run.compiled, state, inputs) |
| 29 | + if save_state: |
| 30 | + self._previous_state = final_state |
| 31 | + if store_results: |
| 32 | + self._previous_result = result |
| 33 | + return final_state, result |
| 34 | + |
| 35 | + |
| 36 | + |
| 37 | +class JaxJointProcess(JointProcess, JaxProcessesMixin): |
| 38 | + pass |
| 39 | + |
| 40 | +class JaxMethodProcess(MethodProcess, JaxProcessesMixin): |
| 41 | + pass |
0 commit comments