Skip to content

Commit e45b7b6

Browse files
Re-export AbstractRef from hijax
PiperOrigin-RevId: 834868498
1 parent 22fbb0c commit e45b7b6

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

jax/experimental/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,7 @@ pytype_strict_library(
513513
"//jax/_src:core",
514514
"//jax/_src:effects",
515515
"//jax/_src:hijax",
516+
"//jax/_src:lax",
516517
],
517518
)
518519

jax/experimental/hijax.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
HiPrimitive as HiPrimitive,
3232
HiType as HiType,
3333
MutableHiType as MutableHiType,
34-
register_hitype as register_hitype,
3534
VJPHiPrimitive as VJPHiPrimitive,
35+
register_hitype as register_hitype,
36+
)
37+
from jax._src.state import (
38+
AbstractRef as AbstractRef,
3639
)

0 commit comments

Comments
 (0)