Skip to content

Commit b9cff51

Browse files
Further enzymexla bump (#785)
* Further enzymexla bump * more enzymejax bump * Update src/Compiler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 9ea612f commit b9cff51

File tree

3 files changed

+7
-13
lines changed

3 files changed

+7
-13
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272
#include "xla/pjrt/pjrt_api.h"
7373
#include "xla/pjrt/pjrt_c_api_client.h"
7474
#include "xla/pjrt/pjrt_executable.h"
75-
#include "xla/pjrt/status_casters.h"
7675

7776
// shardy
7877
#include "shardy/dialect/sdy/ir/dialect.h"
@@ -180,15 +179,10 @@ extern "C" void (*ReactantThrowError)(const char *) = nullptr;
180179

181180
// Utilities for `StatusOr`.
182181
template <typename T> T MyValueOrThrow(absl::StatusOr<T> v) {
183-
if (ReactantThrowError) {
184182
if (!v.ok()) {
185183
ReactantThrowError(v.status().ToString().c_str());
186-
throw xla::XlaRuntimeError(v.status().ToString().c_str());
187184
}
188185
return std::move(v).value();
189-
} else {
190-
return xla::ValueOrThrow(std::move(v));
191-
}
192186
}
193187

194188
extern "C" void ReactantHandleCuResult(uint32_t curesult) {

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "467d1f32b5747f9d9ab7b2314eafc011e6dd7b5b"
12+
ENZYMEXLA_COMMIT = "e748ca63ea8bd3b33354c43d289b253386189a0b"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(

src/Compiler.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -644,16 +644,16 @@ function compile_mlir!(
644644
jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
645645
end
646646

647+
opt_passes = optimization_passes(; no_nan, sroa=true)
648+
opt_passes2 = optimization_passes(; no_nan, sroa=false)
649+
647650
raise = if Raise[]
648-
# "llvm-to-memref-access" # ,canonicalize,convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,canonicalize,func.func(canonicalize-loops),canonicalize-scf-for,canonicalize,affine-cfg,canonicalize,func.func(canonicalize-loops),canonicalize,llvm-to-affine-access,canonicalize,delinearize-indexing,canonicalize"
649-
"canonicalize"
651+
"canonicalize,llvm-to-memref-access,canonicalize,convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,canonicalize,func.func(canonicalize-loops),canonicalize-scf-for,canonicalize,affine-cfg,canonicalize,func.func(canonicalize-loops),canonicalize,llvm-to-affine-access,canonicalize,delinearize-indexing,canonicalize,raise-affine-to-stablehlo,arith-raise{stablehlo=true}," *
652+
opt_passes2
650653
else
651654
"canonicalize"
652655
end
653656

654-
opt_passes = optimization_passes(; no_nan, sroa=true)
655-
opt_passes2 = optimization_passes(; no_nan, sroa=false)
656-
657657
if optimize === :all
658658
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
659659
run_pass_pipeline!(
@@ -669,7 +669,7 @@ function compile_mlir!(
669669
opt_passes2,
670670
kern,
671671
raise,
672-
jit,
672+
jit
673673
],
674674
',',
675675
),

0 commit comments

Comments
 (0)