Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d653c5a
Update llvm, 2025 Q3
paul0403 Jul 17, 2025
bdf648c
Merge remote-tracking branch 'origin/main' into llvm_update_2025Q3
paul0403 Jul 22, 2025
8e6f2e6
patch mhlo std::sort bug
paul0403 Jul 22, 2025
c104695
update enzyme version
paul0403 Jul 22, 2025
f3b4f65
GreedyRewriteConfig.stuff = blah -> GreedyRewriteConfig.setStuff(blah)
paul0403 Jul 22, 2025
b76c38b
bufferization.to_memref -> bufferization.to_buffer
paul0403 Jul 22, 2025
e1ecd75
gep inbounds update
paul0403 Jul 22, 2025
cab23c8
arith.contant[int, float]op builder arg order swap
paul0403 Jul 22, 2025
06e9c28
`lookupOrCreateFn` takes in builder
paul0403 Jul 22, 2025
e44f5b3
gep inbounds flag seems to not need `()`
paul0403 Jul 22, 2025
cc077c1
bufferization.to_tensor op now needs the result type to be explicit
paul0403 Jul 22, 2025
b214802
two missed builder passing into `lookupOrCreateFn`
paul0403 Jul 22, 2025
8919702
things with bufferization
paul0403 Jul 22, 2025
d4d2548
update creations of GEPOp to use LLVM::GEPNoWrapFlags::inbounds
paul0403 Jul 22, 2025
6ac0a96
getStridedElementPtr taking in rewriter as first instead of last arg
paul0403 Jul 22, 2025
f6832cf
`catalyst::convertToDestinationPassingStyle` returns `LogicalResult`
paul0403 Jul 22, 2025
6f17313
propagate logical result for funcop insert/eraseArgument in a few mor…
paul0403 Jul 22, 2025
4b70caf
`getBackwardSlice()` now returns logical result
paul0403 Jul 22, 2025
85202c8
lit test. "to_memref" -> "to_buffer"
paul0403 Jul 22, 2025
05eb3de
a few more to_memref->to_buffer lit tests
paul0403 Jul 22, 2025
dc53788
pipeline printer now wraps lines at round brakets
paul0403 Jul 22, 2025
7c65d90
small fix for a warnings as errors in CI
paul0403 Jul 22, 2025
9325efb
apply_registered_pass op now takes in dict as options instead of strings
paul0403 Jul 23, 2025
4b08c5a
fix a test
paul0403 Jul 23, 2025
dba5a26
change a lit test for new dict option format
paul0403 Jul 23, 2025
3a8fdce
do not nest attr helper under a new context
paul0403 Jul 23, 2025
5d1a132
another pytest needs a manual ir.Context (the test for pass options)
paul0403 Jul 23, 2025
3a63e65
try suppressing nanobind warnings in CI for `make dialects`
paul0403 Jul 23, 2025
ca34f47
revert makefile nb suppress warnings
paul0403 Jul 23, 2025
bb4db28
ignore nanobind warnings in the dialects' python bindings' cmake
paul0403 Jul 23, 2025
edaf215
Merge remote-tracking branch 'origin/main' into llvm_update_2025Q3
paul0403 Jul 24, 2025
114eb0d
patch mlir bufferization segfault
paul0403 Jul 24, 2025
29069d6
changelog
paul0403 Jul 24, 2025
ec19eb9
update llvm patch to deal with nullptr deference instead of revert
paul0403 Jul 24, 2025
513f644
is CI not patching??
paul0403 Jul 24, 2025
8205845
revert patch
paul0403 Jul 25, 2025
eaff36b
use python match-case
paul0403 Jul 25, 2025
eb82dc4
Merge remote-tracking branch 'origin/main' into llvm_update_2025Q3
paul0403 Jul 25, 2025
9938f42
Merge remote-tracking branch 'origin/main' into llvm_update_2025Q3
paul0403 Jul 26, 2025
b93b059
codefactor complex method
paul0403 Jul 26, 2025
60814fd
codecov
paul0403 Jul 26, 2025
89ad730
pylint again
paul0403 Jul 26, 2025
cb48513
patch in wheels scripts
paul0403 Jul 26, 2025
96906d7
forgot a patch in wheels
paul0403 Jul 26, 2025
2038103
unify enzyme patch to also use git apply
paul0403 Jul 28, 2025
522bb02
Update frontend/catalyst/passes/pass_api.py
paul0403 Jul 28, 2025
d82d7ac
check int inside non overflowing range
paul0403 Jul 28, 2025
0b17c6c
remove unitattr from util
paul0403 Jul 28, 2025
dc072e1
remove patch checks in wheels scripts
paul0403 Jul 28, 2025
81b0ab6
mlir::sort
paul0403 Jul 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=0.6.2
mhlo=617a9361d186199480c080c9e8c474a5e30c22d1
llvm=179d30f8c3fddd3c85056fd2b8e877a4a8513158
enzyme=v0.0.180
mhlo=1dd2e71331014ae0373f6bf900ce6be393357190
llvm=f8cb7987c64dcffb72414a40560055cb717dbf74
enzyme=v0.0.186

