Skip to content

Commit 6d9bb3a

Browse files
committed
fix
1 parent 9c686e1 commit 6d9bb3a

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

src/enzyme_ad/jax/gpu.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ struct CuFuncWrapper {
3737
void *func;
3838
};
3939

40-
void noop(void *){};
41-
4240
template <bool withError>
4341
XLA_FFI_Error *initialize(XLA_FFI_CallFrame *call_frame) {
4442
assert(call_frame->attrs.size == 1);
@@ -83,7 +81,7 @@ XLA_FFI_Error *initialize(XLA_FFI_CallFrame *call_frame) {
8381
auto *execution_state = reinterpret_cast<xla::ffi::ExecutionState *>(
8482
internal_api->XLA_FFI_INTERNAL_ExecutionState_Get(ctx));
8583
(void)execution_state->Set(xla::ffi::TypeRegistry::GetTypeId<CuFuncWrapper>(),
86-
cufunc, noop);
84+
cufunc);
8785

8886
return nullptr;
8987
}

third_party/jax/workspace.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ def repo(extra_patches = [], override_commit = ""):
3030
strip_prefix = "jax-" + commit,
3131
urls = ["https://github.com/google/jax/archive/{commit}.tar.gz".format(commit = commit)],
3232
patch_cmds = JAX_PATCHES + extra_patches,
33-
patches = ["//:patches/jax.patch"],
33+
patches = ["//:patches/jax.patch"],
3434
patch_args = ["-p1"],
3535
)

0 commit comments

Comments
 (0)