2424from jax ._src import dispatch
2525from jax ._src .custom_partitioning import custom_partitioning
2626from jax ._src .interpreters import batching
27+ from jax ._src .interpreters import mlir
2728from jax ._src .lib import cuda_versions
2829from jax ._src import xla_bridge
29- from jax .interpreters import mlir
30- from jax .interpreters import xla
3130from jax ._src .lib .mlir import ir
3231from jax ._src .lib .mlir .dialects import hlo
3332import jax .numpy as jnp
@@ -1018,7 +1017,7 @@ def sharded_impl(*args):
10181017_dot_product_attention_fwd_p = core .Primitive ("dot_product_attention_fwd" )
10191018_dot_product_attention_fwd_p .multiple_results = True
10201019_dot_product_attention_fwd_p .def_impl (
1021- functools .partial (xla .apply_primitive , _dot_product_attention_fwd_p )
1020+ functools .partial (dispatch .apply_primitive , _dot_product_attention_fwd_p )
10221021)
10231022_dot_product_attention_fwd_p .def_abstract_eval (
10241023 _dot_product_attention_fwd_abstract
@@ -1043,7 +1042,7 @@ def sharded_impl(*args):
10431042_dot_product_attention_bwd_p = core .Primitive ("dot_product_attention_bwd" )
10441043_dot_product_attention_bwd_p .multiple_results = True
10451044_dot_product_attention_bwd_p .def_impl (
1046- functools .partial (xla .apply_primitive , _dot_product_attention_bwd_p )
1045+ functools .partial (dispatch .apply_primitive , _dot_product_attention_bwd_p )
10471046)
10481047_dot_product_attention_bwd_p .def_abstract_eval (
10491048 _dot_product_attention_bwd_abstract
@@ -1604,7 +1603,7 @@ def _dot_product_attention_fp8_bwd_partition(
16041603_dot_product_attention_fp8_fwd_p = core .Primitive ("dot_product_attention_fp8_fwd" )
16051604_dot_product_attention_fp8_fwd_p .multiple_results = True
16061605_dot_product_attention_fp8_fwd_p .def_impl (
1607- functools .partial (xla .apply_primitive , _dot_product_attention_fp8_fwd_p )
1606+ functools .partial (dispatch .apply_primitive , _dot_product_attention_fp8_fwd_p )
16081607)
16091608_dot_product_attention_fp8_fwd_p .def_abstract_eval (
16101609 _dot_product_attention_fp8_fwd_abstract
@@ -1629,7 +1628,7 @@ def _dot_product_attention_fp8_bwd_partition(
16291628_dot_product_attention_fp8_bwd_p = core .Primitive ("dot_product_attention_fp8_bwd" )
16301629_dot_product_attention_fp8_bwd_p .multiple_results = True
16311630_dot_product_attention_fp8_bwd_p .def_impl (
1632- functools .partial (xla .apply_primitive , _dot_product_attention_fp8_bwd_p )
1631+ functools .partial (dispatch .apply_primitive , _dot_product_attention_fp8_bwd_p )
16331632)
16341633_dot_product_attention_fp8_bwd_p .def_abstract_eval (
16351634 _dot_product_attention_fp8_bwd_abstract
0 commit comments