# Always remove custom PL/LQ versions before release.

Expand Down
2 changes: 1 addition & 1 deletion mlir/Enzyme
Submodule Enzyme updated 82 files
+34 −0 enzyme/BUILD
+8 −2 enzyme/Enzyme/ActivityAnalysis.cpp
+32 −9 enzyme/Enzyme/AdjointGenerator.h
+37 −9 enzyme/Enzyme/CApi.cpp
+7 −5 enzyme/Enzyme/CApi.h
+21 −27 enzyme/Enzyme/CallDerivatives.cpp
+3 −1 enzyme/Enzyme/Clang/EnzymeClang.cpp
+7 −6 enzyme/Enzyme/DiffeGradientUtils.cpp
+3 −2 enzyme/Enzyme/DiffeGradientUtils.h
+13 −2 enzyme/Enzyme/DifferentialUseAnalysis.h
+22 −7 enzyme/Enzyme/Enzyme.cpp
+242 −67 enzyme/Enzyme/EnzymeLogic.cpp
+32 −10 enzyme/Enzyme/EnzymeLogic.h
+2 −2 enzyme/Enzyme/FunctionUtils.cpp
+223 −138 enzyme/Enzyme/GradientUtils.cpp
+11 −7 enzyme/Enzyme/GradientUtils.h
+4 −3 enzyme/Enzyme/InstructionDerivatives.td
+5 −1 enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp
+1,279 −0 enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp
+244 −0 enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h
+1 −0 enzyme/Enzyme/MLIR/Analysis/CMakeLists.txt
+1 −1 enzyme/Enzyme/MLIR/Analysis/DataFlowAliasAnalysis.cpp
+2 −0 enzyme/Enzyme/MLIR/CMakeLists.txt
+1 −0 enzyme/Enzyme/MLIR/Dialect/CMakeLists.txt
+2 −0 enzyme/Enzyme/MLIR/Dialect/Dialect.td
+56 −2 enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
+270 −1 enzyme/Enzyme/MLIR/Dialect/Ops.cpp
+1 −0 enzyme/Enzyme/MLIR/Dialect/Ops.h
+32 −2 enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp
+6 −0 enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt
+24 −3 enzyme/Enzyme/MLIR/Implementations/Common.td
+1 −0 enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp
+1 −0 enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h
+37 −0 enzyme/Enzyme/MLIR/Implementations/EnzymeAutoDiffOpInterfaceImpl.cpp
+3 −0 enzyme/Enzyme/MLIR/Implementations/EnzymeDerivatives.td
+2 −2 enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp
+1 −0 enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td
+9 −5 enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
+14 −0 enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td
+3 −3 enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
+23 −9 enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
+4 −2 enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
+4 −4 enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
+15 −13 enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
+5 −4 enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
+4 −2 enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
+40 −8 enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
+3 −2 enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp
+7 −0 enzyme/Enzyme/MLIR/Passes/Passes.td
+19 −13 enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp
+523 −0 enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
+5 −0 enzyme/Enzyme/MLIR/Passes/RemovalUtils.h
+14 −3 enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp
+97 −0 enzyme/Enzyme/MLIR/Passes/Utils.cpp
+10 −1 enzyme/Enzyme/MLIR/Passes/Utils.h
+0 −1 enzyme/Enzyme/MLIR/enzymemlir-translate/CMakeLists.txt
+1 −1 enzyme/Enzyme/PreserveNVVM.cpp
+2 −0 enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
+11 −0 enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h
+5 −3 enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp
+10 −4 enzyme/Enzyme/Utils.cpp
+81 −35 enzyme/Enzyme/Utils.h
+18 −0 enzyme/WORKSPACE
+6 −0 enzyme/test/Enzyme/ForwardMode/frexp.ll
+82 −0 enzyme/test/Enzyme/ReverseMode/custom-nouse.ll
+57 −0 enzyme/test/Enzyme/ReverseMode/doubleunreachable.ll
+6 −0 enzyme/test/Enzyme/ReverseMode/frexp.ll
+431 −0 enzyme/test/Enzyme/ReverseMode/gcloaded.ll
+3 −3 enzyme/test/Enzyme/ReverseMode/mul_checked.ll
+168 −0 enzyme/test/Enzyme/ReverseMode/rematprimal.ll
+131 −0 enzyme/test/MLIR/ActivityAnalysis/Summaries/basic.mlir
+26 −0 enzyme/test/MLIR/ForwardMode/canonicalize.mlir
+31 −0 enzyme/test/MLIR/ForwardMode/mul_strongzero.mlir
+44 −0 enzyme/test/MLIR/ReverseMode/canonicalize.mlir
+40 −0 enzyme/test/MLIR/ReverseMode/drop_gradients.mlir
+22 −0 enzyme/test/MLIR/ReverseMode/mul_strongzero.mlir
+3 −3 enzyme/test/MLIR/ReverseMode/pow.mlir
+6 −6 enzyme/test/MLIR/ReverseMode/scf_for.mlir
+45 −0 enzyme/test/MLIR/ReverseMode/scf_for_mincut.mlir
+15 −1 enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
+4 −0 enzyme/tools/enzyme-tblgen/blasDeclUpdater.h
+46 −25 enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
5 changes: 5 additions & 0 deletions mlir/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ mhlo:
@if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; then \
git apply $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; \
fi

