diff --git a/patches/jax.patch b/patches/jax.patch index 46b176ebf..b34974886 100644 --- a/patches/jax.patch +++ b/patches/jax.patch @@ -1,21 +1,23 @@ +diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD +index 088411939..8a8ae857e 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD -@@ -20,7 +20,7 @@ licenses(["notice"]) - package( - default_applicable_licenses = [], - default_visibility = [ -- "//jax/experimental:mosaic_users", -+ "//visibility:public", +@@ -48,13 +48,11 @@ cc_library( + "dialect/tpu/util.cc", + "dialect/tpu/vreg_util.cc", ], - ) - ---- a/jaxlib/gpu/BUILD -+++ b/jaxlib/gpu/BUILD -@@ -81,6 +81,7 @@ proto_library( - - cc_proto_library( - name = "triton_cc_proto", -+ visibility = "//visibility:public", - compatible_with = None, - deps = [":triton_proto"], - ) +- hdrs = [ +- "dialect/tpu/array_util.h", +- "dialect/tpu/layout.h", +- "dialect/tpu/tpu_dialect.h", +- "dialect/tpu/util.h", +- "dialect/tpu/vreg_util.h", +- ], ++ hdrs = glob( ++ [ ++ "**/*.h", ++ ]) ++ , + # compatible with libtpu + deps = [ + ":tpu_inc_gen", diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 0d57dbd5e..975f166f9 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -1127,7 +1127,7 @@ cc_library( "@enzyme//:EnzymeMLIR", # Mosaic - "@jax//jaxlib/mosaic:tpu_dialect", + # "@jax//jaxlib/mosaic:tpu_dialect", # SHLO "@stablehlo//:stablehlo_ops", @@ -1280,7 +1280,8 @@ cc_library( "@com_google_absl//absl/status:statusor", # Mosaic - "@jax//jaxlib/mosaic:tpu_dialect", + # Upstream is broken, re-enable when working + # "@jax//jaxlib/mosaic:tpu_dialect", ], ) diff --git a/src/enzyme_ad/jax/gpu.cc b/src/enzyme_ad/jax/gpu.cc index 2cf68fc14..68b22868f 100644 --- a/src/enzyme_ad/jax/gpu.cc +++ b/src/enzyme_ad/jax/gpu.cc @@ -37,8 +37,6 @@ struct CuFuncWrapper { void *func; }; -void noop(void *){}; - template XLA_FFI_Error *initialize(XLA_FFI_CallFrame *call_frame) { assert(call_frame->attrs.size == 1); @@ -83,7 +81,7 @@ XLA_FFI_Error *initialize(XLA_FFI_CallFrame *call_frame) { auto *execution_state = reinterpret_cast( internal_api->XLA_FFI_INTERNAL_ExecutionState_Get(ctx)); (void)execution_state->Set(xla::ffi::TypeRegistry::GetTypeId(), - cufunc, noop); + cufunc); return nullptr; } diff --git a/third_party/jax/workspace.bzl b/third_party/jax/workspace.bzl index 2690dd370..175f9f963 100644 --- a/third_party/jax/workspace.bzl +++ b/third_party/jax/workspace.bzl @@ -30,4 +30,6 @@ def repo(extra_patches = [], override_commit = ""): strip_prefix = "jax-" + commit, urls = ["https://github.com/google/jax/archive/{commit}.tar.gz".format(commit = commit)], patch_cmds = JAX_PATCHES + extra_patches, + patches = ["//:patches/jax.patch"], + patch_args = ["-p1"], ) diff --git a/workspace.bzl b/workspace.bzl index 67bad561c..9aed429c7 100644 --- a/workspace.bzl +++ b/workspace.bzl @@ -1,4 +1,4 @@ -JAX_COMMIT = "24e80c494cb5464794730818cea05b60d7a956d7" +JAX_COMMIT = "d79c1c43fe8c40c3c51743e1796f2d2b43ebfb82" JAX_SHA256 = "" ENZYME_COMMIT = "6b4a73e3c71e6451c919850acf2999ee04daab12" @@ -46,9 +46,9 @@ XLA_PATCHES = [ sed -i.bak0 "s/DCHECK_NE(runtime, nullptr/DCHECK_NE(runtime.get(), nullptr/g" xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.cc """, # TODO remove - """ - sed -i.bak0 "s/^bool IsSupportedType/static inline bool IsSupportedType/g" xla/backends/cpu/runtime/convolution_lib.cc - """, + #""" + #sed -i.bak0 "s/^bool IsSupportedType/static inline bool IsSupportedType/g" xla/backends/cpu/runtime/convolution_lib.cc + #""", """ sed -i.bak0 "s/Node::Leaf(std::forward/Node::Leaf(std::forward/g" xla/tuple_tree.h """,