Skip to content

Commit 5a3fc60

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Deprecate public export of mlir.custom_call.
PiperOrigin-RevId: 744722183
1 parent a099b28 commit 5a3fc60

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
3636
* `jax.interpreters.mlir.hlo` and `jax.interpreters.mlir.func_dialect`,
3737
which were accidental exports, have been removed. If needed, they are
3838
available from `jax.extend.mlir`.
39+
* `jax.interpreters.mlir.custom_call` is deprecated. The APIs provided by
40+
{mod}`jax.ffi` should be used instead.
3941
* Several previously-deprecated APIs have been removed, including:
4042
* From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`,
4143
and `shape_from_pyval`.

jax/_src/cudnn/fused_attention_stablehlo.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@
2424
from jax._src import dispatch
2525
from jax._src.custom_partitioning import custom_partitioning
2626
from jax._src.interpreters import batching
27+
from jax._src.interpreters import mlir
2728
from jax._src.lib import cuda_versions
2829
from jax._src import xla_bridge
29-
from jax.interpreters import mlir
30-
from jax.interpreters import xla
3130
from jax._src.lib.mlir import ir
3231
from jax._src.lib.mlir.dialects import hlo
3332
import jax.numpy as jnp
@@ -1018,7 +1017,7 @@ def sharded_impl(*args):
10181017
_dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd")
10191018
_dot_product_attention_fwd_p.multiple_results = True
10201019
_dot_product_attention_fwd_p.def_impl(
1021-
functools.partial(xla.apply_primitive, _dot_product_attention_fwd_p)
1020+
functools.partial(dispatch.apply_primitive, _dot_product_attention_fwd_p)
10221021
)
10231022
_dot_product_attention_fwd_p.def_abstract_eval(
10241023
_dot_product_attention_fwd_abstract
@@ -1043,7 +1042,7 @@ def sharded_impl(*args):
10431042
_dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd")
10441043
_dot_product_attention_bwd_p.multiple_results = True
10451044
_dot_product_attention_bwd_p.def_impl(
1046-
functools.partial(xla.apply_primitive, _dot_product_attention_bwd_p)
1045+
functools.partial(dispatch.apply_primitive, _dot_product_attention_bwd_p)
10471046
)
10481047
_dot_product_attention_bwd_p.def_abstract_eval(
10491048
_dot_product_attention_bwd_abstract
@@ -1604,7 +1603,7 @@ def _dot_product_attention_fp8_bwd_partition(
16041603
_dot_product_attention_fp8_fwd_p = core.Primitive("dot_product_attention_fp8_fwd")
16051604
_dot_product_attention_fp8_fwd_p.multiple_results = True
16061605
_dot_product_attention_fp8_fwd_p.def_impl(
1607-
functools.partial(xla.apply_primitive, _dot_product_attention_fp8_fwd_p)
1606+
functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_fwd_p)
16081607
)
16091608
_dot_product_attention_fp8_fwd_p.def_abstract_eval(
16101609
_dot_product_attention_fp8_fwd_abstract
@@ -1629,7 +1628,7 @@ def _dot_product_attention_fp8_bwd_partition(
16291628
_dot_product_attention_fp8_bwd_p = core.Primitive("dot_product_attention_fp8_bwd")
16301629
_dot_product_attention_fp8_bwd_p.multiple_results = True
16311630
_dot_product_attention_fp8_bwd_p.def_impl(
1632-
functools.partial(xla.apply_primitive, _dot_product_attention_fp8_bwd_p)
1631+
functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_bwd_p)
16331632
)
16341633
_dot_product_attention_fp8_bwd_p.def_abstract_eval(
16351634
_dot_product_attention_fp8_bwd_abstract

jax/interpreters/mlir.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
aval_to_ir_type as aval_to_ir_type,
3434
aval_to_ir_types as aval_to_ir_types,
3535
core_call_lowering as core_call_lowering,
36-
custom_call as custom_call,
36+
custom_call as _custom_call,
3737
dense_bool_elements as dense_bool_elements,
3838
dense_bool_array as dense_bool_array,
3939
dense_int_array as dense_int_array,
@@ -77,3 +77,23 @@
7777
from jax._src.callback import (
7878
emit_python_callback as emit_python_callback,
7979
)
80+
81+
_deprecations = {
82+
# Added Apr 7 2025
83+
"custom_call": (
84+
"mlir.custom_call is deprecated; use the APIs provided by jax.ffi instead.",
85+
_custom_call,
86+
)
87+
}
88+
89+
import typing as _typing
90+
91+
if _typing.TYPE_CHECKING:
92+
custom_call = _custom_call
93+
else:
94+
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
95+
96+
__getattr__ = _deprecation_getattr(__name__, _deprecations)
97+
del _deprecation_getattr
98+
del _typing
99+
del _custom_call

0 commit comments

Comments
 (0)