# Patch a MHLO bug with std::sort
@if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-rename-sort.patch; then \
git apply $(MK_DIR)/patches/mhlo-rename-sort.patch; \
fi
cmake -G Ninja -S mlir-hlo -B $(MHLO_BUILD_DIR) \
-DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \
-DLLVM_ENABLE_ASSERTIONS=ON \
Expand Down
16 changes: 8 additions & 8 deletions mlir/include/Catalyst/Transforms/AsyncUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ bool hasAbortInBlock(Block *block);
bool hasPutsInBlock(Block *block);

// Helper function for creating function declarations
LLVM::LLVMFuncOp lookupOrCreatePersonality(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateAbort(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetValueError(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetTokenError(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateUnrecoverableError(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateAwaitTokenName(ModuleOp);
LLVM::LLVMFuncOp lookupOrCreateAwaitValueName(ModuleOp);
LLVM::LLVMFuncOp lookupOrCreateDropRef(ModuleOp);
LLVM::LLVMFuncOp lookupOrCreatePersonality(OpBuilder &b, ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateAbort(OpBuilder &b, ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetValueError(OpBuilder &b, ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetTokenError(OpBuilder &b, ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateUnrecoverableError(OpBuilder &b, ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateAwaitTokenName(OpBuilder &b, ModuleOp);
LLVM::LLVMFuncOp lookupOrCreateAwaitValueName(OpBuilder &b, ModuleOp);
LLVM::LLVMFuncOp lookupOrCreateDropRef(OpBuilder &b, ModuleOp);

}; // namespace AsyncUtils
39 changes: 21 additions & 18 deletions mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,81 +128,84 @@ LLVM::LLVMFuncOp AsyncUtils::getCaller(LLVM::CallOp callOp)
return callOp->getParentOfType<LLVM::LLVMFuncOp>();
}

LLVM::LLVMFuncOp AsyncUtils::lookupOrCreatePersonality(ModuleOp moduleOp)
LLVM::LLVMFuncOp AsyncUtils::lookupOrCreatePersonality(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
auto i32Ty = IntegerType::get(ctx, 32);
bool isVarArg = true;
return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::personalityName, {}, i32Ty,
isVarArg)
return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::personalityName, {},
i32Ty, isVarArg)
.value();
}

LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAbort(ModuleOp moduleOp)
LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAbort(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
auto voidTy = LLVM::LLVMVoidType::get(ctx);
return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::abortName, {}, voidTy)
return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::abortName, {}, voidTy)
.value();
}

LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitTokenName(ModuleOp moduleOp)
LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitTokenName(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
auto voidTy = LLVM::LLVMVoidType::get(ctx);
return mlir::LLVM::lookupOrCreateFn(
moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitTokenName, {ptrTy}, voidTy)
b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitTokenName, {ptrTy}, voidTy)
.value();
}

LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitValueName(ModuleOp moduleOp)
LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitValueName(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
auto voidTy = LLVM::LLVMVoidType::get(ctx);
return mlir::LLVM::lookupOrCreateFn(
moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitValueName, {ptrTy}, voidTy)
b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitValueName, {ptrTy}, voidTy)
.value();
}

LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateDropRef(ModuleOp moduleOp)
LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateDropRef(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
Type llvmInt64Type = IntegerType::get(moduleOp.getContext(), 64);
auto voidTy = LLVM::LLVMVoidType::get(ctx);
return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeDropRefName,
return mlir::LLVM::lookupOrCreateFn(b, moduleOp,
AsyncUtilsConstants::mlirAsyncRuntimeDropRefName,
{ptrTy, llvmInt64Type}, voidTy)
.value();
}

LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(ModuleOp moduleOp)
LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(OpBuilder &b,
ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
auto voidTy = LLVM::LLVMVoidType::get(ctx);
return mlir::LLVM::lookupOrCreateFn(
moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetValueErrorName, {ptrTy}, voidTy)
b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetValueErrorName, {ptrTy}, voidTy)
.value();
}

LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(ModuleOp moduleOp)
LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(OpBuilder &b,
ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
auto voidTy = LLVM::LLVMVoidType::get(ctx);
return mlir::LLVM::lookupOrCreateFn(
moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetTokenErrorName, {ptrTy}, voidTy)
b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetTokenErrorName, {ptrTy}, voidTy)
.value();
}

LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateUnrecoverableError(ModuleOp moduleOp)
LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateUnrecoverableError(OpBuilder &b, ModuleOp moduleOp)
{
MLIRContext *ctx = moduleOp.getContext();
auto voidTy = LLVM::LLVMVoidType::get(ctx);
return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::unrecoverableErrorName, {},
voidTy)
return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::unrecoverableErrorName,
{}, voidTy)
.value();
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ struct CustomCallOpInterface
MemRefType memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
auto newBuffer =
rewriter.create<bufferization::ToMemrefOp>(op->getLoc(), memrefType, *tensorAlloc);
rewriter.create<bufferization::ToBufferOp>(op->getLoc(), memrefType, *tensorAlloc);
bufferArgs.push_back(newBuffer);
}

Expand Down Expand Up @@ -314,8 +314,8 @@ struct CallbackCallOpInterface
auto shape = tensorTy.getShape();
auto elementTy = tensorTy.getElementType();
auto memrefType = MemRefType::get(shape, elementTy);
auto toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
auto memref = toMemrefOp.getResult();
auto toBufferOp = rewriter.create<bufferization::ToBufferOp>(loc, memrefType, tensor);
auto memref = toBufferOp.getResult();
outmemrefs.push_back(memref);
newInputs.push_back(memref);
}
Expand Down
19 changes: 8 additions & 11 deletions mlir/lib/Catalyst/Transforms/DetectQNodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ LogicalResult AddExceptionHandlingTransform::matchAndRewrite(LLVM::CallOp callOp
auto moduleOp = callOp->getParentOfType<ModuleOp>();
// Here, we are adding a reference to the personality declaration.
// From the documentation: https://llvm.org/docs/ExceptionHandling.html#exception-tables
auto personality = AsyncUtils::lookupOrCreatePersonality(moduleOp);
auto personality = AsyncUtils::lookupOrCreatePersonality(rewriter, moduleOp);

// We annotate the body of the function containing the callop to have a reference
// to the personality.
Expand Down Expand Up @@ -294,7 +294,7 @@ RemoveAbortAndPutsInsertCallTransform::matchAndRewrite(LLVM::CallOp callOp,
// Here, we are declaring an external function which is available in the Catalyst runtime.
// llvm.func @__catalyst__host__rt__unrecoverable_error()
auto moduleOp = callOp->getParentOfType<ModuleOp>();
auto unrecoverableError = AsyncUtils::lookupOrCreateUnrecoverableError(moduleOp);
auto unrecoverableError = AsyncUtils::lookupOrCreateUnrecoverableError(rewriter, moduleOp);

auto callee = maybeCallee.value();
rewriter.modifyOpInPlace(callee, [&] { callee.setLinkage(LLVM::Linkage::Internal); });
Expand Down Expand Up @@ -516,8 +516,8 @@ LogicalResult LivenessAnalysisDropRef::matchAndRewrite(LLVM::CallOp sink,
// llvm.func @mlirAsyncRuntimeAwaitValue(!llvm.ptr)
// llvm.func @mlirAsyncRuntimeAwaitToken(!llvm.ptr)
// llvm.func @mlirAsyncRuntimeDropRef(!llvm.ptr, i64)
auto awaitFnDecl = AsyncUtils::lookupOrCreateAwaitTokenName(moduleOp);
auto dropRefFnDecl = AsyncUtils::lookupOrCreateDropRef(moduleOp);
auto awaitFnDecl = AsyncUtils::lookupOrCreateAwaitTokenName(rewriter, moduleOp);
auto dropRefFnDecl = AsyncUtils::lookupOrCreateDropRef(rewriter, moduleOp);

Type llvmInt64Type = IntegerType::get(sink->getContext(), 64);
auto one = rewriter.getIntegerAttr(llvmInt64Type, 1);
Expand Down Expand Up @@ -871,9 +871,9 @@ void insertErrorCalls(std::vector<Value> tokens, std::vector<Value> values, Bloc
auto moduleOp = landingPad->getParentOfType<ModuleOp>();

LLVM::LLVMFuncOp setTokenError =
AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(moduleOp);
AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(rewriter, moduleOp);
LLVM::LLVMFuncOp setValueError =
AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(moduleOp);
AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(rewriter, moduleOp);
for (auto token : tokens) {
insertCallToMlirAsyncRuntimeErrorFunction(token, setTokenError, failBlock, rewriter);
}
Expand Down Expand Up @@ -918,11 +918,8 @@ struct AddExceptionHandlingPass : impl::AddExceptionHandlingPassBase<AddExceptio
patterns1.add<DetectCallsInAsyncRegionsTransform>(context);

GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
// TODO: Update to the following lines the next time we update llvm
// config.setStrictness(GreedyRewriteStrictness::ExistingOps);
// config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
config.setStrictness(GreedyRewriteStrictness::ExistingOps);
config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns1), config))) {
signalPassFailure();
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ struct GEPOpRewritePattern : public mlir::OpRewritePattern<LLVM::GEPOp> {
mlir::PatternRewriter &rewriter) const override
{
auto defOp = op.getBase().getDefiningOp();
if (op.getInbounds() || (defOp && isa<LLVM::ZeroOp>(defOp))) {
if (op.getNoWrapFlags() == LLVM::GEPNoWrapFlags::inbounds() ||
(defOp && isa<LLVM::ZeroOp>(defOp))) {
return failure();
}
rewriter.startOpModification(op);
op.setInbounds(true);
op.setNoWrapFlags(LLVM::GEPNoWrapFlags::inbounds());
rewriter.finalizeOpModification(op);
return success();
}
Expand Down
14 changes: 4 additions & 10 deletions mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,8 @@ struct AnnotateWithFullyQualifiedNamePass
{
MLIRContext *context = &getContext();
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
// TODO: Update to the following lines the next time we update llvm
// config.setStrictness(GreedyRewriteStrictness::ExistingOps);
// config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
config.setStrictness(GreedyRewriteStrictness::ExistingOps);
config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);

RewritePatternSet annotate(context);
auto root = getOperation();
Expand All @@ -409,11 +406,8 @@ struct InlineNestedSymbolTablePass : PassWrapper<InlineNestedSymbolTablePass, Op
MLIRContext *context = &getContext();

GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
// TODO: Update to the following lines the next time we update llvm
// config.setStrictness(GreedyRewriteStrictness::ExistingOps);
// config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
config.setStrictness(GreedyRewriteStrictness::ExistingOps);
config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
RewritePatternSet renameFunctions(context);

// Get all symbol tables in current symbol table. Will be useful for making sure that
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ struct CustomCallOpPattern : public OpConversionPattern<CustomCallOp> {
rewriter.setInsertionPointToStart(mod.getBody());

LLVM::LLVMFuncOp customCallFnOp =
mlir::LLVM::lookupOrCreateFn(mod, op.getCallTargetName(), {/*args=*/ptr, /*rets=*/ptr},
mlir::LLVM::lookupOrCreateFn(rewriter, mod, op.getCallTargetName(),
{/*args=*/ptr, /*rets=*/ptr},
/*ret_type=*/voidType)
.value();
customCallFnOp.setPrivate();
Expand Down Expand Up @@ -467,7 +468,7 @@ struct DefineCallbackOpPattern : public OpConversionPattern<CallbackOp> {
ModuleOp mod = op->getParentOfType<ModuleOp>();
auto typeConverter = getTypeConverter();
LLVM::LLVMFuncOp customCallFnOp =
mlir::LLVM::lookupOrCreateFn(mod, "__catalyst_inactive_callback",
mlir::LLVM::lookupOrCreateFn(rewriter, mod, "__catalyst_inactive_callback",
{/*args=*/i64, i64, i64},
/*ret_type=*/voidType, isVarArg)
.value();
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,9 @@ struct ForwardOpInterface
// away.
BaseMemRefType resultType = options.unknownTypeConverterFn(
returnVal, *options.defaultMemorySpaceFn(tensorType), options);
Value toMemrefOp =
rewriter.create<bufferization::ToMemrefOp>(loc, resultType, returnVal);
returnValues.push_back(toMemrefOp);
Value toBufferOp =
rewriter.create<bufferization::ToBufferOp>(loc, resultType, returnVal);
returnValues.push_back(toBufferOp);
}

// 3. Rewrite the terminator.
Expand Down Expand Up @@ -579,9 +579,9 @@ struct ReverseOpInterface
// away.
BaseMemRefType resultType = options.unknownTypeConverterFn(
returnVal, *options.defaultMemorySpaceFn(tensorType), options);
Value toMemrefOp =
rewriter.create<bufferization::ToMemrefOp>(loc, resultType, returnVal);
returnValues.push_back(toMemrefOp);
Value toBufferOp =
rewriter.create<bufferization::ToBufferOp>(loc, resultType, returnVal);
returnValues.push_back(toBufferOp);
}

// 3. Rewrite the terminator.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ void FiniteDiffLowering::computeFiniteDiff(PatternRewriter &rewriter, Location l
auto tensorTy = diffArg.getType();
auto memrefTy = bufferization::getMemRefTypeWithStaticIdentityLayout(
cast<TensorType>(tensorTy));
auto toMemrefOp =
rewriter.create<bufferization::ToMemrefOp>(loc, memrefTy, diffArg);
auto toBufferOp =
rewriter.create<bufferization::ToBufferOp>(loc, memrefTy, diffArg);

auto cloneOp = rewriter.create<bufferization::CloneOp>(loc, toMemrefOp);
auto cloneOp = rewriter.create<bufferization::CloneOp>(loc, toBufferOp);

auto toTensorOp =
rewriter.create<bufferization::ToTensorOp>(loc, cloneOp, true);
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ void initializeCotangents(TypeRange primalResultTypes, unsigned activeResult, Va
: activeResultType);

Value zero = builder.create<arith::ConstantFloatOp>(
loc, APFloat(elementType.getFloatSemantics(), 0), elementType);
Value one = builder.create<arith::ConstantFloatOp>(
loc, APFloat(elementType.getFloatSemantics(), 1), elementType);
loc, elementType, APFloat(elementType.getFloatSemantics(), 0));
Value one = builder.create<arith::ConstantFloatOp>(loc, elementType,
APFloat(elementType.getFloatSemantics(), 1));

Value zeroTensor;
if (auto activeResultTensor = dyn_cast<RankedTensorType>(activeResultType)) {
Expand Down Expand Up @@ -397,7 +397,7 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc,
}
else {
jacobians.push_back(rewriter.create<arith::ConstantFloatOp>(
loc, APFloat(0.0), cast<FloatType>(jacobianType)));
loc, cast<FloatType>(jacobianType), APFloat(0.0)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static Value genSelectiveShift(PatternRewriter &rewriter, Location loc, Value pa
}

// Make sure all active iteration variables match the selectors.
Value shiftCondition = rewriter.create<arith::ConstantIntOp>(loc, true, 1);
Value shiftCondition = rewriter.create<arith::ConstantIntOp>(loc, 1, true);
for (auto &[iteration, selector] : selectors) {
Value iterationMatch =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, iteration, selector);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Value buildTensorLinalgGeneric(OpBuilder &builder, Location loc, ValueRange oper
// Initialize the result tensor
FloatType elementType = cast<FloatType>(resultType.getElementType());
Value zero = builder.create<arith::ConstantFloatOp>(
loc, APFloat::getZero(elementType.getFloatSemantics()), elementType);
loc, elementType, APFloat::getZero(elementType.getFloatSemantics()));
Value result =
builder.create<tensor::EmptyOp>(loc, resultType.getShape(), resultType.getElementType());
result = builder.create<linalg::FillOp>(loc, zero, result).getResult(0);
Expand Down
Loading
Loading