Skip to content

Commit 68483b8

Browse files
Merge pull request jax-ml#25710 from apaszke:mgpu_dialect_fix
PiperOrigin-RevId: 711430610
2 parents ac817b4 + 6443343 commit 68483b8

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

jaxlib/mlir/_mlir_libs/BUILD.bazel

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,10 @@ py_extension(
158158
copts = COPTS,
159159
linkopts = LINKOPTS,
160160
deps = [
161-
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
161+
":jaxlib_mlir_capi_shared_library",
162+
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_headers",
162163
"@llvm-project//mlir:CAPIIRHeaders",
163-
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
164+
"@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps",
164165
"@nanobind",
165166
],
166167
)
@@ -380,6 +381,7 @@ cc_library(
380381
name = "jaxlib_mlir_capi_objects",
381382
deps = [
382383
"//jaxlib/mosaic:tpu_dialect_capi_objects",
384+
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects",
383385
"@llvm-project//mlir:CAPIArithObjects",
384386
"@llvm-project//mlir:CAPIGPUObjects",
385387
"@llvm-project//mlir:CAPIIRObjects",

jaxlib/mosaic/dialect/gpu/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,26 @@ cc_library(
215215
"@llvm-project//mlir:CAPIIR",
216216
],
217217
)
218+
219+
# Header-only target, used when using the C API from a separate shared library.
220+
cc_library(
221+
name = "gpu_dialect_capi_headers",
222+
hdrs = DIALECT_CAPI_HEADERS,
223+
deps = [
224+
":mosaic_gpu_inc_gen",
225+
"@llvm-project//mlir:CAPIIRHeaders",
226+
],
227+
)
228+
229+
# Alwayslink target, used when exporting the C API from a shared library.
230+
cc_library(
231+
name = "gpu_dialect_capi_objects",
232+
srcs = DIALECT_CAPI_SOURCES,
233+
hdrs = DIALECT_CAPI_HEADERS,
234+
deps = [
235+
":mosaic_gpu",
236+
":mosaic_gpu_inc_gen",
237+
"@llvm-project//mlir:CAPIIRObjects",
238+
],
239+
alwayslink = True,
240+
)

jaxlib/mosaic/python/mosaic_gpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@
3333
from mlir.dialects._ods_common import _cext # type: ignore[import-not-found]
3434

3535

36-
_cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python")
36+
# Add the parent module to the search prefix
37+
_cext.globals.append_dialect_search_prefix(__name__[:__name__.rfind(".")])

0 commit comments

Comments
 (0)