From 5671ce28601477266bc6f08ba5c26ab6df0238cb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Oct 2025 18:04:41 -0500 Subject: [PATCH 01/12] feat: get_dimension_size batch interface --- .../jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 02e4a0fcb..4ab47c218 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -1692,8 +1692,9 @@ struct SHLOGetDimensionSizeOpBatchInterface auto bcastOp = BroadcastInDimOp::create( builder, src->getLoc(), RankedTensorType::get( - batchSizes, cast(newOp->getResult(0).getType()) - .getElementType()), + batchSizes, + cast(newOp->getResult(0).getType()) + .getElementType()), newOp->getResult(0), builder.getDenseI64ArrayAttr({})); mapper.map(src->getResult(0), bcastOp->getResult(0)); return success(); From 81b9bd0ffce6bdcd2e5ca2f4f2771ae062a40081 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Oct 2025 19:21:03 -0500 Subject: [PATCH 02/12] feat: implement jitcall batching with shlo_generic_batch_op_interface --- .../jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 4ab47c218..02e4a0fcb 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -1692,9 +1692,8 @@ struct SHLOGetDimensionSizeOpBatchInterface auto bcastOp = BroadcastInDimOp::create( builder, src->getLoc(), RankedTensorType::get( - batchSizes, - cast(newOp->getResult(0).getType()) - .getElementType()), + batchSizes, cast(newOp->getResult(0).getType()) + .getElementType()), newOp->getResult(0), builder.getDenseI64ArrayAttr({})); mapper.map(src->getResult(0), bcastOp->getResult(0)); return success(); From 9759ffd4ba35f98bfa9a81472c4b8ab19baf8c25 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Oct 2025 17:57:38 -0500 Subject: [PATCH 03/12] refactor: reuse batching interface for LU factorization --- .../StableHLOAutoDiffOpInterfaceImpl.cpp | 13 +- src/enzyme_ad/jax/Passes/LinalgUtils.cpp | 2 +- src/enzyme_ad/jax/Passes/LinalgUtils.h | 2 +- .../jax/Passes/LowerEnzymeXLALapack.cpp | 10 +- .../jax/Passes/LowerEnzymeXLALinalg.cpp | 394 +++++++----------- src/enzyme_ad/jax/Passes/Passes.td | 1 + 6 files changed, 165 insertions(+), 257 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 02e4a0fcb..aa0e87445 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -32,6 +32,8 @@ #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Dialect/Ops.h" #include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h" #include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" #include "src/enzyme_ad/jax/Utils.h" @@ -1692,8 +1694,9 @@ struct SHLOGetDimensionSizeOpBatchInterface auto bcastOp = BroadcastInDimOp::create( builder, src->getLoc(), RankedTensorType::get( - batchSizes, cast(newOp->getResult(0).getType()) - .getElementType()), + batchSizes, + cast(newOp->getResult(0).getType()) + .getElementType()), newOp->getResult(0), builder.getDenseI64ArrayAttr({})); mapper.map(src->getResult(0), bcastOp->getResult(0)); return success(); @@ -3946,6 +3949,8 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( *context); ConstantOp::attachInterface(*context); + GetDimensionSizeOp::attachInterface( + *context); TransposeOp::attachInterface(*context); IfOp::attachInterface>(*context); WhileOp::attachInterface>(*context); @@ -3975,5 +3980,9 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( AddOp::attachInterface(*context); SubtractOp::attachInterface(*context); + + // TODO: move into its own file + enzymexla::JITCallOp::attachInterface< + SHLOGenericBatchOpInterface>(*context); }); } diff --git a/src/enzyme_ad/jax/Passes/LinalgUtils.cpp b/src/enzyme_ad/jax/Passes/LinalgUtils.cpp index bdae610fe..8623f37a5 100644 --- a/src/enzyme_ad/jax/Passes/LinalgUtils.cpp +++ b/src/enzyme_ad/jax/Passes/LinalgUtils.cpp @@ -38,7 +38,7 @@ mlir::ArrayAttr getSHLOLayout(PatternRewriter &rewriter, return rewriter.getArrayAttr(attrs); } -std::optional lapack_precision_prefix(Type elementType) { +std::optional lapackPrecisionPrefix(Type elementType) { // single-precision float if (elementType.isF32()) { diff --git a/src/enzyme_ad/jax/Passes/LinalgUtils.h b/src/enzyme_ad/jax/Passes/LinalgUtils.h index 46a330cb4..b618f0781 100644 --- a/src/enzyme_ad/jax/Passes/LinalgUtils.h +++ b/src/enzyme_ad/jax/Passes/LinalgUtils.h @@ -17,6 +17,6 @@ mlir::ArrayAttr getSHLOLayout(mlir::PatternRewriter &rewriter, llvm::SmallVector isColMajorArr, int64_t maxNumDims); -std::optional lapack_precision_prefix(mlir::Type elementType); +std::optional lapackPrecisionPrefix(mlir::Type elementType); #endif // ENZYMEXLA_LINALGUTILS_H diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp index f9d5a2f16..0e84af01b 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp @@ -80,7 +80,7 @@ struct GeqrfOpLowering : public OpRewritePattern { auto type_llvm_void = LLVM::LLVMVoidType::get(ctx); std::string fn = "geqrf_"; - if (auto prefix = lapack_precision_prefix(inputElementType)) { + if (auto prefix = lapackPrecisionPrefix(inputElementType)) { fn = *prefix + fn; } else { op->emitOpError() << "Unsupported element type: " << inputElementType; @@ -355,7 +355,7 @@ struct GeqrtOpLowering : public OpRewritePattern { auto type_llvm_void = LLVM::LLVMVoidType::get(ctx); std::string fn = "geqrt_"; - if (auto prefix = lapack_precision_prefix(inputElementType)) { + if (auto prefix = lapackPrecisionPrefix(inputElementType)) { fn = *prefix + fn; } else { op->emitOpError() << "Unsupported element type: " << inputElementType; @@ -567,7 +567,7 @@ struct OrgqrOpLowering : public OpRewritePattern { auto type_llvm_void = LLVM::LLVMVoidType::get(ctx); std::string fn = "gqr_"; - if (auto prefix = lapack_precision_prefix(inputElementType)) { + if (auto prefix = lapackPrecisionPrefix(inputElementType)) { if (prefix == "s" || prefix == "d") fn = *prefix + "or" + fn; else @@ -873,7 +873,7 @@ struct OrmqrOpLowering : public OpRewritePattern { auto type_llvm_char = rewriter.getIntegerType(8); std::string fn = "mqr_"; - if (auto prefix = lapack_precision_prefix(A_eltype)) { + if (auto prefix = lapackPrecisionPrefix(A_eltype)) { if (prefix == "s" || prefix == "d") fn = *prefix + "or" + fn; else @@ -1141,7 +1141,7 @@ struct GemqrtOpLowering : public OpRewritePattern { auto type_llvm_char = rewriter.getIntegerType(8); std::string fn = "gemqrt_"; - if (auto prefix = lapack_precision_prefix(C_eltype)) { + if (auto prefix = lapackPrecisionPrefix(C_eltype)) { fn = *prefix + fn; } else { op->emitOpError() << "Unsupported element type: " << C_eltype; diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp index 43f6ed349..2fb76ee17 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp @@ -1,3 +1,5 @@ +#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Enzyme/MLIR/Passes/EnzymeBatchPass.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -48,95 +50,46 @@ struct LUFactorizationOpLowering auto inputShape = cast(input.getType()).getShape(); auto inputRank = static_cast(inputShape.size()); auto inputType = cast(input.getType()); + auto unbatchedInputType = RankedTensorType::get( + SmallVector(inputType.getShape().end() - 2, + inputType.getShape().end()), + inputType.getElementType()); auto inputElementType = inputType.getElementType(); - const int64_t m = inputShape[inputRank - 2]; - const int64_t n = inputShape[inputRank - 1]; + const int64_t m = inputShape[inputRank - 2]; // TODO: use get_dimension_size + const int64_t n = inputShape[inputRank - 1]; // TODO: use get_dimension_size const int64_t numBatchDims = inputRank - 2; auto pivotType = cast(op.getResult(1).getType()); auto pivotRank = pivotType.getRank(); + auto unbatchedPivotType = RankedTensorType::get( + SmallVector(pivotType.getShape().end() - 1, + pivotType.getShape().end()), + pivotType.getElementType()); + auto infoType = cast(op.getResult(3).getType()); auto infoRank = infoType.getRank(); + auto unbatchedInfoType = + RankedTensorType::get({}, infoType.getElementType()); if (backend == "cpu") { auto moduleOp = op->getParentOfType(); - static int64_t fnNum = 0; auto blasIntType = rewriter.getIntegerType(blasIntWidth); - auto llvmBlasIntType = typeConverter.convertType(blasIntType); + auto intType = RankedTensorType::get({}, blasIntType); auto llvmPtrType = LLVM::LLVMPointerType::get(ctx); auto llvmVoidPtrType = LLVM::LLVMVoidType::get(ctx); std::string lapackFn; - if (inputElementType.isF32()) { - lapackFn = "sgetrf_"; // single-precision float - } else if (inputElementType.isF64()) { - lapackFn = "dgetrf_"; // double-precision float - } else if (auto complexType = dyn_cast(inputElementType)) { - auto elem = complexType.getElementType(); - if (elem.isF32()) { - lapackFn = "cgetrf_"; // single-precision complex - } else if (elem.isF64()) { - lapackFn = "zgetrf_"; // double-precision complex - } else { - op->emitOpError() << "Unsupported complex element type: " << elem; - return rewriter.notifyMatchFailure( - op, "unsupported complex element type"); - } + auto prefix = lapackPrecisionPrefix(inputElementType); + if (prefix) { + lapackFn = "enzymexla_lapack_" + *prefix + "getrf_"; } else { op->emitOpError() << "Unsupported input element type: " << inputElementType; return rewriter.notifyMatchFailure(op, "unsupported input element type"); } - lapackFn = "enzymexla_lapack_" + lapackFn; - - // Generate the LLVM function body - std::string fnName = lapackFn + "wrapper_" + std::to_string(fnNum); - fnNum++; - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - - auto funcType = LLVM::LLVMFunctionType::get( - llvmVoidPtrType, {llvmPtrType, llvmPtrType, llvmPtrType}, false); - - auto func = - LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), fnName, funcType); - rewriter.setInsertionPointToStart(func.addEntryBlock(rewriter)); - - auto ptrSize = - LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmBlasIntType, - rewriter.getIntegerAttr(blasIntType, 1)); - auto mPtr = LLVM::AllocaOp::create(rewriter, op.getLoc(), llvmPtrType, - llvmBlasIntType, ptrSize, 0); - auto nPtr = LLVM::AllocaOp::create(rewriter, op.getLoc(), llvmPtrType, - llvmBlasIntType, ptrSize, 0); - - auto mVal = - LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmBlasIntType, - rewriter.getIntegerAttr(blasIntType, m)); - auto nVal = - LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmBlasIntType, - rewriter.getIntegerAttr(blasIntType, n)); - - LLVM::StoreOp::create(rewriter, op.getLoc(), mVal, mPtr); - LLVM::StoreOp::create(rewriter, op.getLoc(), nVal, nPtr); - - LLVM::CallOp::create(rewriter, op.getLoc(), TypeRange{}, - SymbolRefAttr::get(ctx, lapackFn), - ValueRange{ - mPtr, - nPtr, - func.getArgument(0), - mPtr, - func.getArgument(1), - func.getArgument(2), - }); - - LLVM::ReturnOp::create(rewriter, op.getLoc(), ValueRange{}); - } // Insert function declaration if not already present if (!moduleOp.lookupSymbol(lapackFn)) { @@ -155,197 +108,77 @@ struct LUFactorizationOpLowering // Call the LLVM function with enzymexla.jit_call SmallVector aliases; - for (int i = 0; i < 3; ++i) { - aliases.push_back(stablehlo::OutputOperandAliasAttr::get( - ctx, std::vector{i}, i, std::vector{})); - } - + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{0}, 2, std::vector{})); + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{1}, 4, std::vector{})); + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{2}, 5, std::vector{})); + + auto unbatchedBLASPivotType = RankedTensorType::get( + unbatchedPivotType.getShape(), rewriter.getIntegerType(blasIntWidth)); auto blasPivotType = RankedTensorType::get( pivotType.getShape(), rewriter.getIntegerType(blasIntWidth)); + auto unbatchedBLASInfoType = RankedTensorType::get( + unbatchedInfoType.getShape(), rewriter.getIntegerType(blasIntWidth)); auto blasInfoType = RankedTensorType::get( infoType.getShape(), rewriter.getIntegerType(blasIntWidth)); - SmallVector isColMajorArr = {true, true, true}; - SmallVector operandRanks = {2, 1, 0}; - SmallVector outputRanks = {2, 1, 0}; auto operandLayouts = - getSHLOLayout(rewriter, operandRanks, isColMajorArr, 2); + getSHLOLayout(rewriter, SmallVector{0, 0, 2, 0, 1, 0}, + SmallVector(6, true), 2); auto resultLayouts = - getSHLOLayout(rewriter, outputRanks, isColMajorArr, 2); - - auto iterType = RankedTensorType::get({}, rewriter.getI32Type()); - auto iter = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr(iterType, 0))); - auto zeroConst = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr(iterType, 0))); + getSHLOLayout(rewriter, SmallVector{2, 1, 0}, + SmallVector(3, true), 2); Value factorizedResult, pivotResult, infoResult; + static int64_t fnNum = 0; + std::string wrapperFnName = lapackFn + std::to_string(fnNum++); + + func::FuncOp func = createWrapperFuncOpCPULapack( + rewriter, lapackFn, unbatchedInputType, unbatchedBLASPivotType, + unbatchedBLASInfoType, blasIntType, wrapperFnName, op, operandLayouts, + resultLayouts, rewriter.getArrayAttr(aliases)); + if (!func) + return rewriter.notifyMatchFailure(op, + "failed to create wrapper function"); + + SmallVector batchOps; + SmallVector batchFunctions; if (numBatchDims > 0) { // TODO: Implement batched LU factorizations by directly calling MKL // https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-fortran/2024-0/getrf-batch-strided.html. + SmallVector batchShape(inputShape.begin(), + inputShape.begin() + numBatchDims); - int64_t batchSize = 1; - for (int i = 0; i < numBatchDims; i++) { - batchSize *= inputShape[i]; - } - SmallVector flattenedInput = {batchSize, m, n}; - - auto flatInputType = - RankedTensorType::get(flattenedInput, inputElementType); - auto flatInput = stablehlo::ReshapeOp::create(rewriter, op.getLoc(), - flatInputType, input); - - auto flatPivotType = RankedTensorType::get( - {batchSize, pivotType.getShape()[pivotRank - 1]}, blasIntType); - auto flatPivot = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), flatPivotType, - cast(makeAttr(flatPivotType, -1))); - - auto flatInfoType = RankedTensorType::get({batchSize}, blasIntType); - auto flatInfo = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), flatInfoType, - cast(makeAttr(flatInfoType, -1))); - - auto whileReturnTypes = {iterType, flatInputType, flatPivotType, - flatInfoType}; - auto whileOp = stablehlo::WhileOp::create( - rewriter, op.getLoc(), - TypeRange{iterType, flatInputType, flatPivotType, flatInfoType}, - ValueRange{iter, flatInput, flatPivot, flatInfo}); - - { - OpBuilder::InsertionGuard guard(rewriter); - - Block *block = rewriter.createBlock(&whileOp.getCond()); - rewriter.setInsertionPointToStart(block); - - for (auto type : whileReturnTypes) { - block->addArgument(type, whileOp.getLoc()); - } - - auto batchSizeConst = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr(iterType, batchSize))); - - auto comparison = stablehlo::CompareOp::create( - rewriter, op.getLoc(), block->getArgument(0), batchSizeConst, - stablehlo::ComparisonDirection::LT); - - stablehlo::ReturnOp::create(rewriter, op.getLoc(), - ValueRange{comparison.getResult()}); - } - - { - OpBuilder::InsertionGuard guard(rewriter); - - Block *block = rewriter.createBlock(&whileOp.getBody()); - rewriter.setInsertionPointToStart(block); - - for (auto type : whileReturnTypes) { - block->addArgument(type, whileOp.getLoc()); - } + auto batchOp = rewriter.create( + op.getLoc(), TypeRange{inputType, blasPivotType, blasInfoType}, + mlir::FlatSymbolRefAttr::get(op.getContext(), wrapperFnName), + ValueRange{input}, rewriter.getDenseI64ArrayAttr(batchShape)); - auto iterArg = block->getArgument(0); + factorizedResult = batchOp.getResult(0); + pivotResult = batchOp.getResult(1); + infoResult = batchOp.getResult(2); - auto inputSliceType = RankedTensorType::get({m, n}, inputElementType); - auto inputSlice = stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), inputSliceType, - stablehlo::DynamicSliceOp::create( - rewriter, op.getLoc(), block->getArgument(1), - ValueRange{iterArg, zeroConst, zeroConst}, - rewriter.getDenseI64ArrayAttr({1, m, n}))); - - auto pivotSliceType = - RankedTensorType::get({std::min(m, n)}, blasIntType); - auto pivotSlice = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), pivotSliceType, - cast(makeAttr(pivotSliceType, -1))); - - auto infoSliceType = RankedTensorType::get({}, blasIntType); - auto infoSlice = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), infoSliceType, - cast(makeAttr(infoSliceType, -1))); - - auto jitCall = enzymexla::JITCallOp::create( - rewriter, op.getLoc(), - TypeRange{inputSliceType, pivotSliceType, infoSliceType}, - mlir::FlatSymbolRefAttr::get(ctx, fnName), - ValueRange{inputSlice, pivotSlice, infoSlice}, - rewriter.getStringAttr(""), - /*operand_layouts=*/operandLayouts, - /*result_layouts=*/resultLayouts, - /*arg_attrs=*/nullptr, - /*res_attrs=*/nullptr, - /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), - /*xla_side_effect_free=*/rewriter.getUnitAttr()); - - auto inputUpdated = stablehlo::DynamicUpdateSliceOp::create( - rewriter, op.getLoc(), block->getArgument(1), - stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), - RankedTensorType::get({1, m, n}, inputElementType), - jitCall.getResult(0)), - ValueRange{iterArg, zeroConst, zeroConst}); - auto pivotUpdated = stablehlo::DynamicUpdateSliceOp::create( - rewriter, op.getLoc(), block->getArgument(2), - stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), - RankedTensorType::get({1, std::min(m, n)}, blasIntType), - jitCall.getResult(1)), - ValueRange{iterArg, zeroConst}); - auto infoUpdated = stablehlo::DynamicUpdateSliceOp::create( - rewriter, op.getLoc(), block->getArgument(3), - stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), - RankedTensorType::get({1}, blasIntType), - jitCall.getResult(2)), - ValueRange{iterArg}); - - auto updatedIter = stablehlo::AddOp::create( - rewriter, op.getLoc(), block->getArgument(0), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr(iterType, 1)))); - - stablehlo::ReturnOp::create( - rewriter, op.getLoc(), - ValueRange{updatedIter, inputUpdated, pivotUpdated, infoUpdated}); - } - - factorizedResult = stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), inputType, whileOp.getResult(1)); - pivotResult = stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), blasPivotType, whileOp.getResult(2)); - infoResult = stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), blasInfoType, whileOp.getResult(3)); + batchOps.push_back(batchOp); + batchFunctions.push_back( + cast(func.getOperation())); } else { - auto pivot = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), blasPivotType, - cast(makeAttr(blasPivotType, -1))); - auto info = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), blasInfoType, - cast(makeAttr(blasInfoType, -1))); - - auto jitCall = enzymexla::JITCallOp::create( - rewriter, op.getLoc(), - TypeRange{inputType, blasPivotType, blasInfoType}, - mlir::FlatSymbolRefAttr::get(ctx, fnName), - ValueRange{input, pivot, info}, rewriter.getStringAttr(""), - /*operand_layouts=*/operandLayouts, - /*result_layouts=*/resultLayouts, - /*arg_attrs=*/nullptr, - /*res_attrs=*/nullptr, - /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), - /*xla_side_effect_free=*/rewriter.getUnitAttr()); - - factorizedResult = jitCall.getResult(0); - pivotResult = jitCall.getResult(1); - infoResult = jitCall.getResult(2); + auto callOp = + rewriter.create(op.getLoc(), func, ValueRange{input}); + + factorizedResult = callOp.getResult(0); + pivotResult = callOp.getResult(1); + infoResult = callOp.getResult(2); } + auto iterType = RankedTensorType::get({}, rewriter.getI32Type()); + auto iter = rewriter.create( + op.getLoc(), iterType, cast(makeAttr(iterType, 0))); + auto zeroConst = rewriter.create( + op.getLoc(), iterType, cast(makeAttr(iterType, 0))); + auto pivots0indexed = stablehlo::SubtractOp::create( rewriter, op.getLoc(), pivotResult, stablehlo::ConstantOp::create( @@ -494,15 +327,20 @@ struct LUFactorizationOpLowering rewriter.replaceAllUsesWith(op.getResult(0), factorizedResult); rewriter.replaceAllUsesWith( - op.getResult(1), stablehlo::ConvertOp::create( - rewriter, op.getLoc(), pivotType, pivotResult)); - rewriter.replaceAllUsesWith( - op.getResult(2), - stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, - finalPermutation)); - rewriter.replaceAllUsesWith( - op.getResult(3), stablehlo::ConvertOp::create(rewriter, op.getLoc(), - infoType, infoResult)); + op.getResult(2), rewriter.create( + op.getLoc(), pivotType, finalPermutation)); + rewriter.replaceAllUsesWith(op.getResult(3), + rewriter.create( + op.getLoc(), infoType, infoResult)); + + std::map + batchedFunctionCache; + for (auto [batchOp, func] : llvm::zip(batchOps, batchFunctions)) { + if (failed(enzyme::batchutils::batchOperation(rewriter, batchOp, func, + batchedFunctionCache))) { + return rewriter.notifyMatchFailure(op, "failed to batch operation"); + } + } return success(); } else if (backend == "cuda") { @@ -695,6 +533,66 @@ struct LUFactorizationOpLowering return rewriter.notifyMatchFailure(op, "Unknown backend " + backend); } } + +private: + func::FuncOp createWrapperFuncOpCPULapack( + PatternRewriter &rewriter, const std::string &lapackFn, + RankedTensorType inputType, RankedTensorType blasPivotType, + RankedTensorType blasInfoType, Type blasIntType, + const std::string &fnName, enzymexla::LUFactorizationOp op, + ArrayAttr operandLayouts, ArrayAttr resultLayouts, + ArrayAttr outputOperandAliases) const { + auto ctx = op->getContext(); + + OpBuilder::InsertionGuard guard(rewriter); + auto moduleOp = op->getParentOfType(); + if (!moduleOp) + return nullptr; + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + SmallVector argTypes = {inputType}; + SmallVector retTypes = {inputType, blasPivotType, blasInfoType}; + + FunctionType calleeType = rewriter.getFunctionType(argTypes, retTypes); + func::FuncOp func = + rewriter.create(op.getLoc(), fnName, calleeType); + func.setPrivate(); + + auto &entryBlock = *func.addEntryBlock(); + rewriter.setInsertionPointToStart(&entryBlock); + + auto input = entryBlock.getArgument(0); + auto mSize = rewriter.create( + op.getLoc(), RankedTensorType::get({}, blasIntType), + rewriter.create(op.getLoc(), input, 0)); + auto nSize = rewriter.create( + op.getLoc(), RankedTensorType::get({}, blasIntType), + rewriter.create(op.getLoc(), input, 1)); + auto pivot = rewriter.create( + op.getLoc(), blasPivotType, + cast(makeAttr(blasPivotType, -1))); + auto info = rewriter.create( + op.getLoc(), blasInfoType, + cast(makeAttr(blasInfoType, -1))); + + auto jitCall = rewriter.create( + op.getLoc(), TypeRange{inputType, blasPivotType, blasInfoType}, + mlir::FlatSymbolRefAttr::get(ctx, lapackFn), + ValueRange{mSize, nSize, input, mSize, pivot, info}, + rewriter.getStringAttr(""), + /*operand_layouts=*/operandLayouts, + /*result_layouts=*/resultLayouts, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/outputOperandAliases, + /*xla_side_effect_free=*/rewriter.getUnitAttr()); + + rewriter.create( + op.getLoc(), ValueRange{jitCall.getResult(0), jitCall.getResult(1), + jitCall.getResult(2)}); + + return func; + } }; struct SVDFactorizationOpLowering @@ -755,7 +653,7 @@ struct SVDFactorizationOpLowering // TODO change SVD method with attributes std::string fn = "gesvd_"; - if (auto prefix = lapack_precision_prefix(inputElementType)) { + if (auto prefix = lapackPrecisionPrefix(inputElementType)) { fn = *prefix + fn; } else { op->emitOpError() << "Unsupported complex element type: " diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 0ce38ff15..42622da99 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -416,6 +416,7 @@ def LowerEnzymeXLALinalgPass : Pass<"lower-enzymexla-linalg"> { "stablehlo::StablehloDialect", "enzymexla::EnzymeXLADialect", "LLVM::LLVMDialect", + "enzyme::EnzymeDialect", ]; let options = [ From 8f54d70220aca72e0f9347d3ad50f183722257ae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Oct 2025 19:35:58 -0500 Subject: [PATCH 04/12] fix: remove old changes --- .../StableHLOAutoDiffOpInterfaceImpl.cpp | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index aa0e87445..02e4a0fcb 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -32,8 +32,6 @@ #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" -#include "src/enzyme_ad/jax/Dialect/Dialect.h" -#include "src/enzyme_ad/jax/Dialect/Ops.h" #include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h" #include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" #include "src/enzyme_ad/jax/Utils.h" @@ -1694,9 +1692,8 @@ struct SHLOGetDimensionSizeOpBatchInterface auto bcastOp = BroadcastInDimOp::create( builder, src->getLoc(), RankedTensorType::get( - batchSizes, - cast(newOp->getResult(0).getType()) - .getElementType()), + batchSizes, cast(newOp->getResult(0).getType()) + .getElementType()), newOp->getResult(0), builder.getDenseI64ArrayAttr({})); mapper.map(src->getResult(0), bcastOp->getResult(0)); return success(); @@ -3949,8 +3946,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( *context); ConstantOp::attachInterface(*context); - GetDimensionSizeOp::attachInterface( - *context); TransposeOp::attachInterface(*context); IfOp::attachInterface>(*context); WhileOp::attachInterface>(*context); @@ -3980,9 +3975,5 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( AddOp::attachInterface(*context); SubtractOp::attachInterface(*context); - - // TODO: move into its own file - enzymexla::JITCallOp::attachInterface< - SHLOGenericBatchOpInterface>(*context); }); } From b789363324edc5c1177f665c0d5cafa4f9d7fb9d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Oct 2025 20:12:28 -0500 Subject: [PATCH 05/12] refactor: move into separate functions --- .../jax/Passes/LowerEnzymeXLALapack.cpp | 81 +- .../jax/Passes/LowerEnzymeXLALinalg.cpp | 989 +++++++++--------- 2 files changed, 533 insertions(+), 537 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp index 0e84af01b..78c869294 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp @@ -41,14 +41,11 @@ struct GeqrfOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(enzymexla::GeqrfOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); else if (backend == "cuda") - return this->matchAndRewrite_cuda(op, rewriter); - + return matchAndRewriteCUDA(op, rewriter); else if (backend == "tpu") - return this->matchAndRewrite_tpu(op, rewriter); - + return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -56,8 +53,8 @@ struct GeqrfOpLowering : public OpRewritePattern { // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance - LogicalResult matchAndRewrite_cpu(enzymexla::GeqrfOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::GeqrfOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -209,8 +206,8 @@ struct GeqrfOpLowering : public OpRewritePattern { return success(); } - LogicalResult matchAndRewrite_cuda(enzymexla::GeqrfOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCUDA(enzymexla::GeqrfOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -265,8 +262,8 @@ struct GeqrfOpLowering : public OpRewritePattern { return success(); } - LogicalResult matchAndRewrite_tpu(enzymexla::GeqrfOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteTPU(enzymexla::GeqrfOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -316,14 +313,11 @@ struct GeqrtOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(enzymexla::GeqrtOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); // else if (backend == "cuda") - // return this->matchAndRewrite_cuda(op, rewriter); - + // return matchAndRewriteCUDA(op, rewriter); // else if (backend == "tpu") - // return this->matchAndRewrite_tpu(op, rewriter); - + // return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -331,8 +325,8 @@ struct GeqrtOpLowering : public OpRewritePattern { // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance - LogicalResult matchAndRewrite_cpu(enzymexla::GeqrtOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::GeqrtOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -523,14 +517,11 @@ struct OrgqrOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(enzymexla::OrgqrOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); else if (backend == "cuda") - return this->matchAndRewrite_cuda(op, rewriter); - + return matchAndRewriteCUDA(op, rewriter); else if (backend == "tpu") - return this->matchAndRewrite_tpu(op, rewriter); - + return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -538,8 +529,8 @@ struct OrgqrOpLowering : public OpRewritePattern { // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance - LogicalResult matchAndRewrite_cpu(enzymexla::OrgqrOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::OrgqrOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -688,8 +679,8 @@ struct OrgqrOpLowering : public OpRewritePattern { return success(); } - LogicalResult matchAndRewrite_cuda(enzymexla::OrgqrOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCUDA(enzymexla::OrgqrOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -734,8 +725,8 @@ struct OrgqrOpLowering : public OpRewritePattern { return success(); } - LogicalResult matchAndRewrite_tpu(enzymexla::OrgqrOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteTPU(enzymexla::OrgqrOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -772,14 +763,11 @@ struct OrmqrOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(enzymexla::OrmqrOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); // else if (backend == "cuda") - // return this->matchAndRewrite_cuda(op, rewriter); - + // return matchAndRewriteCUDA(op, rewriter); // else if (backend == "tpu") - // return this->matchAndRewrite_tpu(op, rewriter); - + // return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -787,8 +775,8 @@ struct OrmqrOpLowering : public OpRewritePattern { // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance - LogicalResult matchAndRewrite_cpu(enzymexla::OrmqrOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::OrmqrOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -1031,14 +1019,11 @@ struct GemqrtOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(enzymexla::GemqrtOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); // else if (backend == "cuda") - // return this->matchAndRewrite_cuda(op, rewriter); - + // return matchAndRewriteCUDA(op, rewriter); // else if (backend == "tpu") - // return this->matchAndRewrite_tpu(op, rewriter); - + // return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -1046,8 +1031,8 @@ struct GemqrtOpLowering : public OpRewritePattern { // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance - LogicalResult matchAndRewrite_cpu(enzymexla::GemqrtOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::GemqrtOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp index 2fb76ee17..fafa7ea18 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp @@ -43,8 +43,81 @@ struct LUFactorizationOpLowering LogicalResult matchAndRewrite(enzymexla::LUFactorizationOp op, PatternRewriter &rewriter) const override { + + if (backend == "cpu") { + return matchAndRewriteCPU(op, rewriter); + } else if (backend == "cuda") { + return matchAndRewriteCUDA(op, rewriter); + } else if (backend == "tpu") { + return matchAndRewriteTPU(op, rewriter); + } else { + return rewriter.notifyMatchFailure(op, "Unknown backend " + backend); + } + } + +private: + func::FuncOp createWrapperFuncOpCPULapack( + PatternRewriter &rewriter, const std::string &lapackFn, + RankedTensorType inputType, RankedTensorType blasPivotType, + RankedTensorType blasInfoType, Type blasIntType, + const std::string &fnName, enzymexla::LUFactorizationOp op, + ArrayAttr operandLayouts, ArrayAttr resultLayouts, + ArrayAttr outputOperandAliases) const { + auto ctx = op->getContext(); + + OpBuilder::InsertionGuard guard(rewriter); + auto moduleOp = op->getParentOfType(); + if (!moduleOp) + return nullptr; + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + SmallVector argTypes = {inputType}; + SmallVector retTypes = {inputType, blasPivotType, blasInfoType}; + + FunctionType calleeType = rewriter.getFunctionType(argTypes, retTypes); + func::FuncOp func = + rewriter.create(op.getLoc(), fnName, calleeType); + func.setPrivate(); + + auto &entryBlock = *func.addEntryBlock(); + rewriter.setInsertionPointToStart(&entryBlock); + + auto input = entryBlock.getArgument(0); + auto mSize = rewriter.create( + op.getLoc(), RankedTensorType::get({}, blasIntType), + rewriter.create(op.getLoc(), input, 0)); + auto nSize = rewriter.create( + op.getLoc(), RankedTensorType::get({}, blasIntType), + rewriter.create(op.getLoc(), input, 1)); + auto pivot = rewriter.create( + op.getLoc(), blasPivotType, + cast(makeAttr(blasPivotType, -1))); + auto info = rewriter.create( + op.getLoc(), blasInfoType, + cast(makeAttr(blasInfoType, -1))); + + auto jitCall = rewriter.create( + op.getLoc(), TypeRange{inputType, blasPivotType, blasInfoType}, + mlir::FlatSymbolRefAttr::get(ctx, lapackFn), + ValueRange{mSize, nSize, input, mSize, pivot, info}, + rewriter.getStringAttr(""), + /*operand_layouts=*/operandLayouts, + /*result_layouts=*/resultLayouts, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/outputOperandAliases, + /*xla_side_effect_free=*/rewriter.getUnitAttr()); + + rewriter.create( + op.getLoc(), ValueRange{jitCall.getResult(0), jitCall.getResult(1), + jitCall.getResult(2)}); + + return func; + } + + LogicalResult matchAndRewriteCPU(enzymexla::LUFactorizationOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); - LLVMTypeConverter typeConverter(ctx); auto input = op.getOperand(); auto inputShape = cast(input.getType()).getShape(); @@ -56,8 +129,6 @@ struct LUFactorizationOpLowering inputType.getElementType()); auto inputElementType = inputType.getElementType(); - const int64_t m = inputShape[inputRank - 2]; // TODO: use get_dimension_size - const int64_t n = inputShape[inputRank - 1]; // TODO: use get_dimension_size const int64_t numBatchDims = inputRank - 2; auto pivotType = cast(op.getResult(1).getType()); @@ -72,295 +143,334 @@ struct LUFactorizationOpLowering auto unbatchedInfoType = RankedTensorType::get({}, infoType.getElementType()); - if (backend == "cpu") { - auto moduleOp = op->getParentOfType(); - - auto blasIntType = rewriter.getIntegerType(blasIntWidth); - auto intType = RankedTensorType::get({}, blasIntType); - auto llvmPtrType = LLVM::LLVMPointerType::get(ctx); - auto llvmVoidPtrType = LLVM::LLVMVoidType::get(ctx); - - std::string lapackFn; - auto prefix = lapackPrecisionPrefix(inputElementType); - if (prefix) { - lapackFn = "enzymexla_lapack_" + *prefix + "getrf_"; - } else { - op->emitOpError() << "Unsupported input element type: " - << inputElementType; - return rewriter.notifyMatchFailure(op, - "unsupported input element type"); - } + auto moduleOp = op->getParentOfType(); - // Insert function declaration if not already present - if (!moduleOp.lookupSymbol(lapackFn)) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto blasIntType = rewriter.getIntegerType(blasIntWidth); + auto intType = RankedTensorType::get({}, blasIntType); + auto llvmPtrType = LLVM::LLVMPointerType::get(ctx); + auto llvmVoidPtrType = LLVM::LLVMVoidType::get(ctx); - auto funcType = - LLVM::LLVMFunctionType::get(llvmVoidPtrType, - {llvmPtrType, llvmPtrType, llvmPtrType, - llvmPtrType, llvmPtrType, llvmPtrType}, - false); + std::string lapackFn; + auto prefix = lapackPrecisionPrefix(inputElementType); + if (prefix) { + lapackFn = "enzymexla_lapack_" + *prefix + "getrf_"; + } else { + op->emitOpError() << "Unsupported input element type: " + << inputElementType; + return rewriter.notifyMatchFailure(op, "unsupported input element type"); + } - LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), lapackFn, funcType, - LLVM::Linkage::External); - } + // Insert function declaration if not already present + if (!moduleOp.lookupSymbol(lapackFn)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); - // Call the LLVM function with enzymexla.jit_call - SmallVector aliases; - aliases.push_back(stablehlo::OutputOperandAliasAttr::get( - ctx, std::vector{0}, 2, std::vector{})); - aliases.push_back(stablehlo::OutputOperandAliasAttr::get( - ctx, std::vector{1}, 4, std::vector{})); - aliases.push_back(stablehlo::OutputOperandAliasAttr::get( - ctx, std::vector{2}, 5, std::vector{})); - - auto unbatchedBLASPivotType = RankedTensorType::get( - unbatchedPivotType.getShape(), rewriter.getIntegerType(blasIntWidth)); - auto blasPivotType = RankedTensorType::get( - pivotType.getShape(), rewriter.getIntegerType(blasIntWidth)); - auto unbatchedBLASInfoType = RankedTensorType::get( - unbatchedInfoType.getShape(), rewriter.getIntegerType(blasIntWidth)); - auto blasInfoType = RankedTensorType::get( - infoType.getShape(), rewriter.getIntegerType(blasIntWidth)); - - auto operandLayouts = - getSHLOLayout(rewriter, SmallVector{0, 0, 2, 0, 1, 0}, - SmallVector(6, true), 2); - auto resultLayouts = - getSHLOLayout(rewriter, SmallVector{2, 1, 0}, - SmallVector(3, true), 2); - - Value factorizedResult, pivotResult, infoResult; - static int64_t fnNum = 0; - std::string wrapperFnName = lapackFn + std::to_string(fnNum++); - - func::FuncOp func = createWrapperFuncOpCPULapack( - rewriter, lapackFn, unbatchedInputType, unbatchedBLASPivotType, - unbatchedBLASInfoType, blasIntType, wrapperFnName, op, operandLayouts, - resultLayouts, rewriter.getArrayAttr(aliases)); - if (!func) - return rewriter.notifyMatchFailure(op, - "failed to create wrapper function"); - - SmallVector batchOps; - SmallVector batchFunctions; - - if (numBatchDims > 0) { - // TODO: Implement batched LU factorizations by directly calling MKL - // https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-fortran/2024-0/getrf-batch-strided.html. - SmallVector batchShape(inputShape.begin(), - inputShape.begin() + numBatchDims); - - auto batchOp = rewriter.create( - op.getLoc(), TypeRange{inputType, blasPivotType, blasInfoType}, - mlir::FlatSymbolRefAttr::get(op.getContext(), wrapperFnName), - ValueRange{input}, rewriter.getDenseI64ArrayAttr(batchShape)); - - factorizedResult = batchOp.getResult(0); - pivotResult = batchOp.getResult(1); - infoResult = batchOp.getResult(2); - - batchOps.push_back(batchOp); - batchFunctions.push_back( - cast(func.getOperation())); - } else { - auto callOp = - rewriter.create(op.getLoc(), func, ValueRange{input}); - - factorizedResult = callOp.getResult(0); - pivotResult = callOp.getResult(1); - infoResult = callOp.getResult(2); - } + auto funcType = + LLVM::LLVMFunctionType::get(llvmVoidPtrType, + {llvmPtrType, llvmPtrType, llvmPtrType, + llvmPtrType, llvmPtrType, llvmPtrType}, + false); - auto iterType = RankedTensorType::get({}, rewriter.getI32Type()); - auto iter = rewriter.create( - op.getLoc(), iterType, cast(makeAttr(iterType, 0))); - auto zeroConst = rewriter.create( - op.getLoc(), iterType, cast(makeAttr(iterType, 0))); + rewriter.create(op.getLoc(), lapackFn, funcType, + LLVM::Linkage::External); + } - auto pivots0indexed = stablehlo::SubtractOp::create( - rewriter, op.getLoc(), pivotResult, - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), blasPivotType, - cast(makeAttr(blasPivotType, 1)))); + // Call the LLVM function with enzymexla.jit_call + SmallVector aliases; + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{0}, 2, std::vector{})); + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{1}, 4, std::vector{})); + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{2}, 5, std::vector{})); - auto permutation = stablehlo::IotaOp::create( - rewriter, op.getLoc(), blasPivotType, - rewriter.getI64IntegerAttr(blasPivotType.getRank() - 1)); + auto unbatchedBLASPivotType = RankedTensorType::get( + unbatchedPivotType.getShape(), rewriter.getIntegerType(blasIntWidth)); + auto blasPivotType = RankedTensorType::get( + pivotType.getShape(), rewriter.getIntegerType(blasIntWidth)); + auto unbatchedBLASInfoType = RankedTensorType::get( + unbatchedInfoType.getShape(), rewriter.getIntegerType(blasIntWidth)); + auto blasInfoType = RankedTensorType::get( + infoType.getShape(), rewriter.getIntegerType(blasIntWidth)); - auto pivotToPermReturnTypes = {iterType, blasPivotType}; - auto pivotToPermWhileOp = stablehlo::WhileOp::create( - rewriter, op.getLoc(), TypeRange{iterType, blasPivotType}, - ValueRange{iter, permutation}); + auto operandLayouts = + getSHLOLayout(rewriter, SmallVector{0, 0, 2, 0, 1, 0}, + SmallVector(6, true), 2); + auto resultLayouts = + getSHLOLayout(rewriter, SmallVector{2, 1, 0}, + SmallVector(3, true), 2); + + Value factorizedResult, pivotResult, infoResult; + static int64_t fnNum = 0; + std::string wrapperFnName = lapackFn + std::to_string(fnNum++); + + func::FuncOp func = createWrapperFuncOpCPULapack( + rewriter, lapackFn, unbatchedInputType, unbatchedBLASPivotType, + unbatchedBLASInfoType, blasIntType, wrapperFnName, op, operandLayouts, + resultLayouts, rewriter.getArrayAttr(aliases)); + if (!func) + return rewriter.notifyMatchFailure(op, + "failed to create wrapper function"); - { - OpBuilder::InsertionGuard guard(rewriter); + SmallVector batchOps; + SmallVector batchFunctions; - Block *block = rewriter.createBlock(&pivotToPermWhileOp.getCond()); - rewriter.setInsertionPointToStart(block); + if (numBatchDims > 0) { + // TODO: Implement batched LU factorizations by directly calling MKL + // https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-fortran/2024-0/getrf-batch-strided.html. + SmallVector batchShape(inputShape.begin(), + inputShape.begin() + numBatchDims); + + auto batchOp = rewriter.create( + op.getLoc(), TypeRange{inputType, blasPivotType, blasInfoType}, + mlir::FlatSymbolRefAttr::get(op.getContext(), wrapperFnName), + ValueRange{input}, rewriter.getDenseI64ArrayAttr(batchShape)); + + factorizedResult = batchOp.getResult(0); + pivotResult = batchOp.getResult(1); + infoResult = batchOp.getResult(2); + + batchOps.push_back(batchOp); + batchFunctions.push_back(cast(func.getOperation())); + } else { + auto callOp = + rewriter.create(op.getLoc(), func, ValueRange{input}); - for (auto type : pivotToPermReturnTypes) - block->addArgument(type, pivotToPermWhileOp.getLoc()); + factorizedResult = callOp.getResult(0); + pivotResult = callOp.getResult(1); + infoResult = callOp.getResult(2); + } - auto pivotShapeConst = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr( - iterType, pivotType.getShape()[pivotType.getRank() - 1]))); + auto iterType = RankedTensorType::get({}, rewriter.getI32Type()); + auto iter = rewriter.create( + op.getLoc(), iterType, cast(makeAttr(iterType, 0))); + auto zeroConst = rewriter.create( + op.getLoc(), iterType, cast(makeAttr(iterType, 0))); - auto comparison = stablehlo::CompareOp::create( - rewriter, op.getLoc(), block->getArgument(0), pivotShapeConst, - stablehlo::ComparisonDirection::LT); + auto pivots0indexed = rewriter.create( + op.getLoc(), pivotResult, + rewriter.create( + op.getLoc(), blasPivotType, + cast(makeAttr(blasPivotType, 1)))); - stablehlo::ReturnOp::create(rewriter, op.getLoc(), - ValueRange{comparison.getResult()}); + auto permutation = rewriter.create( + op.getLoc(), blasPivotType, + rewriter.getI64IntegerAttr(blasPivotType.getRank() - 1)); + + auto pivotToPermReturnTypes = {iterType, blasPivotType}; + auto pivotToPermWhileOp = rewriter.create( + op.getLoc(), TypeRange{iterType, blasPivotType}, + ValueRange{iter, permutation}); + + { + OpBuilder::InsertionGuard guard(rewriter); + + Block *block = rewriter.createBlock(&pivotToPermWhileOp.getCond()); + rewriter.setInsertionPointToStart(block); + + for (auto type : pivotToPermReturnTypes) + block->addArgument(type, pivotToPermWhileOp.getLoc()); + + auto pivotShapeConst = rewriter.create( + op.getLoc(), iterType, + cast(makeAttr( + iterType, pivotType.getShape()[pivotType.getRank() - 1]))); + + auto comparison = rewriter.create( + op.getLoc(), block->getArgument(0), pivotShapeConst, + stablehlo::ComparisonDirection::LT); + + rewriter.create(op.getLoc(), + ValueRange{comparison.getResult()}); + } + + { + OpBuilder::InsertionGuard guard(rewriter); + + Block *block = rewriter.createBlock(&pivotToPermWhileOp.getBody()); + rewriter.setInsertionPointToStart(block); + + for (auto type : pivotToPermReturnTypes) + block->addArgument(type, pivotToPermWhileOp.getLoc()); + + auto iterArg = block->getArgument(0); + + auto updatedIter = rewriter.create( + op.getLoc(), iterArg, + rewriter.create( + op.getLoc(), iterType, + cast(makeAttr(iterType, 1)))); + + /* + for i in range(pivot.shape[-1]): + j = pivot[..., i] # dynamic slice + x = permutation[..., i] # dynamic slice + y = permutation[j] # gather + permutation[..., i] = y # dynamic update slice + permutation[j] = x # scatter + */ + + SmallVector indices; + SmallVector sliceShape, batchDims; + for (int i = 0; i < numBatchDims; i++) { + indices.push_back(zeroConst); + sliceShape.push_back(pivotType.getShape()[i]); + batchDims.push_back(i); } + indices.push_back(iterArg); + sliceShape.push_back(1); + SmallVector gatherSliceSizes(numBatchDims + 1, 1); + + auto pivotJ = rewriter.create( + op.getLoc(), pivots0indexed, indices, sliceShape); + auto permutationX = rewriter.create( + op.getLoc(), block->getArgument(1), indices, sliceShape); + + auto gatherDims = stablehlo::GatherDimensionNumbersAttr::get( + op.getContext(), + /*offsetDims=*/{numBatchDims}, + /*collapsedSliceDims=*/{}, + /*operandBatchingDims=*/batchDims, + /*startIndicesBatchingDims=*/batchDims, + /*startIndexMap=*/{numBatchDims}, + /*indexVectorDim=*/numBatchDims); + auto permutationY = rewriter.create( + op.getLoc(), + RankedTensorType::get(sliceShape, cast( + block->getArgument(1).getType()) + .getElementType()), + block->getArgument(1), pivotJ.getResult(), gatherDims, + gatherSliceSizes); + + auto permutationUpdate1 = + rewriter.create( + op.getLoc(), block->getArgument(1), permutationY->getResult(0), + indices); + + auto scatterDims = stablehlo::ScatterDimensionNumbersAttr::get( + op.getContext(), + /*updateWindowDims=*/{}, + /*insertedWindowDims=*/{numBatchDims}, + /*inputBatchingDims=*/batchDims, + /*scatterIndicesBatchingDims=*/batchDims, + /*scatterDimsToOperandDims=*/{numBatchDims}, + /*indexVectorDim=*/numBatchDims); + SmallVector scatterShape(sliceShape.begin(), + sliceShape.end() - 1); + auto permutationUpdate2 = rewriter.create( + op.getLoc(), TypeRange{permutationUpdate1->getResult(0).getType()}, + ValueRange(permutationUpdate1->getResult(0)), pivotJ, + ValueRange(rewriter.create( + op.getLoc(), + RankedTensorType::get(scatterShape, + permutationX.getType().getElementType()), + permutationX)), + scatterDims); { OpBuilder::InsertionGuard guard(rewriter); - - Block *block = rewriter.createBlock(&pivotToPermWhileOp.getBody()); + auto *block = + rewriter.createBlock(&permutationUpdate2.getUpdateComputation()); + block->addArgument(RankedTensorType::get({}, blasIntType), op.getLoc()); + block->addArgument(RankedTensorType::get({}, blasIntType), op.getLoc()); rewriter.setInsertionPointToStart(block); - for (auto type : pivotToPermReturnTypes) - block->addArgument(type, pivotToPermWhileOp.getLoc()); - - auto iterArg = block->getArgument(0); - - auto updatedIter = stablehlo::AddOp::create( - rewriter, op.getLoc(), iterArg, - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr(iterType, 1)))); - - /* - for i in range(pivot.shape[-1]): - j = pivot[..., i] # dynamic slice - x = permutation[..., i] # dynamic slice - y = permutation[j] # gather - permutation[..., i] = y # dynamic update slice - permutation[j] = x # scatter - */ - - SmallVector indices; - SmallVector sliceShape, batchDims; - for (int i = 0; i < numBatchDims; i++) { - indices.push_back(zeroConst); - sliceShape.push_back(pivotType.getShape()[i]); - batchDims.push_back(i); - } - indices.push_back(iterArg); - sliceShape.push_back(1); - SmallVector gatherSliceSizes(numBatchDims + 1, 1); - - auto pivotJ = stablehlo::DynamicSliceOp::create( - rewriter, op.getLoc(), pivots0indexed, indices, sliceShape); - auto permutationX = stablehlo::DynamicSliceOp::create( - rewriter, op.getLoc(), block->getArgument(1), indices, sliceShape); - - auto gatherDims = stablehlo::GatherDimensionNumbersAttr::get( - op.getContext(), - /*offsetDims=*/{numBatchDims}, - /*collapsedSliceDims=*/{}, - /*operandBatchingDims=*/batchDims, - /*startIndicesBatchingDims=*/batchDims, - /*startIndexMap=*/{numBatchDims}, - /*indexVectorDim=*/numBatchDims); - auto permutationY = stablehlo::GatherOp::create( - rewriter, op.getLoc(), - RankedTensorType::get( - sliceShape, - cast(block->getArgument(1).getType()) - .getElementType()), - block->getArgument(1), pivotJ.getResult(), gatherDims, - gatherSliceSizes); - - auto permutationUpdate1 = stablehlo::DynamicUpdateSliceOp::create( - rewriter, op.getLoc(), block->getArgument(1), - permutationY->getResult(0), indices); - - auto scatterDims = stablehlo::ScatterDimensionNumbersAttr::get( - op.getContext(), - /*updateWindowDims=*/{}, - /*insertedWindowDims=*/{numBatchDims}, - /*inputBatchingDims=*/batchDims, - /*scatterIndicesBatchingDims=*/batchDims, - /*scatterDimsToOperandDims=*/{numBatchDims}, - /*indexVectorDim=*/numBatchDims); - SmallVector scatterShape(sliceShape.begin(), - sliceShape.end() - 1); - auto permutationUpdate2 = stablehlo::ScatterOp::create( - rewriter, op.getLoc(), - TypeRange{permutationUpdate1->getResult(0).getType()}, - ValueRange(permutationUpdate1->getResult(0)), pivotJ, - ValueRange(stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), - RankedTensorType::get(scatterShape, - permutationX.getType().getElementType()), - permutationX)), - scatterDims); - - { - OpBuilder::InsertionGuard guard(rewriter); - auto *block = - rewriter.createBlock(&permutationUpdate2.getUpdateComputation()); - block->addArgument(RankedTensorType::get({}, blasIntType), - op.getLoc()); - block->addArgument(RankedTensorType::get({}, blasIntType), - op.getLoc()); - rewriter.setInsertionPointToStart(block); - - stablehlo::ReturnOp::create(rewriter, op.getLoc(), - ValueRange{block->getArgument(1)}); - } - - stablehlo::ReturnOp::create( - rewriter, op.getLoc(), - ValueRange{updatedIter, permutationUpdate2->getResult(0)}); + rewriter.create(op.getLoc(), + ValueRange{block->getArgument(1)}); } - auto finalPermutation = stablehlo::AddOp::create( - rewriter, op.getLoc(), pivotToPermWhileOp.getResult(1), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), blasPivotType, - cast(makeAttr(blasPivotType, 1)))); + rewriter.create( + op.getLoc(), + ValueRange{updatedIter, permutationUpdate2->getResult(0)}); + } - rewriter.replaceAllUsesWith(op.getResult(0), factorizedResult); - rewriter.replaceAllUsesWith( - op.getResult(2), rewriter.create( - op.getLoc(), pivotType, finalPermutation)); - rewriter.replaceAllUsesWith(op.getResult(3), - rewriter.create( - op.getLoc(), infoType, infoResult)); - - std::map - batchedFunctionCache; - for (auto [batchOp, func] : llvm::zip(batchOps, batchFunctions)) { - if (failed(enzyme::batchutils::batchOperation(rewriter, batchOp, func, - batchedFunctionCache))) { - return rewriter.notifyMatchFailure(op, "failed to batch operation"); - } + auto finalPermutation = rewriter.create( + op.getLoc(), pivotToPermWhileOp.getResult(1), + rewriter.create( + op.getLoc(), blasPivotType, + cast(makeAttr(blasPivotType, 1)))); + + rewriter.replaceAllUsesWith(op.getResult(0), factorizedResult); + rewriter.replaceAllUsesWith(op.getResult(1), + rewriter.create( + op.getLoc(), pivotType, pivotResult)); + rewriter.replaceAllUsesWith(op.getResult(2), + rewriter.create( + op.getLoc(), pivotType, finalPermutation)); + rewriter.replaceAllUsesWith(op.getResult(3), + rewriter.create( + op.getLoc(), infoType, infoResult)); + + std::map + batchedFunctionCache; + for (auto [batchOp, func] : llvm::zip(batchOps, batchFunctions)) { + if (failed(enzyme::batchutils::batchOperation(rewriter, batchOp, func, + batchedFunctionCache))) { + return rewriter.notifyMatchFailure(op, "failed to batch operation"); } + } - return success(); - } else if (backend == "cuda") { - SmallVector aliases = {stablehlo::OutputOperandAliasAttr::get( - ctx, std::vector{0}, 0, std::vector{})}; - - SmallVector isColMajorArrOperands = {true}; - SmallVector operandRanks = {inputRank}; - SmallVector isColMajorArrOutputs = {true, true, true}; - SmallVector outputRanks = {inputRank, pivotRank, infoRank}; - - auto pivotCuSolverType = - RankedTensorType::get(pivotType.getShape(), rewriter.getI32Type()); - auto infoCuSolverType = - RankedTensorType::get(infoType.getShape(), rewriter.getI32Type()); - - auto cusolverffi = stablehlo::CustomCallOp::create( - rewriter, op.getLoc(), - TypeRange{inputType, pivotCuSolverType, infoCuSolverType}, - ValueRange{input}, rewriter.getStringAttr("cusolver_getrf_ffi"), + return success(); + } + + LogicalResult matchAndRewriteCUDA(enzymexla::LUFactorizationOp op, + PatternRewriter &rewriter) const { + auto ctx = op->getContext(); + + auto input = op.getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); + auto inputRank = inputType.getRank(); + auto numBatchDims = inputRank - 2; + + auto pivotType = cast(op.getResult(1).getType()); + auto pivotRank = pivotType.getRank(); + auto infoType = cast(op.getResult(3).getType()); + auto infoRank = infoType.getRank(); + + SmallVector aliases = {stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{0}, 0, std::vector{})}; + + SmallVector isColMajorArrOperands = {true}; + SmallVector operandRanks = {inputRank}; + SmallVector isColMajorArrOutputs = {true, true, true}; + SmallVector outputRanks = {inputRank, pivotRank, infoRank}; + + auto pivotCuSolverType = + RankedTensorType::get(pivotType.getShape(), rewriter.getI32Type()); + auto infoCuSolverType = + RankedTensorType::get(infoType.getShape(), rewriter.getI32Type()); + + auto cusolverffi = rewriter.create( + op.getLoc(), TypeRange{inputType, pivotCuSolverType, infoCuSolverType}, + ValueRange{input}, rewriter.getStringAttr("cusolver_getrf_ffi"), + /*has_side_effect*/ nullptr, + /*backend_config*/ nullptr, + /*api_version*/ + stablehlo::CustomCallApiVersionAttr::get( + rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), + /*calledcomputations*/ nullptr, + /*operand_layouts*/ + getSHLOLayout(rewriter, operandRanks, isColMajorArrOperands, inputRank), + /*result_layouts*/ + getSHLOLayout(rewriter, outputRanks, isColMajorArrOutputs, inputRank), + /*output_operand_aliases*/ rewriter.getArrayAttr(aliases)); + + // unused custom call not getting optimized away. so adding a manual check + if (!op.getResult(2).getUses().empty()) { + auto pivots0indexed = rewriter.create( + op.getLoc(), cusolverffi.getResult(1), + rewriter.create( + op.getLoc(), pivotCuSolverType, + cast(makeAttr(pivotCuSolverType, 1)))); + + SmallVector outputRanksPermutation = {pivotRank}; + + auto permutation = rewriter.create( + op.getLoc(), TypeRange{pivotCuSolverType}, + ValueRange{pivots0indexed.getResult()}, + rewriter.getStringAttr("cu_lu_pivots_to_permutation"), /*has_side_effect*/ nullptr, /*backend_config*/ nullptr, /*api_version*/ @@ -369,229 +479,136 @@ struct LUFactorizationOpLowering mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), /*calledcomputations*/ nullptr, /*operand_layouts*/ - getSHLOLayout(rewriter, operandRanks, isColMajorArrOperands, - inputRank), + getSHLOLayout(rewriter, SmallVector{pivotRank}, + SmallVector{true}, inputRank), /*result_layouts*/ - getSHLOLayout(rewriter, outputRanks, isColMajorArrOutputs, inputRank), - /*output_operand_aliases*/ rewriter.getArrayAttr(aliases)); - - // unused custom call not getting optimized away. so adding a manual check - if (!op.getResult(2).getUses().empty()) { - auto pivots0indexed = stablehlo::SubtractOp::create( - rewriter, op.getLoc(), cusolverffi.getResult(1), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), pivotCuSolverType, - cast(makeAttr(pivotCuSolverType, 1)))); - - SmallVector isColMajorArrOperandsPermutation = {true}; - SmallVector operandRanksPermutation = {pivotRank}; - SmallVector isColMajorArrOutputsPermutation = {true}; - SmallVector outputRanksPermutation = {pivotRank}; - - auto permutation = stablehlo::CustomCallOp::create( - rewriter, op.getLoc(), TypeRange{pivotCuSolverType}, - ValueRange{pivots0indexed.getResult()}, - rewriter.getStringAttr("cu_lu_pivots_to_permutation"), - /*has_side_effect*/ nullptr, - /*backend_config*/ nullptr, - /*api_version*/ - stablehlo::CustomCallApiVersionAttr::get( - rewriter.getContext(), - mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), - /*calledcomputations*/ nullptr, - /*operand_layouts*/ - getSHLOLayout(rewriter, operandRanksPermutation, - isColMajorArrOperandsPermutation, inputRank), - /*result_layouts*/ - getSHLOLayout(rewriter, outputRanksPermutation, - isColMajorArrOutputsPermutation, inputRank), - /*output_operand_aliases*/ nullptr); - auto permutation1Indexed = stablehlo::AddOp::create( - rewriter, op.getLoc(), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), pivotCuSolverType, - cast(makeAttr(pivotCuSolverType, 1))), - permutation.getResult(0)); - - rewriter.replaceAllUsesWith( - op.getResult(2), - stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, - permutation1Indexed)); - } - - rewriter.replaceAllUsesWith(op.getResult(0), cusolverffi.getResult(0)); - rewriter.replaceAllUsesWith( - op.getResult(1), - stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, - cusolverffi.getResult(1))); - rewriter.replaceAllUsesWith( - op.getResult(3), - stablehlo::ConvertOp::create(rewriter, op.getLoc(), infoType, - cusolverffi.getResult(2))); - - return success(); - } else if (backend == "tpu") { - SmallVector permutationShape; - for (int i = 0; i < numBatchDims; i++) { - permutationShape.push_back(inputShape[i]); - } - permutationShape.push_back(m); - auto permutationType = - RankedTensorType::get(permutationShape, rewriter.getI32Type()); - - auto pivotTPUType = - RankedTensorType::get(pivotType.getShape(), rewriter.getI32Type()); - - // TPU returns (LU, pivots, permutation). info isn't returned. based on - // how JAX operates, I am assuming info != 0 when there is a nan in the - // output. - auto customCall = stablehlo::CustomCallOp::create( - rewriter, op.getLoc(), - TypeRange{inputType, pivotTPUType, permutationType}, - ValueRange{input}, rewriter.getStringAttr("LuDecomposition"), - /*has_side_effect*/ nullptr, - /*backend_config*/ nullptr, - /*api_version*/ nullptr, - /*calledcomputations*/ nullptr, - /*operand_layouts*/ nullptr, - /*result_layouts*/ nullptr, + getSHLOLayout(rewriter, SmallVector{pivotRank}, + SmallVector{true}, inputRank), /*output_operand_aliases*/ nullptr); + auto permutation1Indexed = rewriter.create( + op.getLoc(), + rewriter.create( + op.getLoc(), pivotCuSolverType, + cast(makeAttr(pivotCuSolverType, 1))), + permutation.getResult(0)); - // LAPACK returns 1-indexed pivots, while XLA returns 0-indexed pivots. We - // make it consistent with LAPACK by adding 1 to the pivots. - auto pivots1Indexed = stablehlo::AddOp::create( - rewriter, op.getLoc(), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), pivotType, - cast(makeAttr(pivotType, 1))), - stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, - customCall.getResult(1))); - - auto permutation1Indexed = stablehlo::AddOp::create( - rewriter, op.getLoc(), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), permutationType, - cast(makeAttr(permutationType, 1))), - stablehlo::ConvertOp::create(rewriter, op.getLoc(), permutationType, - customCall.getResult(2))); - - auto isFinite = stablehlo::AndOp::create( - rewriter, op.getLoc(), - stablehlo::IsFiniteOp::create( - rewriter, op.getLoc(), - stablehlo::RealOp::create(rewriter, op.getLoc(), - customCall.getResult(0))), - stablehlo::IsFiniteOp::create( - rewriter, op.getLoc(), - stablehlo::ImagOp::create(rewriter, op.getLoc(), - customCall.getResult(0)))); - - SmallVector reductionDims; - for (int i = numBatchDims; i < inputRank; i++) - reductionDims.push_back(i); - auto initValType = RankedTensorType::get({}, rewriter.getI1Type()); - auto initVal = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), initValType, - cast(makeAttr(initValType, 1))); - - auto allFinite = stablehlo::ReduceOp::create( - rewriter, op.getLoc(), - RankedTensorType::get(infoType.getShape(), rewriter.getI1Type()), - ValueRange{isFinite.getResult()}, ValueRange{initVal}, - rewriter.getDenseI64ArrayAttr(reductionDims)); - - { - OpBuilder::InsertionGuard guard(rewriter); - auto ®ion = allFinite.getBody(); - auto *block = - rewriter.createBlock(®ion, {}, {initValType, initValType}, - {op.getLoc(), op.getLoc()}); - - rewriter.setInsertionPointToStart(block); - auto lhs = block->getArgument(0); - auto rhs = block->getArgument(1); - auto andOp = stablehlo::AndOp::create(rewriter, op.getLoc(), lhs, rhs); + rewriter.replaceAllUsesWith( + op.getResult(2), rewriter.create( + op.getLoc(), pivotType, permutation1Indexed)); + } - stablehlo::ReturnOp::create(rewriter, op.getLoc(), - ValueRange{andOp.getResult()}); - } + rewriter.replaceAllUsesWith(op.getResult(0), cusolverffi.getResult(0)); + rewriter.replaceAllUsesWith( + op.getResult(1), rewriter.create( + op.getLoc(), pivotType, cusolverffi.getResult(1))); + rewriter.replaceAllUsesWith( + op.getResult(3), rewriter.create( + op.getLoc(), infoType, cusolverffi.getResult(2))); - // info == 0 if all finite (success) - auto info = stablehlo::ConvertOp::create( - rewriter, op.getLoc(), infoType, - stablehlo::NotOp::create(rewriter, op.getLoc(), - allFinite.getResult(0))); + return success(); + } - rewriter.replaceAllUsesWith(op.getResult(0), customCall.getResult(0)); - rewriter.replaceAllUsesWith(op.getResult(1), pivots1Indexed); - rewriter.replaceAllUsesWith(op.getResult(2), permutation1Indexed); - rewriter.replaceAllUsesWith(op.getResult(3), info); - rewriter.eraseOp(op); + LogicalResult matchAndRewriteTPU(enzymexla::LUFactorizationOp op, + PatternRewriter &rewriter) const { + auto input = op.getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); - return success(); - } else { - return rewriter.notifyMatchFailure(op, "Unknown backend " + backend); - } - } + auto inputRank = inputType.getRank(); + auto numBatchDims = inputRank - 2; -private: - func::FuncOp createWrapperFuncOpCPULapack( - PatternRewriter &rewriter, const std::string &lapackFn, - RankedTensorType inputType, RankedTensorType blasPivotType, - RankedTensorType blasInfoType, Type blasIntType, - const std::string &fnName, enzymexla::LUFactorizationOp op, - ArrayAttr operandLayouts, ArrayAttr resultLayouts, - ArrayAttr outputOperandAliases) const { - auto ctx = op->getContext(); + auto pivotType = cast(op.getResult(1).getType()); + auto infoType = cast(op.getResult(3).getType()); - OpBuilder::InsertionGuard guard(rewriter); - auto moduleOp = op->getParentOfType(); - if (!moduleOp) - return nullptr; - rewriter.setInsertionPointToStart(moduleOp.getBody()); + SmallVector permutationShape(inputShape.begin(), + inputShape.end() - 2); + permutationShape.push_back(inputShape[0]); + auto permutationType = + RankedTensorType::get(permutationShape, rewriter.getI32Type()); + + auto pivotTPUType = + RankedTensorType::get(pivotType.getShape(), rewriter.getI32Type()); + + // TPU returns (LU, pivots, permutation). info isn't returned. based on + // how JAX operates, I am assuming info != 0 when there is a nan in the + // output. + auto customCall = rewriter.create( + op.getLoc(), TypeRange{inputType, pivotTPUType, permutationType}, + ValueRange{input}, rewriter.getStringAttr("LuDecomposition"), + /*has_side_effect*/ nullptr, + /*backend_config*/ nullptr, + /*api_version*/ nullptr, + /*calledcomputations*/ nullptr, + /*operand_layouts*/ nullptr, + /*result_layouts*/ nullptr, + /*output_operand_aliases*/ nullptr); - SmallVector argTypes = {inputType}; - SmallVector retTypes = {inputType, blasPivotType, blasInfoType}; + // LAPACK returns 1-indexed pivots, while XLA returns 0-indexed pivots. We + // make it consistent with LAPACK by adding 1 to the pivots. + auto pivots1Indexed = rewriter.create( + op.getLoc(), + rewriter.create( + op.getLoc(), pivotType, cast(makeAttr(pivotType, 1))), + rewriter.create(op.getLoc(), pivotType, + customCall.getResult(1))); + + auto permutation1Indexed = rewriter.create( + op.getLoc(), + rewriter.create( + op.getLoc(), permutationType, + cast(makeAttr(permutationType, 1))), + rewriter.create(op.getLoc(), permutationType, + customCall.getResult(2))); + + auto isFinite = rewriter.create( + op.getLoc(), + rewriter.create( + op.getLoc(), rewriter.create( + op.getLoc(), customCall.getResult(0))), + rewriter.create( + op.getLoc(), rewriter.create( + op.getLoc(), customCall.getResult(0)))); + + SmallVector reductionDims; + for (int i = numBatchDims; i < inputRank; i++) + reductionDims.push_back(i); + auto initValType = RankedTensorType::get({}, rewriter.getI1Type()); + auto initVal = rewriter.create( + op.getLoc(), initValType, cast(makeAttr(initValType, 1))); + + auto allFinite = rewriter.create( + op.getLoc(), + RankedTensorType::get(infoType.getShape(), rewriter.getI1Type()), + ValueRange{isFinite.getResult()}, ValueRange{initVal}, + rewriter.getDenseI64ArrayAttr(reductionDims)); - FunctionType calleeType = rewriter.getFunctionType(argTypes, retTypes); - func::FuncOp func = - rewriter.create(op.getLoc(), fnName, calleeType); - func.setPrivate(); + { + OpBuilder::InsertionGuard guard(rewriter); + auto ®ion = allFinite.getBody(); + auto *block = rewriter.createBlock( + ®ion, {}, {initValType, initValType}, {op.getLoc(), op.getLoc()}); - auto &entryBlock = *func.addEntryBlock(); - rewriter.setInsertionPointToStart(&entryBlock); + rewriter.setInsertionPointToStart(block); + auto lhs = block->getArgument(0); + auto rhs = block->getArgument(1); + auto andOp = rewriter.create(op.getLoc(), lhs, rhs); - auto input = entryBlock.getArgument(0); - auto mSize = rewriter.create( - op.getLoc(), RankedTensorType::get({}, blasIntType), - rewriter.create(op.getLoc(), input, 0)); - auto nSize = rewriter.create( - op.getLoc(), RankedTensorType::get({}, blasIntType), - rewriter.create(op.getLoc(), input, 1)); - auto pivot = rewriter.create( - op.getLoc(), blasPivotType, - cast(makeAttr(blasPivotType, -1))); - auto info = rewriter.create( - op.getLoc(), blasInfoType, - cast(makeAttr(blasInfoType, -1))); + rewriter.create(op.getLoc(), + ValueRange{andOp.getResult()}); + } - auto jitCall = rewriter.create( - op.getLoc(), TypeRange{inputType, blasPivotType, blasInfoType}, - mlir::FlatSymbolRefAttr::get(ctx, lapackFn), - ValueRange{mSize, nSize, input, mSize, pivot, info}, - rewriter.getStringAttr(""), - /*operand_layouts=*/operandLayouts, - /*result_layouts=*/resultLayouts, - /*arg_attrs=*/nullptr, - /*res_attrs=*/nullptr, - /*output_operand_aliases=*/outputOperandAliases, - /*xla_side_effect_free=*/rewriter.getUnitAttr()); + // info == 0 if all finite (success) + auto info = rewriter.create( + op.getLoc(), infoType, + rewriter.create(op.getLoc(), allFinite.getResult(0))); - rewriter.create( - op.getLoc(), ValueRange{jitCall.getResult(0), jitCall.getResult(1), - jitCall.getResult(2)}); + rewriter.replaceAllUsesWith(op.getResult(0), customCall.getResult(0)); + rewriter.replaceAllUsesWith(op.getResult(1), pivots1Indexed); + rewriter.replaceAllUsesWith(op.getResult(2), permutation1Indexed); + rewriter.replaceAllUsesWith(op.getResult(3), info); + rewriter.eraseOp(op); - return func; + return success(); } }; @@ -608,14 +625,11 @@ struct SVDFactorizationOpLowering LogicalResult matchAndRewrite(enzymexla::SVDFactorizationOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); else if (backend == "cuda") - return this->matchAndRewrite_cuda(op, rewriter); - + return matchAndRewriteCUDA(op, rewriter); else if (backend == "tpu") - return this->matchAndRewrite_tpu(op, rewriter); - + return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -624,8 +638,8 @@ struct SVDFactorizationOpLowering // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance // TODO support more SVD algorithms (e.g. `gesdd`, `gesvj`) - LogicalResult matchAndRewrite_cpu(enzymexla::SVDFactorizationOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::SVDFactorizationOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -838,10 +852,9 @@ struct SVDFactorizationOpLowering return success(); } - LogicalResult matchAndRewrite_cuda(enzymexla::SVDFactorizationOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCUDA(enzymexla::SVDFactorizationOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); - LLVMTypeConverter typeConverter(ctx); auto type_lapack_int = rewriter.getIntegerType(blasIntWidth); @@ -918,14 +931,12 @@ struct SVDFactorizationOpLowering } // TODO find registered TPU kernel - LogicalResult matchAndRewrite_tpu(enzymexla::SVDFactorizationOp op, - PatternRewriter &rewriter) const { - + LogicalResult matchAndRewriteTPU(enzymexla::SVDFactorizationOp op, + PatternRewriter &rewriter) const { return rewriter.notifyMatchFailure( op, "We don't know yet to which SVD TPU kernel to lower to :_("); auto ctx = op->getContext(); - LLVMTypeConverter typeConverter(ctx); auto input = op.getOperand(); auto type_input = cast(input.getType()); From bcb7d563b0db3f7935889ed739e8c07d4ceb584f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Oct 2025 20:41:55 -0500 Subject: [PATCH 06/12] test: update LU tests --- .../jax/Passes/LowerEnzymeXLALinalg.cpp | 2 +- test/lit_tests/linalg/lu.mlir | 37 +++--- test/lit_tests/linalg/lu_batched.mlir | 121 +++++++++--------- 3 files changed, 80 insertions(+), 80 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp index fafa7ea18..3e5c65200 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp @@ -522,7 +522,7 @@ struct LUFactorizationOpLowering SmallVector permutationShape(inputShape.begin(), inputShape.end() - 2); - permutationShape.push_back(inputShape[0]); + permutationShape.push_back(inputShape[inputRank - 2]); auto permutationType = RankedTensorType::get(permutationShape, rewriter.getI32Type()); diff --git a/test/lit_tests/linalg/lu.mlir b/test/lit_tests/linalg/lu.mlir index 9ce775d76..c31183154 100644 --- a/test/lit_tests/linalg/lu.mlir +++ b/test/lit_tests/linalg/lu.mlir @@ -9,37 +9,32 @@ module { } } -// CPU: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -// CPU-NEXT: llvm.func @enzymexla_lapack_sgetrf_wrapper_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(64 : i64) : i64 -// CPU-NEXT: %1 = llvm.mlir.constant(1 : i64) : i64 -// CPU-NEXT: %2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr -// CPU-NEXT: %3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr -// CPU-NEXT: llvm.store %0, %2 : i64, !llvm.ptr -// CPU-NEXT: llvm.store %0, %3 : i64, !llvm.ptr -// CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%2, %3, %arg0, %2, %arg1, %arg2) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> () -// CPU-NEXT: llvm.return +// CPU: func.func private @enzymexla_lapack_sgetrf_[[WRAPPER_ID:[0-9]+]](%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) { +// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<-1> : tensor<64xi64> +// CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor +// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%c_1, %c_1, %arg0, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor, tensor, tensor<64x64xf32>, tensor, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) +// CPU-NEXT: stablehlo.return %0#0, %0#1, %0#2 : tensor<64x64xf32>, tensor<64xi64>, tensor // CPU-NEXT: } +// CPU-NEXT: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) // CPU-NEXT: func.func @main(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xi32>, tensor<64xi32>, tensor) { // CPU-NEXT: %c = stablehlo.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64> // CPU-NEXT: %c_0 = stablehlo.constant dense<1> : tensor // CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor // CPU-NEXT: %c_2 = stablehlo.constant dense<1> : tensor<64xi64> // CPU-NEXT: %c_3 = stablehlo.constant dense<0> : tensor -// CPU-NEXT: %c_4 = stablehlo.constant dense<-1> : tensor<64xi64> -// CPU-NEXT: %c_5 = stablehlo.constant dense<-1> : tensor -// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper_[[WRAPPER_ID]] (%arg0, %c_4, %c_5) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x64xf32>, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) +// CPU-NEXT: %0:3 = call @enzymexla_lapack_sgetrf_[[WRAPPER_ID]](%arg0) : (tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) // CPU-NEXT: %1 = stablehlo.subtract %0#1, %c_2 : tensor<64xi64> -// CPU-NEXT: %2:2 = stablehlo.while(%iterArg = %c_3, %iterArg_6 = %c) : tensor, tensor<64xi64> -// CPU-NEXT: cond { +// CPU-NEXT: %2:2 = stablehlo.while(%iterArg = %c_3, %iterArg_4 = %c) : tensor, tensor<64xi64> +// CPU-NEXT: cond { // CPU-NEXT: %7 = stablehlo.compare LT, %iterArg, %c_1 : (tensor, tensor) -> tensor // CPU-NEXT: stablehlo.return %7 : tensor // CPU-NEXT: } do { // CPU-NEXT: %7 = stablehlo.add %iterArg, %c_0 : tensor // CPU-NEXT: %8 = stablehlo.dynamic_slice %1, %iterArg, sizes = [1] : (tensor<64xi64>, tensor) -> tensor<1xi64> -// CPU-NEXT: %9 = stablehlo.dynamic_slice %iterArg_6, %iterArg, sizes = [1] : (tensor<64xi64>, tensor) -> tensor<1xi64> -// CPU-NEXT: %10 = "stablehlo.gather"(%iterArg_6, %8) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<64xi64>, tensor<1xi64>) -> tensor<1xi64> -// CPU-NEXT: %11 = stablehlo.dynamic_update_slice %iterArg_6, %10, %iterArg : (tensor<64xi64>, tensor<1xi64>, tensor) -> tensor<64xi64> +// CPU-NEXT: %9 = stablehlo.dynamic_slice %iterArg_4, %iterArg, sizes = [1] : (tensor<64xi64>, tensor) -> tensor<1xi64> +// CPU-NEXT: %10 = "stablehlo.gather"(%iterArg_4, %8) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<64xi64>, tensor<1xi64>) -> tensor<1xi64> +// CPU-NEXT: %11 = stablehlo.dynamic_update_slice %iterArg_4, %10, %iterArg : (tensor<64xi64>, tensor<1xi64>, tensor) -> tensor<64xi64> // CPU-NEXT: %12 = stablehlo.reshape %9 : (tensor<1xi64>) -> tensor // CPU-NEXT: %13 = "stablehlo.scatter"(%11, %8, %12) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ // CPU-NEXT: ^bb0(%arg1: tensor, %arg2: tensor): @@ -77,27 +72,27 @@ module { // TPU-NEXT: } module { + // CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_ // CPU: func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { - // CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_wrapper_[[WRAPPER_ID:[0-9]+]] %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<64xi32>, tensor) return %0#0, %0#1, %0#3 : tensor<64x64xf64>, tensor<64xi32>, tensor } } module { + // CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_ // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { - // CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_wrapper_[[WRAPPER_ID:[0-9]+]] %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor } } module { + // CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_ // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { - // CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_wrapper_[[WRAPPER_ID:[0-9]+]] %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor } diff --git a/test/lit_tests/linalg/lu_batched.mlir b/test/lit_tests/linalg/lu_batched.mlir index 1af0820a8..ce100f3db 100644 --- a/test/lit_tests/linalg/lu_batched.mlir +++ b/test/lit_tests/linalg/lu_batched.mlir @@ -10,70 +10,75 @@ module { } // CPU: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -// CPU-NEXT: llvm.func @enzymexla_lapack_sgetrf_wrapper_0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(64 : i64) : i64 -// CPU-NEXT: %1 = llvm.mlir.constant(1 : i64) : i64 -// CPU-NEXT: %2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr -// CPU-NEXT: %3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr -// CPU-NEXT: llvm.store %0, %2 : i64, !llvm.ptr -// CPU-NEXT: llvm.store %0, %3 : i64, !llvm.ptr -// CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%2, %3, %arg0, %2, %arg1, %arg2) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> () -// CPU-NEXT: llvm.return -// CPU-NEXT: } // CPU-NEXT: func.func @main(%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi32>, tensor<4x3x64xi32>, tensor<4x3xi32>) { -// CPU: %c_0 = stablehlo.constant dense<64> : tensor -// CPU-NEXT: %c_1 = stablehlo.constant dense<1> : tensor<4x3x64xi64> -// CPU-NEXT: %c_2 = stablehlo.constant dense<1> : tensor -// CPU-NEXT: %c_3 = stablehlo.constant dense<-1> : tensor -// CPU-NEXT: %c_4 = stablehlo.constant dense<-1> : tensor<64xi64> -// CPU-NEXT: %c_5 = stablehlo.constant dense<12> : tensor -// CPU-NEXT: %c_6 = stablehlo.constant dense<-1> : tensor<12xi64> -// CPU-NEXT: %c_7 = stablehlo.constant dense<-1> : tensor<12x64xi64> -// CPU-NEXT: %c_8 = stablehlo.constant dense<0> : tensor -// CPU-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<4x3x64x64xf32>) -> tensor<12x64x64xf32> -// CPU-NEXT: %1:4 = stablehlo.while(%iterArg = %c_8, %iterArg_9 = %0, %iterArg_10 = %c_7, %iterArg_11 = %c_6) : tensor, tensor<12x64x64xf32>, tensor<12x64xi64>, tensor<12xi64> -// CPU-NEXT: cond { -// CPU-NEXT: %11 = stablehlo.compare LT, %iterArg, %c_5 : (tensor, tensor) -> tensor -// CPU-NEXT: stablehlo.return %11 : tensor +// CPU: %c_0 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor +// CPU-NEXT: %c_2 = stablehlo.constant dense<1> : tensor<4x3x64xi64> +// CPU-NEXT: %c_3 = stablehlo.constant dense<0> : tensor +// CPU-NEXT: %0:3 = call @batched_enzymexla_lapack_sgetrf_[[WRAPPER_ID:[0-9]+]](%arg0) : (tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>) +// CPU-NEXT: %1 = stablehlo.subtract %0#1, %c_2 : tensor<4x3x64xi64> +// CPU-NEXT: %2:2 = stablehlo.while(%iterArg = %c_3, %iterArg_4 = %c) : tensor, tensor<4x3x64xi64> +// CPU-NEXT: cond { +// CPU-NEXT: %7 = stablehlo.compare LT, %iterArg, %c_1 : (tensor, tensor) -> tensor +// CPU-NEXT: stablehlo.return %7 : tensor // CPU-NEXT: } do { -// CPU-NEXT: %11 = stablehlo.dynamic_slice %iterArg_9, %iterArg, %c_8, %c_8, sizes = [1, 64, 64] : (tensor<12x64x64xf32>, tensor, tensor, tensor) -> tensor<1x64x64xf32> -// CPU-NEXT: %12 = stablehlo.reshape %11 : (tensor<1x64x64xf32>) -> tensor<64x64xf32> -// CPU-NEXT: %13:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper_0 (%12, %c_4, %c_3) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x64xf32>, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) -// CPU-NEXT: %14 = stablehlo.reshape %13#0 : (tensor<64x64xf32>) -> tensor<1x64x64xf32> -// CPU-NEXT: %15 = stablehlo.dynamic_update_slice %iterArg_9, %14, %iterArg, %c_8, %c_8 : (tensor<12x64x64xf32>, tensor<1x64x64xf32>, tensor, tensor, tensor) -> tensor<12x64x64xf32> -// CPU-NEXT: %16 = stablehlo.reshape %13#1 : (tensor<64xi64>) -> tensor<1x64xi64> -// CPU-NEXT: %17 = stablehlo.dynamic_update_slice %iterArg_10, %16, %iterArg, %c_8 : (tensor<12x64xi64>, tensor<1x64xi64>, tensor, tensor) -> tensor<12x64xi64> -// CPU-NEXT: %18 = stablehlo.reshape %13#2 : (tensor) -> tensor<1xi64> -// CPU-NEXT: %19 = stablehlo.dynamic_update_slice %iterArg_11, %18, %iterArg : (tensor<12xi64>, tensor<1xi64>, tensor) -> tensor<12xi64> -// CPU-NEXT: %20 = stablehlo.add %iterArg, %c_2 : tensor -// CPU-NEXT: stablehlo.return %20, %15, %17, %19 : tensor, tensor<12x64x64xf32>, tensor<12x64xi64>, tensor<12xi64> -// CPU-NEXT: } -// CPU-NEXT: %2 = stablehlo.reshape %1#1 : (tensor<12x64x64xf32>) -> tensor<4x3x64x64xf32> -// CPU-NEXT: %3 = stablehlo.reshape %1#2 : (tensor<12x64xi64>) -> tensor<4x3x64xi64> -// CPU-NEXT: %4 = stablehlo.convert %1#3 : (tensor<12xi64>) -> tensor<12xi32> -// CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<12xi32>) -> tensor<4x3xi32> -// CPU-NEXT: %6 = stablehlo.subtract %3, %c_1 : tensor<4x3x64xi64> -// CPU-NEXT: %7:2 = stablehlo.while(%iterArg = %c_8, %iterArg_9 = %c) : tensor, tensor<4x3x64xi64> -// CPU-NEXT: cond { -// CPU-NEXT: %11 = stablehlo.compare LT, %iterArg, %c_0 : (tensor, tensor) -> tensor -// CPU-NEXT: stablehlo.return %11 : tensor -// CPU-NEXT: } do { -// CPU-NEXT: %11 = stablehlo.add %iterArg, %c_2 : tensor -// CPU-NEXT: %12 = stablehlo.dynamic_slice %6, %c_8, %c_8, %iterArg, sizes = [4, 3, 1] : (tensor<4x3x64xi64>, tensor, tensor, tensor) -> tensor<4x3x1xi64> -// CPU-NEXT: %13 = stablehlo.dynamic_slice %iterArg_9, %c_8, %c_8, %iterArg, sizes = [4, 3, 1] : (tensor<4x3x64xi64>, tensor, tensor, tensor) -> tensor<4x3x1xi64> -// CPU-NEXT: %14 = "stablehlo.gather"(%iterArg_9, %12) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<4x3x64xi64>, tensor<4x3x1xi64>) -> tensor<4x3x1xi64> -// CPU-NEXT: %15 = stablehlo.dynamic_update_slice %iterArg_9, %14, %c_8, %c_8, %iterArg : (tensor<4x3x64xi64>, tensor<4x3x1xi64>, tensor, tensor, tensor) -> tensor<4x3x64xi64> -// CPU-NEXT: %16 = stablehlo.reshape %13 : (tensor<4x3x1xi64>) -> tensor<4x3xi64> -// CPU-NEXT: %17 = "stablehlo.scatter"(%15, %12, %16) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ +// CPU-NEXT: %7 = stablehlo.add %iterArg, %c_0 : tensor +// CPU-NEXT: %8 = stablehlo.dynamic_slice %1, %c_3, %c_3, %iterArg, sizes = [4, 3, 1] : (tensor<4x3x64xi64>, tensor, tensor, tensor) -> tensor<4x3x1xi64> +// CPU-NEXT: %9 = stablehlo.dynamic_slice %iterArg_4, %c_3, %c_3, %iterArg, sizes = [4, 3, 1] : (tensor<4x3x64xi64>, tensor, tensor, tensor) -> tensor<4x3x1xi64> +// CPU-NEXT: %10 = "stablehlo.gather"(%iterArg_4, %8) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<4x3x64xi64>, tensor<4x3x1xi64>) -> tensor<4x3x1xi64> +// CPU-NEXT: %11 = stablehlo.dynamic_update_slice %iterArg_4, %10, %c_3, %c_3, %iterArg : (tensor<4x3x64xi64>, tensor<4x3x1xi64>, tensor, tensor, tensor) -> tensor<4x3x64xi64> +// CPU-NEXT: %12 = stablehlo.reshape %9 : (tensor<4x3x1xi64>) -> tensor<4x3xi64> +// CPU-NEXT: %13 = "stablehlo.scatter"(%11, %8, %12) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ // CPU-NEXT: ^bb0(%arg1: tensor, %arg2: tensor): // CPU-NEXT: stablehlo.return %arg2 : tensor // CPU-NEXT: }) : (tensor<4x3x64xi64>, tensor<4x3x1xi64>, tensor<4x3xi64>) -> tensor<4x3x64xi64> -// CPU-NEXT: stablehlo.return %11, %17 : tensor, tensor<4x3x64xi64> +// CPU-NEXT: stablehlo.return %7, %13 : tensor, tensor<4x3x64xi64> +// CPU-NEXT: } +// CPU-NEXT: %3 = stablehlo.add %2#1, %c_2 : tensor<4x3x64xi64> +// CPU-NEXT: %4 = stablehlo.convert %0#1 : (tensor<4x3x64xi64>) -> tensor<4x3x64xi32> +// CPU-NEXT: %5 = stablehlo.convert %3 : (tensor<4x3x64xi64>) -> tensor<4x3x64xi32> +// CPU-NEXT: %6 = stablehlo.convert %0#2 : (tensor<4x3xi64>) -> tensor<4x3xi32> +// CPU-NEXT: return %0#0, %4, %5, %6 : tensor<4x3x64x64xf32>, tensor<4x3x64xi32>, tensor<4x3x64xi32>, tensor<4x3xi32> +// CPU-NEXT: } +// CPU-NEXT: func.func private @batched_enzymexla_lapack_sgetrf_[[WRAPPER_ID]](%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>) { +// CPU-NEXT: %c = stablehlo.constant dense<4> : tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %c_1 = stablehlo.constant dense<12> : tensor +// CPU-NEXT: %cst = arith.constant dense<0> : tensor<4x3xi64> +// CPU-NEXT: %cst_2 = arith.constant dense<0> : tensor<4x3x64xi64> +// CPU-NEXT: %cst_3 = arith.constant dense<0.000000e+00> : tensor<4x3x64x64xf32> +// CPU-NEXT: %c_4 = stablehlo.constant dense<0> : tensor +// CPU-NEXT: %c_5 = stablehlo.constant dense<-1> : tensor<4x3xi64> +// CPU-NEXT: %c_6 = stablehlo.constant dense<-1> : tensor<4x3x64xi64> +// CPU-NEXT: %c_7 = stablehlo.constant dense<64> : tensor<4x3xi64> +// CPU-NEXT: %0:4 = stablehlo.while(%iterArg = %c_4, %iterArg_8 = %cst_3, %iterArg_9 = %cst_2, %iterArg_10 = %cst) : tensor, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> +// CPU-NEXT: cond { +// CPU-NEXT: %1 = stablehlo.compare LT, %iterArg, %c_1 : (tensor, tensor) -> tensor +// CPU-NEXT: stablehlo.return %1 : tensor +// CPU-NEXT: } do { +// CPU-NEXT: %1 = stablehlo.add %iterArg, %c_0 : tensor +// CPU-NEXT: %2 = stablehlo.remainder %iterArg, %c : tensor +// CPU-NEXT: %3 = stablehlo.divide %iterArg, %c : tensor +// CPU-NEXT: %4 = stablehlo.dynamic_slice %c_7, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor, tensor) -> tensor<1x1xi64> +// CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<1x1xi64>) -> tensor +// CPU-NEXT: %6 = stablehlo.dynamic_slice %arg0, %2, %3, %c_4, %c_4, sizes = [1, 1, 64, 64] : (tensor<4x3x64x64xf32>, tensor, tensor, tensor, tensor) -> tensor<1x1x64x64xf32> +// CPU-NEXT: %7 = stablehlo.reshape %6 : (tensor<1x1x64x64xf32>) -> tensor<64x64xf32> +// CPU-NEXT: %8 = stablehlo.dynamic_slice %c_7, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor, tensor) -> tensor<1x1xi64> +// CPU-NEXT: %9 = stablehlo.reshape %8 : (tensor<1x1xi64>) -> tensor +// CPU-NEXT: %10 = stablehlo.dynamic_slice %c_6, %2, %3, %c_4, sizes = [1, 1, 64] : (tensor<4x3x64xi64>, tensor, tensor, tensor) -> tensor<1x1x64xi64> +// CPU-NEXT: %11 = stablehlo.reshape %10 : (tensor<1x1x64xi64>) -> tensor<64xi64> +// CPU-NEXT: %12 = stablehlo.dynamic_slice %c_5, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor, tensor) -> tensor<1x1xi64> +// CPU-NEXT: %13 = stablehlo.reshape %12 : (tensor<1x1xi64>) -> tensor +// CPU-NEXT: %14:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%9, %5, %7, %9, %11, %13) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor, tensor, tensor<64x64xf32>, tensor, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) +// CPU-NEXT: %15 = stablehlo.reshape %14#0 : (tensor<64x64xf32>) -> tensor<1x1x64x64xf32> +// CPU-NEXT: %16 = stablehlo.dynamic_update_slice %iterArg_8, %15, %2, %3, %c_4, %c_4 : (tensor<4x3x64x64xf32>, tensor<1x1x64x64xf32>, tensor, tensor, tensor, tensor) -> tensor<4x3x64x64xf32> +// CPU-NEXT: %17 = stablehlo.reshape %14#1 : (tensor<64xi64>) -> tensor<1x1x64xi64> +// CPU-NEXT: %18 = stablehlo.dynamic_update_slice %iterArg_9, %17, %2, %3, %c_4 : (tensor<4x3x64xi64>, tensor<1x1x64xi64>, tensor, tensor, tensor) -> tensor<4x3x64xi64> +// CPU-NEXT: %19 = stablehlo.reshape %14#2 : (tensor) -> tensor<1x1xi64> +// CPU-NEXT: %20 = stablehlo.dynamic_update_slice %iterArg_10, %19, %2, %3 : (tensor<4x3xi64>, tensor<1x1xi64>, tensor, tensor) -> tensor<4x3xi64> +// CPU-NEXT: stablehlo.return %1, %16, %18, %20 : tensor, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> // CPU-NEXT: } -// CPU-NEXT: %8 = stablehlo.add %7#1, %c_1 : tensor<4x3x64xi64> -// CPU-NEXT: %9 = stablehlo.convert %3 : (tensor<4x3x64xi64>) -> tensor<4x3x64xi32> -// CPU-NEXT: %10 = stablehlo.convert %8 : (tensor<4x3x64xi64>) -> tensor<4x3x64xi32> -// CPU-NEXT: return %2, %9, %10, %5 : tensor<4x3x64x64xf32>, tensor<4x3x64xi32>, tensor<4x3x64xi32>, tensor<4x3xi32> +// CPU-NEXT: stablehlo.return %0#1, %0#2, %0#3 : tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> // CPU-NEXT: } From af0d25de3649aa304b157be16df083abc99ea678 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Oct 2025 21:08:14 -0500 Subject: [PATCH 07/12] feat: dynamic slice simplify --- test/lit_tests/linalg/lu_batched.mlir | 58 ++++++++++++--------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/test/lit_tests/linalg/lu_batched.mlir b/test/lit_tests/linalg/lu_batched.mlir index ce100f3db..ac519f0fe 100644 --- a/test/lit_tests/linalg/lu_batched.mlir +++ b/test/lit_tests/linalg/lu_batched.mlir @@ -40,43 +40,35 @@ module { // CPU-NEXT: %6 = stablehlo.convert %0#2 : (tensor<4x3xi64>) -> tensor<4x3xi32> // CPU-NEXT: return %0#0, %4, %5, %6 : tensor<4x3x64x64xf32>, tensor<4x3x64xi32>, tensor<4x3x64xi32>, tensor<4x3xi32> // CPU-NEXT: } -// CPU-NEXT: func.func private @batched_enzymexla_lapack_sgetrf_[[WRAPPER_ID]](%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>) { -// CPU-NEXT: %c = stablehlo.constant dense<4> : tensor -// CPU-NEXT: %c_0 = stablehlo.constant dense<1> : tensor -// CPU-NEXT: %c_1 = stablehlo.constant dense<12> : tensor +// CPU-NEXT: func.func private @batched_enzymexla_lapack_sgetrf_[[WRAPPER_ID]](%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>) { +// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<-1> : tensor<64xi64> +// CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor +// CPU-NEXT: %c_2 = stablehlo.constant dense<4> : tensor +// CPU-NEXT: %c_3 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %c_4 = stablehlo.constant dense<12> : tensor // CPU-NEXT: %cst = arith.constant dense<0> : tensor<4x3xi64> -// CPU-NEXT: %cst_2 = arith.constant dense<0> : tensor<4x3x64xi64> -// CPU-NEXT: %cst_3 = arith.constant dense<0.000000e+00> : tensor<4x3x64x64xf32> -// CPU-NEXT: %c_4 = stablehlo.constant dense<0> : tensor -// CPU-NEXT: %c_5 = stablehlo.constant dense<-1> : tensor<4x3xi64> -// CPU-NEXT: %c_6 = stablehlo.constant dense<-1> : tensor<4x3x64xi64> -// CPU-NEXT: %c_7 = stablehlo.constant dense<64> : tensor<4x3xi64> -// CPU-NEXT: %0:4 = stablehlo.while(%iterArg = %c_4, %iterArg_8 = %cst_3, %iterArg_9 = %cst_2, %iterArg_10 = %cst) : tensor, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> +// CPU-NEXT: %cst_5 = arith.constant dense<0> : tensor<4x3x64xi64> +// CPU-NEXT: %cst_6 = arith.constant dense<0.000000e+00> : tensor<4x3x64x64xf32> +// CPU-NEXT: %c_7 = stablehlo.constant dense<0> : tensor +// CPU-NEXT: %0:4 = stablehlo.while(%iterArg = %c_7, %iterArg_8 = %cst_6, %iterArg_9 = %cst_5, %iterArg_10 = %cst) : tensor, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> // CPU-NEXT: cond { -// CPU-NEXT: %1 = stablehlo.compare LT, %iterArg, %c_1 : (tensor, tensor) -> tensor +// CPU-NEXT: %1 = stablehlo.compare LT, %iterArg, %c_4 : (tensor, tensor) -> tensor // CPU-NEXT: stablehlo.return %1 : tensor // CPU-NEXT: } do { -// CPU-NEXT: %1 = stablehlo.add %iterArg, %c_0 : tensor -// CPU-NEXT: %2 = stablehlo.remainder %iterArg, %c : tensor -// CPU-NEXT: %3 = stablehlo.divide %iterArg, %c : tensor -// CPU-NEXT: %4 = stablehlo.dynamic_slice %c_7, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor, tensor) -> tensor<1x1xi64> -// CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<1x1xi64>) -> tensor -// CPU-NEXT: %6 = stablehlo.dynamic_slice %arg0, %2, %3, %c_4, %c_4, sizes = [1, 1, 64, 64] : (tensor<4x3x64x64xf32>, tensor, tensor, tensor, tensor) -> tensor<1x1x64x64xf32> -// CPU-NEXT: %7 = stablehlo.reshape %6 : (tensor<1x1x64x64xf32>) -> tensor<64x64xf32> -// CPU-NEXT: %8 = stablehlo.dynamic_slice %c_7, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor, tensor) -> tensor<1x1xi64> -// CPU-NEXT: %9 = stablehlo.reshape %8 : (tensor<1x1xi64>) -> tensor -// CPU-NEXT: %10 = stablehlo.dynamic_slice %c_6, %2, %3, %c_4, sizes = [1, 1, 64] : (tensor<4x3x64xi64>, tensor, tensor, tensor) -> tensor<1x1x64xi64> -// CPU-NEXT: %11 = stablehlo.reshape %10 : (tensor<1x1x64xi64>) -> tensor<64xi64> -// CPU-NEXT: %12 = stablehlo.dynamic_slice %c_5, %2, %3, sizes = [1, 1] : (tensor<4x3xi64>, tensor, tensor) -> tensor<1x1xi64> -// CPU-NEXT: %13 = stablehlo.reshape %12 : (tensor<1x1xi64>) -> tensor -// CPU-NEXT: %14:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%9, %5, %7, %9, %11, %13) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor, tensor, tensor<64x64xf32>, tensor, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) -// CPU-NEXT: %15 = stablehlo.reshape %14#0 : (tensor<64x64xf32>) -> tensor<1x1x64x64xf32> -// CPU-NEXT: %16 = stablehlo.dynamic_update_slice %iterArg_8, %15, %2, %3, %c_4, %c_4 : (tensor<4x3x64x64xf32>, tensor<1x1x64x64xf32>, tensor, tensor, tensor, tensor) -> tensor<4x3x64x64xf32> -// CPU-NEXT: %17 = stablehlo.reshape %14#1 : (tensor<64xi64>) -> tensor<1x1x64xi64> -// CPU-NEXT: %18 = stablehlo.dynamic_update_slice %iterArg_9, %17, %2, %3, %c_4 : (tensor<4x3x64xi64>, tensor<1x1x64xi64>, tensor, tensor, tensor) -> tensor<4x3x64xi64> -// CPU-NEXT: %19 = stablehlo.reshape %14#2 : (tensor) -> tensor<1x1xi64> -// CPU-NEXT: %20 = stablehlo.dynamic_update_slice %iterArg_10, %19, %2, %3 : (tensor<4x3xi64>, tensor<1x1xi64>, tensor, tensor) -> tensor<4x3xi64> -// CPU-NEXT: stablehlo.return %1, %16, %18, %20 : tensor, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> +// CPU-NEXT: %1 = stablehlo.add %iterArg, %c_3 : tensor +// CPU-NEXT: %2 = stablehlo.remainder %iterArg, %c_2 : tensor +// CPU-NEXT: %3 = stablehlo.divide %iterArg, %c_2 : tensor +// CPU-NEXT: %4 = stablehlo.dynamic_slice %arg0, %2, %3, %c_7, %c_7, sizes = [1, 1, 64, 64] : (tensor<4x3x64x64xf32>, tensor, tensor, tensor, tensor) -> tensor<1x1x64x64xf32> +// CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<1x1x64x64xf32>) -> tensor<64x64xf32> +// CPU-NEXT: %6:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%c_1, %c_1, %5, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor, tensor, tensor<64x64xf32>, tensor, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) +// CPU-NEXT: %7 = stablehlo.reshape %6#0 : (tensor<64x64xf32>) -> tensor<1x1x64x64xf32> +// CPU-NEXT: %8 = stablehlo.dynamic_update_slice %iterArg_8, %7, %2, %3, %c_7, %c_7 : (tensor<4x3x64x64xf32>, tensor<1x1x64x64xf32>, tensor, tensor, tensor, tensor) -> tensor<4x3x64x64xf32> +// CPU-NEXT: %9 = stablehlo.reshape %6#1 : (tensor<64xi64>) -> tensor<1x1x64xi64> +// CPU-NEXT: %10 = stablehlo.dynamic_update_slice %iterArg_9, %9, %2, %3, %c_7 : (tensor<4x3x64xi64>, tensor<1x1x64xi64>, tensor, tensor, tensor) -> tensor<4x3x64xi64> +// CPU-NEXT: %11 = stablehlo.reshape %6#2 : (tensor) -> tensor<1x1xi64> +// CPU-NEXT: %12 = stablehlo.dynamic_update_slice %iterArg_10, %11, %2, %3 : (tensor<4x3xi64>, tensor<1x1xi64>, tensor, tensor) -> tensor<4x3xi64> +// CPU-NEXT: stablehlo.return %1, %8, %10, %12 : tensor, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> // CPU-NEXT: } // CPU-NEXT: stablehlo.return %0#1, %0#2, %0#3 : tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> // CPU-NEXT: } From 4b61728c49367a55c50395d7b00de8378c46ba8c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Oct 2025 21:48:33 -0500 Subject: [PATCH 08/12] feat: mark memory effects --- .../jax/Passes/LowerEnzymeXLALinalg.cpp | 46 ++++++++++++++++++- test/lit_tests/linalg/lu.mlir | 6 ++- test/lit_tests/linalg/lu_batched.mlir | 6 ++- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp index 3e5c65200..72028fc61 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp @@ -159,6 +159,7 @@ struct LUFactorizationOpLowering << inputElementType; return rewriter.notifyMatchFailure(op, "unsupported input element type"); } + std::string lapackFnWrapper = lapackFn + "wrapper"; // Insert function declaration if not already present if (!moduleOp.lookupSymbol(lapackFn)) { @@ -175,6 +176,49 @@ struct LUFactorizationOpLowering LLVM::Linkage::External); } + if (!moduleOp.lookupSymbol(lapackFnWrapper)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = + LLVM::LLVMFunctionType::get(llvmVoidPtrType, + {llvmPtrType, llvmPtrType, llvmPtrType, + llvmPtrType, llvmPtrType, llvmPtrType}, + false); + + auto funcOp = rewriter.create( + op.getLoc(), lapackFnWrapper, funcType, LLVM::Linkage::Private); + rewriter.setInsertionPointToStart(funcOp.addEntryBlock(rewriter)); + + funcOp.setArgAttr(0, LLVM::LLVMDialect::getReadonlyAttrName(), + rewriter.getUnitAttr()); + funcOp.setArgAttr(1, LLVM::LLVMDialect::getReadonlyAttrName(), + rewriter.getUnitAttr()); + // 2 is read + write + funcOp.setArgAttr(3, LLVM::LLVMDialect::getReadonlyAttrName(), + rewriter.getUnitAttr()); + funcOp.setArgAttr(4, LLVM::LLVMDialect::getWriteOnlyAttrName(), + rewriter.getUnitAttr()); + funcOp.setArgAttr(5, LLVM::LLVMDialect::getWriteOnlyAttrName(), + rewriter.getUnitAttr()); + for (int i = 0; i < 6; i++) { + funcOp.setArgAttr(i, LLVM::LLVMDialect::getNoFreeAttrName(), + rewriter.getUnitAttr()); + } + + auto callOp = rewriter.create( + op.getLoc(), TypeRange{}, SymbolRefAttr::get(ctx, lapackFn), + ValueRange{ + funcOp.getArgument(0), + funcOp.getArgument(1), + funcOp.getArgument(2), + funcOp.getArgument(3), + funcOp.getArgument(4), + funcOp.getArgument(5), + }); + rewriter.create(op.getLoc(), ValueRange{}); + } + // Call the LLVM function with enzymexla.jit_call SmallVector aliases; aliases.push_back(stablehlo::OutputOperandAliasAttr::get( @@ -205,7 +249,7 @@ struct LUFactorizationOpLowering std::string wrapperFnName = lapackFn + std::to_string(fnNum++); func::FuncOp func = createWrapperFuncOpCPULapack( - rewriter, lapackFn, unbatchedInputType, unbatchedBLASPivotType, + rewriter, lapackFnWrapper, unbatchedInputType, unbatchedBLASPivotType, unbatchedBLASInfoType, blasIntType, wrapperFnName, op, operandLayouts, resultLayouts, rewriter.getArrayAttr(aliases)); if (!func) diff --git a/test/lit_tests/linalg/lu.mlir b/test/lit_tests/linalg/lu.mlir index c31183154..948d483f7 100644 --- a/test/lit_tests/linalg/lu.mlir +++ b/test/lit_tests/linalg/lu.mlir @@ -13,9 +13,13 @@ module { // CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor // CPU-NEXT: %c_0 = stablehlo.constant dense<-1> : tensor<64xi64> // CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor -// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%c_1, %c_1, %arg0, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor, tensor, tensor<64x64xf32>, tensor, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) +// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper (%c_1, %c_1, %arg0, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor, tensor, tensor<64x64xf32>, tensor, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) // CPU-NEXT: stablehlo.return %0#0, %0#1, %0#2 : tensor<64x64xf32>, tensor<64xi64>, tensor // CPU-NEXT: } +// CPU-NEXT: llvm.func private @enzymexla_lapack_sgetrf_wrapper(%arg0: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg1: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg2: !llvm.ptr {llvm.nofree}, %arg3: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg4: !llvm.ptr {llvm.nofree, llvm.writeonly}, %arg5: !llvm.ptr {llvm.nofree, llvm.writeonly}) { +// CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> () +// CPU-NEXT: llvm.return +// CPU-NEXT: } // CPU-NEXT: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) // CPU-NEXT: func.func @main(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xi32>, tensor<64xi32>, tensor) { // CPU-NEXT: %c = stablehlo.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64> diff --git a/test/lit_tests/linalg/lu_batched.mlir b/test/lit_tests/linalg/lu_batched.mlir index ac519f0fe..3e706a43a 100644 --- a/test/lit_tests/linalg/lu_batched.mlir +++ b/test/lit_tests/linalg/lu_batched.mlir @@ -9,6 +9,10 @@ module { } } +// CPU: llvm.func private @enzymexla_lapack_sgetrf_wrapper(%arg0: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg1: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg2: !llvm.ptr {llvm.nofree}, %arg3: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg4: !llvm.ptr {llvm.nofree, llvm.writeonly}, %arg5: !llvm.ptr {llvm.nofree, llvm.writeonly}) { +// CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> () +// CPU-NEXT: llvm.return +// CPU-NEXT: } // CPU: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) // CPU-NEXT: func.func @main(%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi32>, tensor<4x3x64xi32>, tensor<4x3xi32>) { // CPU: %c_0 = stablehlo.constant dense<1> : tensor @@ -61,7 +65,7 @@ module { // CPU-NEXT: %3 = stablehlo.divide %iterArg, %c_2 : tensor // CPU-NEXT: %4 = stablehlo.dynamic_slice %arg0, %2, %3, %c_7, %c_7, sizes = [1, 1, 64, 64] : (tensor<4x3x64x64xf32>, tensor, tensor, tensor, tensor) -> tensor<1x1x64x64xf32> // CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<1x1x64x64xf32>) -> tensor<64x64xf32> -// CPU-NEXT: %6:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%c_1, %c_1, %5, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor, tensor, tensor<64x64xf32>, tensor, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) +// CPU-NEXT: %6:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper (%c_1, %c_1, %5, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor, tensor, tensor<64x64xf32>, tensor, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) // CPU-NEXT: %7 = stablehlo.reshape %6#0 : (tensor<64x64xf32>) -> tensor<1x1x64x64xf32> // CPU-NEXT: %8 = stablehlo.dynamic_update_slice %iterArg_8, %7, %2, %3, %c_7, %c_7 : (tensor<4x3x64x64xf32>, tensor<1x1x64x64xf32>, tensor, tensor, tensor, tensor) -> tensor<4x3x64x64xf32> // CPU-NEXT: %9 = stablehlo.reshape %6#1 : (tensor<64xi64>) -> tensor<1x1x64xi64> From 51e772ce0ec17385dde7264491b3690f5a14c527 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Nov 2025 17:06:09 -0600 Subject: [PATCH 09/12] fix: update to new API --- src/enzyme_ad/jax/Dialect/TritonExt/Ops.cpp | 11 +- .../jax/Passes/LowerEnzymeXLALinalg.cpp | 310 +++++++++--------- test/lit_tests/linalg/lu.mlir | 48 +-- 3 files changed, 191 insertions(+), 178 deletions(-) diff --git a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.cpp b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.cpp index 5287fee54..6c19dd516 100644 --- a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.cpp +++ b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.cpp @@ -15,11 +15,12 @@ template <> triton_ext::TritonCallOp ReadOnlyArg::create( PatternRewriter &rewriter, triton_ext::TritonCallOp launchOp, ArrayRef resTys, ArrayAttr outputAliases) const { - return rewriter.create( - launchOp.getLoc(), resTys, launchOp.getFn(), launchOp.getGridx(), - launchOp.getGridy(), launchOp.getGridz(), launchOp.getClusterx(), - launchOp.getClustery(), launchOp.getClusterz(), launchOp.getInputs(), - launchOp.getBackendConfigAttr(), launchOp.getOperandLayoutsAttr(), + return triton_ext::TritonCallOp::create( + rewriter, launchOp.getLoc(), resTys, launchOp.getFn(), + launchOp.getGridx(), launchOp.getGridy(), launchOp.getGridz(), + launchOp.getClusterx(), launchOp.getClustery(), launchOp.getClusterz(), + launchOp.getInputs(), launchOp.getBackendConfigAttr(), + launchOp.getOperandLayoutsAttr(), /*resultLayouts*/ nullptr, launchOp.getArgAttrsAttr(), launchOp.getResAttrsAttr(), outputAliases, launchOp.getXlaSideEffectFreeAttr()); diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp index 72028fc61..5a6019f80 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp @@ -43,7 +43,6 @@ struct LUFactorizationOpLowering LogicalResult matchAndRewrite(enzymexla::LUFactorizationOp op, PatternRewriter &rewriter) const override { - if (backend == "cpu") { return matchAndRewriteCPU(op, rewriter); } else if (backend == "cuda") { @@ -76,28 +75,29 @@ struct LUFactorizationOpLowering FunctionType calleeType = rewriter.getFunctionType(argTypes, retTypes); func::FuncOp func = - rewriter.create(op.getLoc(), fnName, calleeType); + func::FuncOp::create(rewriter, op.getLoc(), fnName, calleeType); func.setPrivate(); auto &entryBlock = *func.addEntryBlock(); rewriter.setInsertionPointToStart(&entryBlock); auto input = entryBlock.getArgument(0); - auto mSize = rewriter.create( - op.getLoc(), RankedTensorType::get({}, blasIntType), - rewriter.create(op.getLoc(), input, 0)); - auto nSize = rewriter.create( - op.getLoc(), RankedTensorType::get({}, blasIntType), - rewriter.create(op.getLoc(), input, 1)); - auto pivot = rewriter.create( - op.getLoc(), blasPivotType, + auto mSize = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), RankedTensorType::get({}, blasIntType), + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), input, 0)); + auto nSize = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), RankedTensorType::get({}, blasIntType), + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), input, 1)); + auto pivot = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), blasPivotType, cast(makeAttr(blasPivotType, -1))); - auto info = rewriter.create( - op.getLoc(), blasInfoType, + auto info = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), blasInfoType, cast(makeAttr(blasInfoType, -1))); - auto jitCall = rewriter.create( - op.getLoc(), TypeRange{inputType, blasPivotType, blasInfoType}, + auto jitCall = enzymexla::JITCallOp::create( + rewriter, op.getLoc(), + TypeRange{inputType, blasPivotType, blasInfoType}, mlir::FlatSymbolRefAttr::get(ctx, lapackFn), ValueRange{mSize, nSize, input, mSize, pivot, info}, rewriter.getStringAttr(""), @@ -108,9 +108,10 @@ struct LUFactorizationOpLowering /*output_operand_aliases=*/outputOperandAliases, /*xla_side_effect_free=*/rewriter.getUnitAttr()); - rewriter.create( - op.getLoc(), ValueRange{jitCall.getResult(0), jitCall.getResult(1), - jitCall.getResult(2)}); + stablehlo::ReturnOp::create(rewriter, op.getLoc(), + ValueRange{jitCall.getResult(0), + jitCall.getResult(1), + jitCall.getResult(2)}); return func; } @@ -172,8 +173,8 @@ struct LUFactorizationOpLowering llvmPtrType, llvmPtrType, llvmPtrType}, false); - rewriter.create(op.getLoc(), lapackFn, funcType, - LLVM::Linkage::External); + LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), lapackFn, funcType, + LLVM::Linkage::External); } if (!moduleOp.lookupSymbol(lapackFnWrapper)) { @@ -186,8 +187,9 @@ struct LUFactorizationOpLowering llvmPtrType, llvmPtrType, llvmPtrType}, false); - auto funcOp = rewriter.create( - op.getLoc(), lapackFnWrapper, funcType, LLVM::Linkage::Private); + auto funcOp = + LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), lapackFnWrapper, + funcType, LLVM::Linkage::Private); rewriter.setInsertionPointToStart(funcOp.addEntryBlock(rewriter)); funcOp.setArgAttr(0, LLVM::LLVMDialect::getReadonlyAttrName(), @@ -206,17 +208,17 @@ struct LUFactorizationOpLowering rewriter.getUnitAttr()); } - auto callOp = rewriter.create( - op.getLoc(), TypeRange{}, SymbolRefAttr::get(ctx, lapackFn), - ValueRange{ - funcOp.getArgument(0), - funcOp.getArgument(1), - funcOp.getArgument(2), - funcOp.getArgument(3), - funcOp.getArgument(4), - funcOp.getArgument(5), - }); - rewriter.create(op.getLoc(), ValueRange{}); + auto callOp = LLVM::CallOp::create(rewriter, op.getLoc(), TypeRange{}, + SymbolRefAttr::get(ctx, lapackFn), + ValueRange{ + funcOp.getArgument(0), + funcOp.getArgument(1), + funcOp.getArgument(2), + funcOp.getArgument(3), + funcOp.getArgument(4), + funcOp.getArgument(5), + }); + LLVM::ReturnOp::create(rewriter, op.getLoc(), ValueRange{}); } // Call the LLVM function with enzymexla.jit_call @@ -265,8 +267,9 @@ struct LUFactorizationOpLowering SmallVector batchShape(inputShape.begin(), inputShape.begin() + numBatchDims); - auto batchOp = rewriter.create( - op.getLoc(), TypeRange{inputType, blasPivotType, blasInfoType}, + auto batchOp = enzyme::BatchOp::create( + rewriter, op.getLoc(), + TypeRange{inputType, blasPivotType, blasInfoType}, mlir::FlatSymbolRefAttr::get(op.getContext(), wrapperFnName), ValueRange{input}, rewriter.getDenseI64ArrayAttr(batchShape)); @@ -278,7 +281,7 @@ struct LUFactorizationOpLowering batchFunctions.push_back(cast(func.getOperation())); } else { auto callOp = - rewriter.create(op.getLoc(), func, ValueRange{input}); + func::CallOp::create(rewriter, op.getLoc(), func, ValueRange{input}); factorizedResult = callOp.getResult(0); pivotResult = callOp.getResult(1); @@ -286,24 +289,26 @@ struct LUFactorizationOpLowering } auto iterType = RankedTensorType::get({}, rewriter.getI32Type()); - auto iter = rewriter.create( - op.getLoc(), iterType, cast(makeAttr(iterType, 0))); - auto zeroConst = rewriter.create( - op.getLoc(), iterType, cast(makeAttr(iterType, 0))); - - auto pivots0indexed = rewriter.create( - op.getLoc(), pivotResult, - rewriter.create( - op.getLoc(), blasPivotType, + auto iter = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), iterType, + cast(makeAttr(iterType, 0))); + auto zeroConst = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), iterType, + cast(makeAttr(iterType, 0))); + + auto pivots0indexed = stablehlo::SubtractOp::create( + rewriter, op.getLoc(), pivotResult, + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), blasPivotType, cast(makeAttr(blasPivotType, 1)))); - auto permutation = rewriter.create( - op.getLoc(), blasPivotType, + auto permutation = stablehlo::IotaOp::create( + rewriter, op.getLoc(), blasPivotType, rewriter.getI64IntegerAttr(blasPivotType.getRank() - 1)); auto pivotToPermReturnTypes = {iterType, blasPivotType}; - auto pivotToPermWhileOp = rewriter.create( - op.getLoc(), TypeRange{iterType, blasPivotType}, + auto pivotToPermWhileOp = stablehlo::WhileOp::create( + rewriter, op.getLoc(), TypeRange{iterType, blasPivotType}, ValueRange{iter, permutation}); { @@ -315,17 +320,17 @@ struct LUFactorizationOpLowering for (auto type : pivotToPermReturnTypes) block->addArgument(type, pivotToPermWhileOp.getLoc()); - auto pivotShapeConst = rewriter.create( - op.getLoc(), iterType, + auto pivotShapeConst = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), iterType, cast(makeAttr( iterType, pivotType.getShape()[pivotType.getRank() - 1]))); - auto comparison = rewriter.create( - op.getLoc(), block->getArgument(0), pivotShapeConst, + auto comparison = stablehlo::CompareOp::create( + rewriter, op.getLoc(), block->getArgument(0), pivotShapeConst, stablehlo::ComparisonDirection::LT); - rewriter.create(op.getLoc(), - ValueRange{comparison.getResult()}); + stablehlo::ReturnOp::create(rewriter, op.getLoc(), + ValueRange{comparison.getResult()}); } { @@ -339,10 +344,10 @@ struct LUFactorizationOpLowering auto iterArg = block->getArgument(0); - auto updatedIter = rewriter.create( - op.getLoc(), iterArg, - rewriter.create( - op.getLoc(), iterType, + auto updatedIter = stablehlo::AddOp::create( + rewriter, op.getLoc(), iterArg, + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), iterType, cast(makeAttr(iterType, 1)))); /* @@ -365,10 +370,10 @@ struct LUFactorizationOpLowering sliceShape.push_back(1); SmallVector gatherSliceSizes(numBatchDims + 1, 1); - auto pivotJ = rewriter.create( - op.getLoc(), pivots0indexed, indices, sliceShape); - auto permutationX = rewriter.create( - op.getLoc(), block->getArgument(1), indices, sliceShape); + auto pivotJ = stablehlo::DynamicSliceOp::create( + rewriter, op.getLoc(), pivots0indexed, indices, sliceShape); + auto permutationX = stablehlo::DynamicSliceOp::create( + rewriter, op.getLoc(), block->getArgument(1), indices, sliceShape); auto gatherDims = stablehlo::GatherDimensionNumbersAttr::get( op.getContext(), @@ -378,18 +383,17 @@ struct LUFactorizationOpLowering /*startIndicesBatchingDims=*/batchDims, /*startIndexMap=*/{numBatchDims}, /*indexVectorDim=*/numBatchDims); - auto permutationY = rewriter.create( - op.getLoc(), + auto permutationY = stablehlo::GatherOp::create( + rewriter, op.getLoc(), RankedTensorType::get(sliceShape, cast( block->getArgument(1).getType()) .getElementType()), block->getArgument(1), pivotJ.getResult(), gatherDims, gatherSliceSizes); - auto permutationUpdate1 = - rewriter.create( - op.getLoc(), block->getArgument(1), permutationY->getResult(0), - indices); + auto permutationUpdate1 = stablehlo::DynamicUpdateSliceOp::create( + rewriter, op.getLoc(), block->getArgument(1), + permutationY->getResult(0), indices); auto scatterDims = stablehlo::ScatterDimensionNumbersAttr::get( op.getContext(), @@ -401,11 +405,12 @@ struct LUFactorizationOpLowering /*indexVectorDim=*/numBatchDims); SmallVector scatterShape(sliceShape.begin(), sliceShape.end() - 1); - auto permutationUpdate2 = rewriter.create( - op.getLoc(), TypeRange{permutationUpdate1->getResult(0).getType()}, + auto permutationUpdate2 = stablehlo::ScatterOp::create( + rewriter, op.getLoc(), + TypeRange{permutationUpdate1->getResult(0).getType()}, ValueRange(permutationUpdate1->getResult(0)), pivotJ, - ValueRange(rewriter.create( - op.getLoc(), + ValueRange(stablehlo::ReshapeOp::create( + rewriter, op.getLoc(), RankedTensorType::get(scatterShape, permutationX.getType().getElementType()), permutationX)), @@ -419,31 +424,32 @@ struct LUFactorizationOpLowering block->addArgument(RankedTensorType::get({}, blasIntType), op.getLoc()); rewriter.setInsertionPointToStart(block); - rewriter.create(op.getLoc(), - ValueRange{block->getArgument(1)}); + stablehlo::ReturnOp::create(rewriter, op.getLoc(), + ValueRange{block->getArgument(1)}); } - rewriter.create( - op.getLoc(), + stablehlo::ReturnOp::create( + rewriter, op.getLoc(), ValueRange{updatedIter, permutationUpdate2->getResult(0)}); } - auto finalPermutation = rewriter.create( - op.getLoc(), pivotToPermWhileOp.getResult(1), - rewriter.create( - op.getLoc(), blasPivotType, + auto finalPermutation = stablehlo::AddOp::create( + rewriter, op.getLoc(), pivotToPermWhileOp.getResult(1), + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), blasPivotType, cast(makeAttr(blasPivotType, 1)))); rewriter.replaceAllUsesWith(op.getResult(0), factorizedResult); - rewriter.replaceAllUsesWith(op.getResult(1), - rewriter.create( - op.getLoc(), pivotType, pivotResult)); - rewriter.replaceAllUsesWith(op.getResult(2), - rewriter.create( - op.getLoc(), pivotType, finalPermutation)); - rewriter.replaceAllUsesWith(op.getResult(3), - rewriter.create( - op.getLoc(), infoType, infoResult)); + rewriter.replaceAllUsesWith( + op.getResult(1), stablehlo::ConvertOp::create(rewriter, op.getLoc(), + pivotType, pivotResult)); + rewriter.replaceAllUsesWith( + op.getResult(2), + stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, + finalPermutation)); + rewriter.replaceAllUsesWith( + op.getResult(3), stablehlo::ConvertOp::create(rewriter, op.getLoc(), + infoType, infoResult)); std::map batchedFunctionCache; @@ -485,8 +491,9 @@ struct LUFactorizationOpLowering auto infoCuSolverType = RankedTensorType::get(infoType.getShape(), rewriter.getI32Type()); - auto cusolverffi = rewriter.create( - op.getLoc(), TypeRange{inputType, pivotCuSolverType, infoCuSolverType}, + auto cusolverffi = stablehlo::CustomCallOp::create( + rewriter, op.getLoc(), + TypeRange{inputType, pivotCuSolverType, infoCuSolverType}, ValueRange{input}, rewriter.getStringAttr("cusolver_getrf_ffi"), /*has_side_effect*/ nullptr, /*backend_config*/ nullptr, @@ -503,16 +510,16 @@ struct LUFactorizationOpLowering // unused custom call not getting optimized away. so adding a manual check if (!op.getResult(2).getUses().empty()) { - auto pivots0indexed = rewriter.create( - op.getLoc(), cusolverffi.getResult(1), - rewriter.create( - op.getLoc(), pivotCuSolverType, + auto pivots0indexed = stablehlo::SubtractOp::create( + rewriter, op.getLoc(), cusolverffi.getResult(1), + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), pivotCuSolverType, cast(makeAttr(pivotCuSolverType, 1)))); SmallVector outputRanksPermutation = {pivotRank}; - auto permutation = rewriter.create( - op.getLoc(), TypeRange{pivotCuSolverType}, + auto permutation = stablehlo::CustomCallOp::create( + rewriter, op.getLoc(), TypeRange{pivotCuSolverType}, ValueRange{pivots0indexed.getResult()}, rewriter.getStringAttr("cu_lu_pivots_to_permutation"), /*has_side_effect*/ nullptr, @@ -529,25 +536,28 @@ struct LUFactorizationOpLowering getSHLOLayout(rewriter, SmallVector{pivotRank}, SmallVector{true}, inputRank), /*output_operand_aliases*/ nullptr); - auto permutation1Indexed = rewriter.create( - op.getLoc(), - rewriter.create( - op.getLoc(), pivotCuSolverType, + auto permutation1Indexed = stablehlo::AddOp::create( + rewriter, op.getLoc(), + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), pivotCuSolverType, cast(makeAttr(pivotCuSolverType, 1))), permutation.getResult(0)); rewriter.replaceAllUsesWith( - op.getResult(2), rewriter.create( - op.getLoc(), pivotType, permutation1Indexed)); + op.getResult(2), + stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, + permutation1Indexed)); } rewriter.replaceAllUsesWith(op.getResult(0), cusolverffi.getResult(0)); rewriter.replaceAllUsesWith( - op.getResult(1), rewriter.create( - op.getLoc(), pivotType, cusolverffi.getResult(1))); + op.getResult(1), + stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, + cusolverffi.getResult(1))); rewriter.replaceAllUsesWith( - op.getResult(3), rewriter.create( - op.getLoc(), infoType, cusolverffi.getResult(2))); + op.getResult(3), + stablehlo::ConvertOp::create(rewriter, op.getLoc(), infoType, + cusolverffi.getResult(2))); return success(); } @@ -576,9 +586,10 @@ struct LUFactorizationOpLowering // TPU returns (LU, pivots, permutation). info isn't returned. based on // how JAX operates, I am assuming info != 0 when there is a nan in the // output. - auto customCall = rewriter.create( - op.getLoc(), TypeRange{inputType, pivotTPUType, permutationType}, - ValueRange{input}, rewriter.getStringAttr("LuDecomposition"), + auto customCall = stablehlo::CustomCallOp::create( + rewriter, op.getLoc(), + TypeRange{inputType, pivotTPUType, permutationType}, ValueRange{input}, + rewriter.getStringAttr("LuDecomposition"), /*has_side_effect*/ nullptr, /*backend_config*/ nullptr, /*api_version*/ nullptr, @@ -589,39 +600,43 @@ struct LUFactorizationOpLowering // LAPACK returns 1-indexed pivots, while XLA returns 0-indexed pivots. We // make it consistent with LAPACK by adding 1 to the pivots. - auto pivots1Indexed = rewriter.create( - op.getLoc(), - rewriter.create( - op.getLoc(), pivotType, cast(makeAttr(pivotType, 1))), - rewriter.create(op.getLoc(), pivotType, - customCall.getResult(1))); - - auto permutation1Indexed = rewriter.create( - op.getLoc(), - rewriter.create( - op.getLoc(), permutationType, + auto pivots1Indexed = stablehlo::AddOp::create( + rewriter, op.getLoc(), + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), pivotType, + cast(makeAttr(pivotType, 1))), + stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, + customCall.getResult(1))); + + auto permutation1Indexed = stablehlo::AddOp::create( + rewriter, op.getLoc(), + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), permutationType, cast(makeAttr(permutationType, 1))), - rewriter.create(op.getLoc(), permutationType, - customCall.getResult(2))); - - auto isFinite = rewriter.create( - op.getLoc(), - rewriter.create( - op.getLoc(), rewriter.create( - op.getLoc(), customCall.getResult(0))), - rewriter.create( - op.getLoc(), rewriter.create( - op.getLoc(), customCall.getResult(0)))); + stablehlo::ConvertOp::create(rewriter, op.getLoc(), permutationType, + customCall.getResult(2))); + + auto isFinite = stablehlo::AndOp::create( + rewriter, op.getLoc(), + stablehlo::IsFiniteOp::create( + rewriter, op.getLoc(), + stablehlo::RealOp::create(rewriter, op.getLoc(), + customCall.getResult(0))), + stablehlo::IsFiniteOp::create( + rewriter, op.getLoc(), + stablehlo::ImagOp::create(rewriter, op.getLoc(), + customCall.getResult(0)))); SmallVector reductionDims; for (int i = numBatchDims; i < inputRank; i++) reductionDims.push_back(i); auto initValType = RankedTensorType::get({}, rewriter.getI1Type()); - auto initVal = rewriter.create( - op.getLoc(), initValType, cast(makeAttr(initValType, 1))); + auto initVal = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), initValType, + cast(makeAttr(initValType, 1))); - auto allFinite = rewriter.create( - op.getLoc(), + auto allFinite = stablehlo::ReduceOp::create( + rewriter, op.getLoc(), RankedTensorType::get(infoType.getShape(), rewriter.getI1Type()), ValueRange{isFinite.getResult()}, ValueRange{initVal}, rewriter.getDenseI64ArrayAttr(reductionDims)); @@ -635,16 +650,17 @@ struct LUFactorizationOpLowering rewriter.setInsertionPointToStart(block); auto lhs = block->getArgument(0); auto rhs = block->getArgument(1); - auto andOp = rewriter.create(op.getLoc(), lhs, rhs); + auto andOp = stablehlo::AndOp::create(rewriter, op.getLoc(), lhs, rhs); - rewriter.create(op.getLoc(), - ValueRange{andOp.getResult()}); + stablehlo::ReturnOp::create(rewriter, op.getLoc(), + ValueRange{andOp.getResult()}); } // info == 0 if all finite (success) - auto info = rewriter.create( - op.getLoc(), infoType, - rewriter.create(op.getLoc(), allFinite.getResult(0))); + auto info = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), infoType, + stablehlo::NotOp::create(rewriter, op.getLoc(), + allFinite.getResult(0))); rewriter.replaceAllUsesWith(op.getResult(0), customCall.getResult(0)); rewriter.replaceAllUsesWith(op.getResult(1), pivots1Indexed); @@ -974,12 +990,8 @@ struct SVDFactorizationOpLowering return success(); } - // TODO find registered TPU kernel LogicalResult matchAndRewriteTPU(enzymexla::SVDFactorizationOp op, PatternRewriter &rewriter) const { - return rewriter.notifyMatchFailure( - op, "We don't know yet to which SVD TPU kernel to lower to :_("); - auto ctx = op->getContext(); auto input = op.getOperand(); @@ -996,7 +1008,7 @@ struct SVDFactorizationOpLowering auto custom_call_op = stablehlo::CustomCallOp::create( rewriter, op.getLoc(), TypeRange{type_input, type_tau}, - ValueRange{input}, rewriter.getStringAttr("Svd"), + ValueRange{input}, rewriter.getStringAttr("SVD"), /*has_side_effect*/ nullptr, /*backend_config*/ nullptr, /*api_version*/ nullptr, @@ -1029,8 +1041,8 @@ struct LowerEnzymeXLALinalgPass auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add(backend, blasIntWidth, context); - patterns.add(backend, blasIntWidth, context); + patterns.add( + backend, blasIntWidth, context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/test/lit_tests/linalg/lu.mlir b/test/lit_tests/linalg/lu.mlir index 948d483f7..b3f6903d8 100644 --- a/test/lit_tests/linalg/lu.mlir +++ b/test/lit_tests/linalg/lu.mlir @@ -75,29 +75,29 @@ module { // TPU-NEXT: return %0#0, %1, %2, %6 : tensor<64x64xf32>, tensor<64xi32>, tensor<64xi32>, tensor // TPU-NEXT: } -module { - // CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_ - // CPU: func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { - func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { - %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<64xi32>, tensor) - return %0#0, %0#1, %0#3 : tensor<64x64xf64>, tensor<64xi32>, tensor - } -} +// module { +// // CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_ +// // CPU: func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { +// func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { +// %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<64xi32>, tensor) +// return %0#0, %0#1, %0#3 : tensor<64x64xf64>, tensor<64xi32>, tensor +// } +// } -module { - // CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_ - // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { - func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { - %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) - return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor - } -} +// module { +// // CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_ +// // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { +// func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { +// %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) +// return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor +// } +// } -module { - // CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_ - // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { - func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { - %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) - return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor - } -} +// module { +// // CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_ +// // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { +// func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { +// %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) +// return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor +// } +// } From 302bd57f1423881c7feeed8b9f3636371702986c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Nov 2025 17:31:18 -0600 Subject: [PATCH 10/12] fix: use correct return --- .../jax/Passes/LowerEnzymeXLALinalg.cpp | 2 +- test/lit_tests/linalg/lu.mlir | 48 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp index 5a6019f80..6d2bddd9c 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp @@ -108,7 +108,7 @@ struct LUFactorizationOpLowering /*output_operand_aliases=*/outputOperandAliases, /*xla_side_effect_free=*/rewriter.getUnitAttr()); - stablehlo::ReturnOp::create(rewriter, op.getLoc(), + func::ReturnOp::create(rewriter, op.getLoc(), ValueRange{jitCall.getResult(0), jitCall.getResult(1), jitCall.getResult(2)}); diff --git a/test/lit_tests/linalg/lu.mlir b/test/lit_tests/linalg/lu.mlir index b3f6903d8..948d483f7 100644 --- a/test/lit_tests/linalg/lu.mlir +++ b/test/lit_tests/linalg/lu.mlir @@ -75,29 +75,29 @@ module { // TPU-NEXT: return %0#0, %1, %2, %6 : tensor<64x64xf32>, tensor<64xi32>, tensor<64xi32>, tensor // TPU-NEXT: } -// module { -// // CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_ -// // CPU: func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { -// func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { -// %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<64xi32>, tensor) -// return %0#0, %0#1, %0#3 : tensor<64x64xf64>, tensor<64xi32>, tensor -// } -// } +module { + // CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_ + // CPU: func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { + func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { + %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<64xi32>, tensor) + return %0#0, %0#1, %0#3 : tensor<64x64xf64>, tensor<64xi32>, tensor + } +} -// module { -// // CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_ -// // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { -// func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { -// %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) -// return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor -// } -// } +module { + // CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_ + // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { + func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { + %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) + return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor + } +} -// module { -// // CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_ -// // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { -// func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { -// %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) -// return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor -// } -// } +module { + // CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_ + // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { + func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { + %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) + return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor + } +} From 0237be3e959dedc5bdf7393e73a3738457a29f9f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Nov 2025 17:40:52 -0600 Subject: [PATCH 11/12] chore: run fmt --- src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp index 6d2bddd9c..b5bc67b87 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp @@ -109,9 +109,9 @@ struct LUFactorizationOpLowering /*xla_side_effect_free=*/rewriter.getUnitAttr()); func::ReturnOp::create(rewriter, op.getLoc(), - ValueRange{jitCall.getResult(0), - jitCall.getResult(1), - jitCall.getResult(2)}); + ValueRange{jitCall.getResult(0), + jitCall.getResult(1), + jitCall.getResult(2)}); return func; } From 61c99a23460242561c8f5abb3daa4c5097399319 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Nov 2025 18:15:53 -0600 Subject: [PATCH 12/12] test: fix --- test/lit_tests/linalg/lu.mlir | 2 +- test/lit_tests/linalg/lu_batched.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/lit_tests/linalg/lu.mlir b/test/lit_tests/linalg/lu.mlir index 948d483f7..821abc11d 100644 --- a/test/lit_tests/linalg/lu.mlir +++ b/test/lit_tests/linalg/lu.mlir @@ -14,7 +14,7 @@ module { // CPU-NEXT: %c_0 = stablehlo.constant dense<-1> : tensor<64xi64> // CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor // CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper (%c_1, %c_1, %arg0, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor, tensor, tensor<64x64xf32>, tensor, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) -// CPU-NEXT: stablehlo.return %0#0, %0#1, %0#2 : tensor<64x64xf32>, tensor<64xi64>, tensor +// CPU-NEXT: return %0#0, %0#1, %0#2 : tensor<64x64xf32>, tensor<64xi64>, tensor // CPU-NEXT: } // CPU-NEXT: llvm.func private @enzymexla_lapack_sgetrf_wrapper(%arg0: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg1: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg2: !llvm.ptr {llvm.nofree}, %arg3: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg4: !llvm.ptr {llvm.nofree, llvm.writeonly}, %arg5: !llvm.ptr {llvm.nofree, llvm.writeonly}) { // CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> () diff --git a/test/lit_tests/linalg/lu_batched.mlir b/test/lit_tests/linalg/lu_batched.mlir index 3e706a43a..39e56fa06 100644 --- a/test/lit_tests/linalg/lu_batched.mlir +++ b/test/lit_tests/linalg/lu_batched.mlir @@ -74,7 +74,7 @@ module { // CPU-NEXT: %12 = stablehlo.dynamic_update_slice %iterArg_10, %11, %2, %3 : (tensor<4x3xi64>, tensor<1x1xi64>, tensor, tensor) -> tensor<4x3xi64> // CPU-NEXT: stablehlo.return %1, %8, %10, %12 : tensor, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> // CPU-NEXT: } -// CPU-NEXT: stablehlo.return %0#1, %0#2, %0#3 : tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> +// CPU-NEXT: return %0#1, %0#2, %0#3 : tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> // CPU-NEXT: }