Skip to content

Commit 70485e3

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Remove accidental exports jax.interpreters.mlir.{hlo,func_dialect}.
These are available via jax.extend.mlir.dialects. No deprecation period because jax.interpreters.mlir is not a stable API. PiperOrigin-RevId: 744712537
1 parent 83572e1 commit 70485e3

File tree

4 files changed

+7
-6
lines changed

4 files changed

+7
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
3333
* Implemented host callback handlers for CPU and GPU devices using XLA's FFI
3434
and removed existing CPU/GPU handlers using XLA's custom call.
3535
* All APIs in `jax.lib.xla_extension` are now deprecated.
36+
* `jax.interpreters.mlir.hlo` and `jax.interpreters.mlir.func_dialect`,
37+
which were accidental exports, have been removed. If needed, they are
38+
available from `jax.extend.mlir`.
3639
* Several previously-deprecated APIs have been removed, including:
3740
* From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`,
3841
and `shape_from_pyval`.

jax/_src/cudnn/fused_attention_stablehlo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from jax._src import xla_bridge
2929
from jax.interpreters import mlir
3030
from jax.interpreters import xla
31-
from jax.interpreters.mlir import hlo
32-
from jax.interpreters.mlir import ir
31+
from jax._src.lib.mlir import ir
32+
from jax._src.lib.mlir.dialects import hlo
3333
import jax.numpy as jnp
3434
from jax.sharding import NamedSharding, PartitionSpec
3535

jax/_src/cudnn/fusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import jax
1717
from jax._src import core as jax_core
1818
from jax.interpreters import mlir
19-
from jax.interpreters.mlir import hlo
20-
from jax.interpreters.mlir import ir
19+
from jax._src.lib.mlir import ir
20+
from jax._src.lib.mlir.dialects import hlo
2121

2222

2323

jax/interpreters/mlir.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@
4343
flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401
4444
flatten_ir_values as flatten_ir_values,
4545
unflatten_ir_values_like_types as unflatten_ir_values_like_types,
46-
func_dialect as func_dialect,
47-
hlo as hlo,
4846
i32_attr as i32_attr,
4947
i64_attr as i64_attr,
5048
ir as ir,

0 commit comments

Comments
 (0)