Skip to content

Commit 2118c23

Browse files
Update jax-ml/jax to commit d79c1c43fe8c40c3c51743e1796f2d2b43ebfb82 (#1611)
* Update jax-ml/jax to commit d79c1c43fe8c40c3c51743e1796f2d2b43ebfb82 Diff: jax-ml/jax@24e80c4...d79c1c4 * tmp * fix * fix * fix * exclude * fix * fix * fix * fmt * Fix --------- Co-authored-by: enzymead-bot[bot] <238314553+enzymead-bot[bot]@users.noreply.github.com> Co-authored-by: William S. Moses <[email protected]>
1 parent bb0db13 commit 2118c23

File tree

5 files changed

+30
-27
lines changed

5 files changed

+30
-27
lines changed

patches/jax.patch

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
1+
diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD
2+
index 088411939..8a8ae857e 100644
13
--- a/jaxlib/mosaic/BUILD
24
+++ b/jaxlib/mosaic/BUILD
3-
@@ -20,7 +20,7 @@ licenses(["notice"])
4-
package(
5-
default_applicable_licenses = [],
6-
default_visibility = [
7-
- "//jax/experimental:mosaic_users",
8-
+ "//visibility:public",
5+
@@ -48,13 +48,11 @@ cc_library(
6+
"dialect/tpu/util.cc",
7+
"dialect/tpu/vreg_util.cc",
98
],
10-
)
11-
12-
--- a/jaxlib/gpu/BUILD
13-
+++ b/jaxlib/gpu/BUILD
14-
@@ -81,6 +81,7 @@ proto_library(
15-
16-
cc_proto_library(
17-
name = "triton_cc_proto",
18-
+ visibility = "//visibility:public",
19-
compatible_with = None,
20-
deps = [":triton_proto"],
21-
)
9+
- hdrs = [
10+
- "dialect/tpu/array_util.h",
11+
- "dialect/tpu/layout.h",
12+
- "dialect/tpu/tpu_dialect.h",
13+
- "dialect/tpu/util.h",
14+
- "dialect/tpu/vreg_util.h",
15+
- ],
16+
+ hdrs = glob(
17+
+ [
18+
+ "**/*.h",
19+
+ ])
20+
+ ,
21+
# compatible with libtpu
22+
deps = [
23+
":tpu_inc_gen",

src/enzyme_ad/jax/BUILD

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ cc_library(
11271127
"@enzyme//:EnzymeMLIR",
11281128

11291129
# Mosaic
1130-
"@jax//jaxlib/mosaic:tpu_dialect",
1130+
# "@jax//jaxlib/mosaic:tpu_dialect",
11311131

11321132
# SHLO
11331133
"@stablehlo//:stablehlo_ops",
@@ -1280,7 +1280,8 @@ cc_library(
12801280
"@com_google_absl//absl/status:statusor",
12811281

12821282
# Mosaic
1283-
"@jax//jaxlib/mosaic:tpu_dialect",
1283+
# Upstream is broken, re-enable when working
1284+
# "@jax//jaxlib/mosaic:tpu_dialect",
12841285
],
12851286
)
12861287

src/enzyme_ad/jax/gpu.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ struct CuFuncWrapper {
3737
void *func;
3838
};
3939

40-
void noop(void *){};
41-
4240
template <bool withError>
4341
XLA_FFI_Error *initialize(XLA_FFI_CallFrame *call_frame) {
4442
assert(call_frame->attrs.size == 1);
@@ -83,7 +81,7 @@ XLA_FFI_Error *initialize(XLA_FFI_CallFrame *call_frame) {
8381
auto *execution_state = reinterpret_cast<xla::ffi::ExecutionState *>(
8482
internal_api->XLA_FFI_INTERNAL_ExecutionState_Get(ctx));
8583
(void)execution_state->Set(xla::ffi::TypeRegistry::GetTypeId<CuFuncWrapper>(),
86-
cufunc, noop);
84+
cufunc);
8785

8886
return nullptr;
8987
}

third_party/jax/workspace.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,6 @@ def repo(extra_patches = [], override_commit = ""):
3030
strip_prefix = "jax-" + commit,
3131
urls = ["https://github.com/google/jax/archive/{commit}.tar.gz".format(commit = commit)],
3232
patch_cmds = JAX_PATCHES + extra_patches,
33+
patches = ["//:patches/jax.patch"],
34+
patch_args = ["-p1"],
3335
)

workspace.bzl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
JAX_COMMIT = "24e80c494cb5464794730818cea05b60d7a956d7"
1+
JAX_COMMIT = "d79c1c43fe8c40c3c51743e1796f2d2b43ebfb82"
22
JAX_SHA256 = ""
33

44
ENZYME_COMMIT = "6b4a73e3c71e6451c919850acf2999ee04daab12"
@@ -46,9 +46,9 @@ XLA_PATCHES = [
4646
sed -i.bak0 "s/DCHECK_NE(runtime, nullptr/DCHECK_NE(runtime.get(), nullptr/g" xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.cc
4747
""",
4848
# TODO remove
49-
"""
50-
sed -i.bak0 "s/^bool IsSupportedType/static inline bool IsSupportedType/g" xla/backends/cpu/runtime/convolution_lib.cc
51-
""",
49+
#"""
50+
#sed -i.bak0 "s/^bool IsSupportedType/static inline bool IsSupportedType/g" xla/backends/cpu/runtime/convolution_lib.cc
51+
#""",
5252
"""
5353
sed -i.bak0 "s/Node::Leaf(std::forward<decltype(value)>/Node::Leaf(std::forward<T>/g" xla/tuple_tree.h
5454
""",

0 commit comments

Comments
 (0)