Skip to content

Commit 54ac172

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas] Refactor Pallas HLO interpret mode to a standalone file.
Also replaces the interpreter context (used only for handling extended dtypes) with a physicalize Jaxpr pass. PiperOrigin-RevId: 720371033
1 parent bc130c7 commit 54ac172

File tree

7 files changed

+519
-309
lines changed

7 files changed

+519
-309
lines changed

jax/_src/pallas/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ py_library(
3232
"core.py",
3333
"cost_estimate.py",
3434
"helpers.py",
35+
"hlo_interpreter.py",
3536
"pallas_call.py",
3637
"primitives.py",
3738
"utils.py",

jax/_src/pallas/core.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -324,22 +324,6 @@ def current_grid_env() -> GridEnv | None:
324324
return _pallas_tracing_env.grid_env_stack[-1]
325325

326326

327-
@contextlib.contextmanager
328-
def interpret_mode_env(interpret_mode: bool) -> Iterator[None]:
329-
prev_interpret = _pallas_tracing_env.is_interpret_mode
330-
if interpret_mode:
331-
_pallas_tracing_env.is_interpret_mode = True
332-
try:
333-
yield
334-
finally:
335-
if interpret_mode:
336-
_pallas_tracing_env.is_interpret_mode = prev_interpret
337-
338-
def is_interpret_mode() -> bool:
339-
"""Returns whether the kernel is executing in interpret mode."""
340-
return _pallas_tracing_env.is_interpret_mode
341-
342-
343327
class Mapped:
344328
"""Used as a block shape dimension to denote a mapped dimension.
345329
A mapped dimension behaves like `1` except it is squeezed from the block.

0 commit comments

Comments
 (0)