Skip to content

Commit 058bf90

Browse files
committed
JaxProcess update
1 parent aa3b52e commit 058bf90

File tree

2 files changed

+41
-171
lines changed

2 files changed

+41
-171
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from ngcsimlib import JointProcess, MethodProcess
2+
from ngcsimlib.global_state import stateManager
3+
import jax
4+
from typing import TYPE_CHECKING
5+
if TYPE_CHECKING:
6+
from ngcsimlib._src.process.baseProcess import BaseProcess
7+
8+
class JaxProcessesMixin:
9+
def __init__(self: "BaseProcess"):
10+
self._previous_result = None
11+
self._previous_state = None
12+
13+
@property
14+
def previous_result(self):
15+
return self._previous_result
16+
17+
@property
18+
def previous_state(self):
19+
return self._previous_state
20+
21+
def clear(self):
22+
self._previous_result = None
23+
self._previous_state = None
24+
25+
26+
def scan(self: "BaseProcess", inputs, current_state=None, save_state: bool = True, store_results: bool = True):
27+
state = current_state or stateManager.state
28+
final_state, result = jax.lax.scan(self.run.compiled, state, inputs)
29+
if save_state:
30+
self._previous_state = final_state
31+
if store_results:
32+
self._previous_result = result
33+
return final_state, result
34+
35+
36+
37+
class JaxJointProcess(JointProcess, JaxProcessesMixin):
38+
pass
39+
40+
class JaxMethodProcess(MethodProcess, JaxProcessesMixin):
41+
pass

ngclearn/utils/jaxProcess.py

Lines changed: 0 additions & 171 deletions
This file was deleted.

0 commit comments

Comments
 (0)