diff --git a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.cpp b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.cpp index 5287fee54..6c19dd516 100644 --- a/src/enzyme_ad/jax/Dialect/TritonExt/Ops.cpp +++ b/src/enzyme_ad/jax/Dialect/TritonExt/Ops.cpp @@ -15,11 +15,12 @@ template <> triton_ext::TritonCallOp ReadOnlyArg::create( PatternRewriter &rewriter, triton_ext::TritonCallOp launchOp, ArrayRef resTys, ArrayAttr outputAliases) const { - return rewriter.create( - launchOp.getLoc(), resTys, launchOp.getFn(), launchOp.getGridx(), - launchOp.getGridy(), launchOp.getGridz(), launchOp.getClusterx(), - launchOp.getClustery(), launchOp.getClusterz(), launchOp.getInputs(), - launchOp.getBackendConfigAttr(), launchOp.getOperandLayoutsAttr(), + return triton_ext::TritonCallOp::create( + rewriter, launchOp.getLoc(), resTys, launchOp.getFn(), + launchOp.getGridx(), launchOp.getGridy(), launchOp.getGridz(), + launchOp.getClusterx(), launchOp.getClustery(), launchOp.getClusterz(), + launchOp.getInputs(), launchOp.getBackendConfigAttr(), + launchOp.getOperandLayoutsAttr(), /*resultLayouts*/ nullptr, launchOp.getArgAttrsAttr(), launchOp.getResAttrsAttr(), outputAliases, launchOp.getXlaSideEffectFreeAttr()); diff --git a/src/enzyme_ad/jax/Passes/LinalgUtils.cpp b/src/enzyme_ad/jax/Passes/LinalgUtils.cpp index bdae610fe..8623f37a5 100644 --- a/src/enzyme_ad/jax/Passes/LinalgUtils.cpp +++ b/src/enzyme_ad/jax/Passes/LinalgUtils.cpp @@ -38,7 +38,7 @@ mlir::ArrayAttr getSHLOLayout(PatternRewriter &rewriter, return rewriter.getArrayAttr(attrs); } -std::optional lapack_precision_prefix(Type elementType) { +std::optional lapackPrecisionPrefix(Type elementType) { // single-precision float if (elementType.isF32()) { diff --git a/src/enzyme_ad/jax/Passes/LinalgUtils.h b/src/enzyme_ad/jax/Passes/LinalgUtils.h index 46a330cb4..b618f0781 100644 --- a/src/enzyme_ad/jax/Passes/LinalgUtils.h +++ b/src/enzyme_ad/jax/Passes/LinalgUtils.h @@ -17,6 +17,6 @@ mlir::ArrayAttr getSHLOLayout(mlir::PatternRewriter &rewriter, llvm::SmallVector isColMajorArr, int64_t maxNumDims); -std::optional lapack_precision_prefix(mlir::Type elementType); +std::optional lapackPrecisionPrefix(mlir::Type elementType); #endif // ENZYMEXLA_LINALGUTILS_H diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp index f9d5a2f16..78c869294 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp @@ -41,14 +41,11 @@ struct GeqrfOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(enzymexla::GeqrfOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); else if (backend == "cuda") - return this->matchAndRewrite_cuda(op, rewriter); - + return matchAndRewriteCUDA(op, rewriter); else if (backend == "tpu") - return this->matchAndRewrite_tpu(op, rewriter); - + return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -56,8 +53,8 @@ struct GeqrfOpLowering : public OpRewritePattern { // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance - LogicalResult matchAndRewrite_cpu(enzymexla::GeqrfOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::GeqrfOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -80,7 +77,7 @@ struct GeqrfOpLowering : public OpRewritePattern { auto type_llvm_void = LLVM::LLVMVoidType::get(ctx); std::string fn = "geqrf_"; - if (auto prefix = lapack_precision_prefix(inputElementType)) { + if (auto prefix = lapackPrecisionPrefix(inputElementType)) { fn = *prefix + fn; } else { op->emitOpError() << "Unsupported element type: " << inputElementType; @@ -209,8 +206,8 @@ struct GeqrfOpLowering : public OpRewritePattern { return success(); } - LogicalResult matchAndRewrite_cuda(enzymexla::GeqrfOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCUDA(enzymexla::GeqrfOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -265,8 +262,8 @@ struct GeqrfOpLowering : public OpRewritePattern { return success(); } - LogicalResult matchAndRewrite_tpu(enzymexla::GeqrfOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteTPU(enzymexla::GeqrfOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -316,14 +313,11 @@ struct GeqrtOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(enzymexla::GeqrtOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); // else if (backend == "cuda") - // return this->matchAndRewrite_cuda(op, rewriter); - + // return matchAndRewriteCUDA(op, rewriter); // else if (backend == "tpu") - // return this->matchAndRewrite_tpu(op, rewriter); - + // return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -331,8 +325,8 @@ struct GeqrtOpLowering : public OpRewritePattern { // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance - LogicalResult matchAndRewrite_cpu(enzymexla::GeqrtOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::GeqrtOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -355,7 +349,7 @@ struct GeqrtOpLowering : public OpRewritePattern { auto type_llvm_void = LLVM::LLVMVoidType::get(ctx); std::string fn = "geqrt_"; - if (auto prefix = lapack_precision_prefix(inputElementType)) { + if (auto prefix = lapackPrecisionPrefix(inputElementType)) { fn = *prefix + fn; } else { op->emitOpError() << "Unsupported element type: " << inputElementType; @@ -523,14 +517,11 @@ struct OrgqrOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(enzymexla::OrgqrOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); else if (backend == "cuda") - return this->matchAndRewrite_cuda(op, rewriter); - + return matchAndRewriteCUDA(op, rewriter); else if (backend == "tpu") - return this->matchAndRewrite_tpu(op, rewriter); - + return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -538,8 +529,8 @@ struct OrgqrOpLowering : public OpRewritePattern { // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance - LogicalResult matchAndRewrite_cpu(enzymexla::OrgqrOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::OrgqrOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -567,7 +558,7 @@ struct OrgqrOpLowering : public OpRewritePattern { auto type_llvm_void = LLVM::LLVMVoidType::get(ctx); std::string fn = "gqr_"; - if (auto prefix = lapack_precision_prefix(inputElementType)) { + if (auto prefix = lapackPrecisionPrefix(inputElementType)) { if (prefix == "s" || prefix == "d") fn = *prefix + "or" + fn; else @@ -688,8 +679,8 @@ struct OrgqrOpLowering : public OpRewritePattern { return success(); } - LogicalResult matchAndRewrite_cuda(enzymexla::OrgqrOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCUDA(enzymexla::OrgqrOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -734,8 +725,8 @@ struct OrgqrOpLowering : public OpRewritePattern { return success(); } - LogicalResult matchAndRewrite_tpu(enzymexla::OrgqrOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteTPU(enzymexla::OrgqrOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -772,14 +763,11 @@ struct OrmqrOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(enzymexla::OrmqrOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); // else if (backend == "cuda") - // return this->matchAndRewrite_cuda(op, rewriter); - + // return matchAndRewriteCUDA(op, rewriter); // else if (backend == "tpu") - // return this->matchAndRewrite_tpu(op, rewriter); - + // return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -787,8 +775,8 @@ struct OrmqrOpLowering : public OpRewritePattern { // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance - LogicalResult matchAndRewrite_cpu(enzymexla::OrmqrOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::OrmqrOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -873,7 +861,7 @@ struct OrmqrOpLowering : public OpRewritePattern { auto type_llvm_char = rewriter.getIntegerType(8); std::string fn = "mqr_"; - if (auto prefix = lapack_precision_prefix(A_eltype)) { + if (auto prefix = lapackPrecisionPrefix(A_eltype)) { if (prefix == "s" || prefix == "d") fn = *prefix + "or" + fn; else @@ -1031,14 +1019,11 @@ struct GemqrtOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(enzymexla::GemqrtOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); // else if (backend == "cuda") - // return this->matchAndRewrite_cuda(op, rewriter); - + // return matchAndRewriteCUDA(op, rewriter); // else if (backend == "tpu") - // return this->matchAndRewrite_tpu(op, rewriter); - + // return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -1046,8 +1031,8 @@ struct GemqrtOpLowering : public OpRewritePattern { // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance - LogicalResult matchAndRewrite_cpu(enzymexla::GemqrtOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::GemqrtOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -1141,7 +1126,7 @@ struct GemqrtOpLowering : public OpRewritePattern { auto type_llvm_char = rewriter.getIntegerType(8); std::string fn = "gemqrt_"; - if (auto prefix = lapack_precision_prefix(C_eltype)) { + if (auto prefix = lapackPrecisionPrefix(C_eltype)) { fn = *prefix + fn; } else { op->emitOpError() << "Unsupported element type: " << C_eltype; diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp index 43f6ed349..b5bc67b87 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp @@ -1,3 +1,5 @@ +#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Enzyme/MLIR/Passes/EnzymeBatchPass.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -41,488 +43,485 @@ struct LUFactorizationOpLowering LogicalResult matchAndRewrite(enzymexla::LUFactorizationOp op, PatternRewriter &rewriter) const override { + if (backend == "cpu") { + return matchAndRewriteCPU(op, rewriter); + } else if (backend == "cuda") { + return matchAndRewriteCUDA(op, rewriter); + } else if (backend == "tpu") { + return matchAndRewriteTPU(op, rewriter); + } else { + return rewriter.notifyMatchFailure(op, "Unknown backend " + backend); + } + } + +private: + func::FuncOp createWrapperFuncOpCPULapack( + PatternRewriter &rewriter, const std::string &lapackFn, + RankedTensorType inputType, RankedTensorType blasPivotType, + RankedTensorType blasInfoType, Type blasIntType, + const std::string &fnName, enzymexla::LUFactorizationOp op, + ArrayAttr operandLayouts, ArrayAttr resultLayouts, + ArrayAttr outputOperandAliases) const { + auto ctx = op->getContext(); + + OpBuilder::InsertionGuard guard(rewriter); + auto moduleOp = op->getParentOfType(); + if (!moduleOp) + return nullptr; + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + SmallVector argTypes = {inputType}; + SmallVector retTypes = {inputType, blasPivotType, blasInfoType}; + + FunctionType calleeType = rewriter.getFunctionType(argTypes, retTypes); + func::FuncOp func = + func::FuncOp::create(rewriter, op.getLoc(), fnName, calleeType); + func.setPrivate(); + + auto &entryBlock = *func.addEntryBlock(); + rewriter.setInsertionPointToStart(&entryBlock); + + auto input = entryBlock.getArgument(0); + auto mSize = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), RankedTensorType::get({}, blasIntType), + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), input, 0)); + auto nSize = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), RankedTensorType::get({}, blasIntType), + stablehlo::GetDimensionSizeOp::create(rewriter, op.getLoc(), input, 1)); + auto pivot = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), blasPivotType, + cast(makeAttr(blasPivotType, -1))); + auto info = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), blasInfoType, + cast(makeAttr(blasInfoType, -1))); + + auto jitCall = enzymexla::JITCallOp::create( + rewriter, op.getLoc(), + TypeRange{inputType, blasPivotType, blasInfoType}, + mlir::FlatSymbolRefAttr::get(ctx, lapackFn), + ValueRange{mSize, nSize, input, mSize, pivot, info}, + rewriter.getStringAttr(""), + /*operand_layouts=*/operandLayouts, + /*result_layouts=*/resultLayouts, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/outputOperandAliases, + /*xla_side_effect_free=*/rewriter.getUnitAttr()); + + func::ReturnOp::create(rewriter, op.getLoc(), + ValueRange{jitCall.getResult(0), + jitCall.getResult(1), + jitCall.getResult(2)}); + + return func; + } + + LogicalResult matchAndRewriteCPU(enzymexla::LUFactorizationOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); - LLVMTypeConverter typeConverter(ctx); auto input = op.getOperand(); auto inputShape = cast(input.getType()).getShape(); auto inputRank = static_cast(inputShape.size()); auto inputType = cast(input.getType()); + auto unbatchedInputType = RankedTensorType::get( + SmallVector(inputType.getShape().end() - 2, + inputType.getShape().end()), + inputType.getElementType()); auto inputElementType = inputType.getElementType(); - const int64_t m = inputShape[inputRank - 2]; - const int64_t n = inputShape[inputRank - 1]; const int64_t numBatchDims = inputRank - 2; auto pivotType = cast(op.getResult(1).getType()); auto pivotRank = pivotType.getRank(); + auto unbatchedPivotType = RankedTensorType::get( + SmallVector(pivotType.getShape().end() - 1, + pivotType.getShape().end()), + pivotType.getElementType()); + auto infoType = cast(op.getResult(3).getType()); auto infoRank = infoType.getRank(); + auto unbatchedInfoType = + RankedTensorType::get({}, infoType.getElementType()); - if (backend == "cpu") { - auto moduleOp = op->getParentOfType(); - static int64_t fnNum = 0; - - auto blasIntType = rewriter.getIntegerType(blasIntWidth); - auto llvmBlasIntType = typeConverter.convertType(blasIntType); - auto llvmPtrType = LLVM::LLVMPointerType::get(ctx); - auto llvmVoidPtrType = LLVM::LLVMVoidType::get(ctx); - - std::string lapackFn; - if (inputElementType.isF32()) { - lapackFn = "sgetrf_"; // single-precision float - } else if (inputElementType.isF64()) { - lapackFn = "dgetrf_"; // double-precision float - } else if (auto complexType = dyn_cast(inputElementType)) { - auto elem = complexType.getElementType(); - if (elem.isF32()) { - lapackFn = "cgetrf_"; // single-precision complex - } else if (elem.isF64()) { - lapackFn = "zgetrf_"; // double-precision complex - } else { - op->emitOpError() << "Unsupported complex element type: " << elem; - return rewriter.notifyMatchFailure( - op, "unsupported complex element type"); - } - } else { - op->emitOpError() << "Unsupported input element type: " - << inputElementType; - return rewriter.notifyMatchFailure(op, - "unsupported input element type"); - } - lapackFn = "enzymexla_lapack_" + lapackFn; - - // Generate the LLVM function body - std::string fnName = lapackFn + "wrapper_" + std::to_string(fnNum); - fnNum++; - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - - auto funcType = LLVM::LLVMFunctionType::get( - llvmVoidPtrType, {llvmPtrType, llvmPtrType, llvmPtrType}, false); - - auto func = - LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), fnName, funcType); - rewriter.setInsertionPointToStart(func.addEntryBlock(rewriter)); - - auto ptrSize = - LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmBlasIntType, - rewriter.getIntegerAttr(blasIntType, 1)); - auto mPtr = LLVM::AllocaOp::create(rewriter, op.getLoc(), llvmPtrType, - llvmBlasIntType, ptrSize, 0); - auto nPtr = LLVM::AllocaOp::create(rewriter, op.getLoc(), llvmPtrType, - llvmBlasIntType, ptrSize, 0); - - auto mVal = - LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmBlasIntType, - rewriter.getIntegerAttr(blasIntType, m)); - auto nVal = - LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmBlasIntType, - rewriter.getIntegerAttr(blasIntType, n)); - - LLVM::StoreOp::create(rewriter, op.getLoc(), mVal, mPtr); - LLVM::StoreOp::create(rewriter, op.getLoc(), nVal, nPtr); - - LLVM::CallOp::create(rewriter, op.getLoc(), TypeRange{}, - SymbolRefAttr::get(ctx, lapackFn), - ValueRange{ - mPtr, - nPtr, - func.getArgument(0), - mPtr, - func.getArgument(1), - func.getArgument(2), - }); - - LLVM::ReturnOp::create(rewriter, op.getLoc(), ValueRange{}); - } + auto moduleOp = op->getParentOfType(); - // Insert function declaration if not already present - if (!moduleOp.lookupSymbol(lapackFn)) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto blasIntType = rewriter.getIntegerType(blasIntWidth); + auto intType = RankedTensorType::get({}, blasIntType); + auto llvmPtrType = LLVM::LLVMPointerType::get(ctx); + auto llvmVoidPtrType = LLVM::LLVMVoidType::get(ctx); - auto funcType = - LLVM::LLVMFunctionType::get(llvmVoidPtrType, - {llvmPtrType, llvmPtrType, llvmPtrType, - llvmPtrType, llvmPtrType, llvmPtrType}, - false); + std::string lapackFn; + auto prefix = lapackPrecisionPrefix(inputElementType); + if (prefix) { + lapackFn = "enzymexla_lapack_" + *prefix + "getrf_"; + } else { + op->emitOpError() << "Unsupported input element type: " + << inputElementType; + return rewriter.notifyMatchFailure(op, "unsupported input element type"); + } + std::string lapackFnWrapper = lapackFn + "wrapper"; - LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), lapackFn, funcType, - LLVM::Linkage::External); - } + // Insert function declaration if not already present + if (!moduleOp.lookupSymbol(lapackFn)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); - // Call the LLVM function with enzymexla.jit_call - SmallVector aliases; - for (int i = 0; i < 3; ++i) { - aliases.push_back(stablehlo::OutputOperandAliasAttr::get( - ctx, std::vector{i}, i, std::vector{})); - } + auto funcType = + LLVM::LLVMFunctionType::get(llvmVoidPtrType, + {llvmPtrType, llvmPtrType, llvmPtrType, + llvmPtrType, llvmPtrType, llvmPtrType}, + false); - auto blasPivotType = RankedTensorType::get( - pivotType.getShape(), rewriter.getIntegerType(blasIntWidth)); - auto blasInfoType = RankedTensorType::get( - infoType.getShape(), rewriter.getIntegerType(blasIntWidth)); - - SmallVector isColMajorArr = {true, true, true}; - SmallVector operandRanks = {2, 1, 0}; - SmallVector outputRanks = {2, 1, 0}; - auto operandLayouts = - getSHLOLayout(rewriter, operandRanks, isColMajorArr, 2); - auto resultLayouts = - getSHLOLayout(rewriter, outputRanks, isColMajorArr, 2); - - auto iterType = RankedTensorType::get({}, rewriter.getI32Type()); - auto iter = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr(iterType, 0))); - auto zeroConst = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr(iterType, 0))); - - Value factorizedResult, pivotResult, infoResult; - - if (numBatchDims > 0) { - // TODO: Implement batched LU factorizations by directly calling MKL - // https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-fortran/2024-0/getrf-batch-strided.html. - - int64_t batchSize = 1; - for (int i = 0; i < numBatchDims; i++) { - batchSize *= inputShape[i]; - } - SmallVector flattenedInput = {batchSize, m, n}; - - auto flatInputType = - RankedTensorType::get(flattenedInput, inputElementType); - auto flatInput = stablehlo::ReshapeOp::create(rewriter, op.getLoc(), - flatInputType, input); - - auto flatPivotType = RankedTensorType::get( - {batchSize, pivotType.getShape()[pivotRank - 1]}, blasIntType); - auto flatPivot = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), flatPivotType, - cast(makeAttr(flatPivotType, -1))); - - auto flatInfoType = RankedTensorType::get({batchSize}, blasIntType); - auto flatInfo = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), flatInfoType, - cast(makeAttr(flatInfoType, -1))); - - auto whileReturnTypes = {iterType, flatInputType, flatPivotType, - flatInfoType}; - auto whileOp = stablehlo::WhileOp::create( - rewriter, op.getLoc(), - TypeRange{iterType, flatInputType, flatPivotType, flatInfoType}, - ValueRange{iter, flatInput, flatPivot, flatInfo}); + LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), lapackFn, funcType, + LLVM::Linkage::External); + } - { - OpBuilder::InsertionGuard guard(rewriter); + if (!moduleOp.lookupSymbol(lapackFnWrapper)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); - Block *block = rewriter.createBlock(&whileOp.getCond()); - rewriter.setInsertionPointToStart(block); + auto funcType = + LLVM::LLVMFunctionType::get(llvmVoidPtrType, + {llvmPtrType, llvmPtrType, llvmPtrType, + llvmPtrType, llvmPtrType, llvmPtrType}, + false); - for (auto type : whileReturnTypes) { - block->addArgument(type, whileOp.getLoc()); - } + auto funcOp = + LLVM::LLVMFuncOp::create(rewriter, op.getLoc(), lapackFnWrapper, + funcType, LLVM::Linkage::Private); + rewriter.setInsertionPointToStart(funcOp.addEntryBlock(rewriter)); + + funcOp.setArgAttr(0, LLVM::LLVMDialect::getReadonlyAttrName(), + rewriter.getUnitAttr()); + funcOp.setArgAttr(1, LLVM::LLVMDialect::getReadonlyAttrName(), + rewriter.getUnitAttr()); + // 2 is read + write + funcOp.setArgAttr(3, LLVM::LLVMDialect::getReadonlyAttrName(), + rewriter.getUnitAttr()); + funcOp.setArgAttr(4, LLVM::LLVMDialect::getWriteOnlyAttrName(), + rewriter.getUnitAttr()); + funcOp.setArgAttr(5, LLVM::LLVMDialect::getWriteOnlyAttrName(), + rewriter.getUnitAttr()); + for (int i = 0; i < 6; i++) { + funcOp.setArgAttr(i, LLVM::LLVMDialect::getNoFreeAttrName(), + rewriter.getUnitAttr()); + } - auto batchSizeConst = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr(iterType, batchSize))); + auto callOp = LLVM::CallOp::create(rewriter, op.getLoc(), TypeRange{}, + SymbolRefAttr::get(ctx, lapackFn), + ValueRange{ + funcOp.getArgument(0), + funcOp.getArgument(1), + funcOp.getArgument(2), + funcOp.getArgument(3), + funcOp.getArgument(4), + funcOp.getArgument(5), + }); + LLVM::ReturnOp::create(rewriter, op.getLoc(), ValueRange{}); + } - auto comparison = stablehlo::CompareOp::create( - rewriter, op.getLoc(), block->getArgument(0), batchSizeConst, - stablehlo::ComparisonDirection::LT); + // Call the LLVM function with enzymexla.jit_call + SmallVector aliases; + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{0}, 2, std::vector{})); + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{1}, 4, std::vector{})); + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{2}, 5, std::vector{})); - stablehlo::ReturnOp::create(rewriter, op.getLoc(), - ValueRange{comparison.getResult()}); - } + auto unbatchedBLASPivotType = RankedTensorType::get( + unbatchedPivotType.getShape(), rewriter.getIntegerType(blasIntWidth)); + auto blasPivotType = RankedTensorType::get( + pivotType.getShape(), rewriter.getIntegerType(blasIntWidth)); + auto unbatchedBLASInfoType = RankedTensorType::get( + unbatchedInfoType.getShape(), rewriter.getIntegerType(blasIntWidth)); + auto blasInfoType = RankedTensorType::get( + infoType.getShape(), rewriter.getIntegerType(blasIntWidth)); - { - OpBuilder::InsertionGuard guard(rewriter); + auto operandLayouts = + getSHLOLayout(rewriter, SmallVector{0, 0, 2, 0, 1, 0}, + SmallVector(6, true), 2); + auto resultLayouts = + getSHLOLayout(rewriter, SmallVector{2, 1, 0}, + SmallVector(3, true), 2); + + Value factorizedResult, pivotResult, infoResult; + static int64_t fnNum = 0; + std::string wrapperFnName = lapackFn + std::to_string(fnNum++); + + func::FuncOp func = createWrapperFuncOpCPULapack( + rewriter, lapackFnWrapper, unbatchedInputType, unbatchedBLASPivotType, + unbatchedBLASInfoType, blasIntType, wrapperFnName, op, operandLayouts, + resultLayouts, rewriter.getArrayAttr(aliases)); + if (!func) + return rewriter.notifyMatchFailure(op, + "failed to create wrapper function"); - Block *block = rewriter.createBlock(&whileOp.getBody()); - rewriter.setInsertionPointToStart(block); + SmallVector batchOps; + SmallVector batchFunctions; - for (auto type : whileReturnTypes) { - block->addArgument(type, whileOp.getLoc()); - } + if (numBatchDims > 0) { + // TODO: Implement batched LU factorizations by directly calling MKL + // https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-fortran/2024-0/getrf-batch-strided.html. + SmallVector batchShape(inputShape.begin(), + inputShape.begin() + numBatchDims); - auto iterArg = block->getArgument(0); + auto batchOp = enzyme::BatchOp::create( + rewriter, op.getLoc(), + TypeRange{inputType, blasPivotType, blasInfoType}, + mlir::FlatSymbolRefAttr::get(op.getContext(), wrapperFnName), + ValueRange{input}, rewriter.getDenseI64ArrayAttr(batchShape)); - auto inputSliceType = RankedTensorType::get({m, n}, inputElementType); - auto inputSlice = stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), inputSliceType, - stablehlo::DynamicSliceOp::create( - rewriter, op.getLoc(), block->getArgument(1), - ValueRange{iterArg, zeroConst, zeroConst}, - rewriter.getDenseI64ArrayAttr({1, m, n}))); + factorizedResult = batchOp.getResult(0); + pivotResult = batchOp.getResult(1); + infoResult = batchOp.getResult(2); - auto pivotSliceType = - RankedTensorType::get({std::min(m, n)}, blasIntType); - auto pivotSlice = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), pivotSliceType, - cast(makeAttr(pivotSliceType, -1))); + batchOps.push_back(batchOp); + batchFunctions.push_back(cast(func.getOperation())); + } else { + auto callOp = + func::CallOp::create(rewriter, op.getLoc(), func, ValueRange{input}); - auto infoSliceType = RankedTensorType::get({}, blasIntType); - auto infoSlice = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), infoSliceType, - cast(makeAttr(infoSliceType, -1))); + factorizedResult = callOp.getResult(0); + pivotResult = callOp.getResult(1); + infoResult = callOp.getResult(2); + } - auto jitCall = enzymexla::JITCallOp::create( - rewriter, op.getLoc(), - TypeRange{inputSliceType, pivotSliceType, infoSliceType}, - mlir::FlatSymbolRefAttr::get(ctx, fnName), - ValueRange{inputSlice, pivotSlice, infoSlice}, - rewriter.getStringAttr(""), - /*operand_layouts=*/operandLayouts, - /*result_layouts=*/resultLayouts, - /*arg_attrs=*/nullptr, - /*res_attrs=*/nullptr, - /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), - /*xla_side_effect_free=*/rewriter.getUnitAttr()); - - auto inputUpdated = stablehlo::DynamicUpdateSliceOp::create( - rewriter, op.getLoc(), block->getArgument(1), - stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), - RankedTensorType::get({1, m, n}, inputElementType), - jitCall.getResult(0)), - ValueRange{iterArg, zeroConst, zeroConst}); - auto pivotUpdated = stablehlo::DynamicUpdateSliceOp::create( - rewriter, op.getLoc(), block->getArgument(2), - stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), - RankedTensorType::get({1, std::min(m, n)}, blasIntType), - jitCall.getResult(1)), - ValueRange{iterArg, zeroConst}); - auto infoUpdated = stablehlo::DynamicUpdateSliceOp::create( - rewriter, op.getLoc(), block->getArgument(3), - stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), - RankedTensorType::get({1}, blasIntType), - jitCall.getResult(2)), - ValueRange{iterArg}); - - auto updatedIter = stablehlo::AddOp::create( - rewriter, op.getLoc(), block->getArgument(0), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr(iterType, 1)))); - - stablehlo::ReturnOp::create( - rewriter, op.getLoc(), - ValueRange{updatedIter, inputUpdated, pivotUpdated, infoUpdated}); - } - - factorizedResult = stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), inputType, whileOp.getResult(1)); - pivotResult = stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), blasPivotType, whileOp.getResult(2)); - infoResult = stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), blasInfoType, whileOp.getResult(3)); - } else { - auto pivot = stablehlo::ConstantOp::create( + auto iterType = RankedTensorType::get({}, rewriter.getI32Type()); + auto iter = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), iterType, + cast(makeAttr(iterType, 0))); + auto zeroConst = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), iterType, + cast(makeAttr(iterType, 0))); + + auto pivots0indexed = stablehlo::SubtractOp::create( + rewriter, op.getLoc(), pivotResult, + stablehlo::ConstantOp::create( rewriter, op.getLoc(), blasPivotType, - cast(makeAttr(blasPivotType, -1))); - auto info = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), blasInfoType, - cast(makeAttr(blasInfoType, -1))); + cast(makeAttr(blasPivotType, 1)))); - auto jitCall = enzymexla::JITCallOp::create( - rewriter, op.getLoc(), - TypeRange{inputType, blasPivotType, blasInfoType}, - mlir::FlatSymbolRefAttr::get(ctx, fnName), - ValueRange{input, pivot, info}, rewriter.getStringAttr(""), - /*operand_layouts=*/operandLayouts, - /*result_layouts=*/resultLayouts, - /*arg_attrs=*/nullptr, - /*res_attrs=*/nullptr, - /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), - /*xla_side_effect_free=*/rewriter.getUnitAttr()); - - factorizedResult = jitCall.getResult(0); - pivotResult = jitCall.getResult(1); - infoResult = jitCall.getResult(2); - } + auto permutation = stablehlo::IotaOp::create( + rewriter, op.getLoc(), blasPivotType, + rewriter.getI64IntegerAttr(blasPivotType.getRank() - 1)); - auto pivots0indexed = stablehlo::SubtractOp::create( - rewriter, op.getLoc(), pivotResult, - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), blasPivotType, - cast(makeAttr(blasPivotType, 1)))); + auto pivotToPermReturnTypes = {iterType, blasPivotType}; + auto pivotToPermWhileOp = stablehlo::WhileOp::create( + rewriter, op.getLoc(), TypeRange{iterType, blasPivotType}, + ValueRange{iter, permutation}); - auto permutation = stablehlo::IotaOp::create( - rewriter, op.getLoc(), blasPivotType, - rewriter.getI64IntegerAttr(blasPivotType.getRank() - 1)); + { + OpBuilder::InsertionGuard guard(rewriter); - auto pivotToPermReturnTypes = {iterType, blasPivotType}; - auto pivotToPermWhileOp = stablehlo::WhileOp::create( - rewriter, op.getLoc(), TypeRange{iterType, blasPivotType}, - ValueRange{iter, permutation}); + Block *block = rewriter.createBlock(&pivotToPermWhileOp.getCond()); + rewriter.setInsertionPointToStart(block); - { - OpBuilder::InsertionGuard guard(rewriter); + for (auto type : pivotToPermReturnTypes) + block->addArgument(type, pivotToPermWhileOp.getLoc()); - Block *block = rewriter.createBlock(&pivotToPermWhileOp.getCond()); - rewriter.setInsertionPointToStart(block); + auto pivotShapeConst = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), iterType, + cast(makeAttr( + iterType, pivotType.getShape()[pivotType.getRank() - 1]))); + + auto comparison = stablehlo::CompareOp::create( + rewriter, op.getLoc(), block->getArgument(0), pivotShapeConst, + stablehlo::ComparisonDirection::LT); - for (auto type : pivotToPermReturnTypes) - block->addArgument(type, pivotToPermWhileOp.getLoc()); + stablehlo::ReturnOp::create(rewriter, op.getLoc(), + ValueRange{comparison.getResult()}); + } - auto pivotShapeConst = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr( - iterType, pivotType.getShape()[pivotType.getRank() - 1]))); + { + OpBuilder::InsertionGuard guard(rewriter); - auto comparison = stablehlo::CompareOp::create( - rewriter, op.getLoc(), block->getArgument(0), pivotShapeConst, - stablehlo::ComparisonDirection::LT); + Block *block = rewriter.createBlock(&pivotToPermWhileOp.getBody()); + rewriter.setInsertionPointToStart(block); - stablehlo::ReturnOp::create(rewriter, op.getLoc(), - ValueRange{comparison.getResult()}); + for (auto type : pivotToPermReturnTypes) + block->addArgument(type, pivotToPermWhileOp.getLoc()); + + auto iterArg = block->getArgument(0); + + auto updatedIter = stablehlo::AddOp::create( + rewriter, op.getLoc(), iterArg, + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), iterType, + cast(makeAttr(iterType, 1)))); + + /* + for i in range(pivot.shape[-1]): + j = pivot[..., i] # dynamic slice + x = permutation[..., i] # dynamic slice + y = permutation[j] # gather + permutation[..., i] = y # dynamic update slice + permutation[j] = x # scatter + */ + + SmallVector indices; + SmallVector sliceShape, batchDims; + for (int i = 0; i < numBatchDims; i++) { + indices.push_back(zeroConst); + sliceShape.push_back(pivotType.getShape()[i]); + batchDims.push_back(i); } + indices.push_back(iterArg); + sliceShape.push_back(1); + SmallVector gatherSliceSizes(numBatchDims + 1, 1); + + auto pivotJ = stablehlo::DynamicSliceOp::create( + rewriter, op.getLoc(), pivots0indexed, indices, sliceShape); + auto permutationX = stablehlo::DynamicSliceOp::create( + rewriter, op.getLoc(), block->getArgument(1), indices, sliceShape); + + auto gatherDims = stablehlo::GatherDimensionNumbersAttr::get( + op.getContext(), + /*offsetDims=*/{numBatchDims}, + /*collapsedSliceDims=*/{}, + /*operandBatchingDims=*/batchDims, + /*startIndicesBatchingDims=*/batchDims, + /*startIndexMap=*/{numBatchDims}, + /*indexVectorDim=*/numBatchDims); + auto permutationY = stablehlo::GatherOp::create( + rewriter, op.getLoc(), + RankedTensorType::get(sliceShape, cast( + block->getArgument(1).getType()) + .getElementType()), + block->getArgument(1), pivotJ.getResult(), gatherDims, + gatherSliceSizes); + + auto permutationUpdate1 = stablehlo::DynamicUpdateSliceOp::create( + rewriter, op.getLoc(), block->getArgument(1), + permutationY->getResult(0), indices); + + auto scatterDims = stablehlo::ScatterDimensionNumbersAttr::get( + op.getContext(), + /*updateWindowDims=*/{}, + /*insertedWindowDims=*/{numBatchDims}, + /*inputBatchingDims=*/batchDims, + /*scatterIndicesBatchingDims=*/batchDims, + /*scatterDimsToOperandDims=*/{numBatchDims}, + /*indexVectorDim=*/numBatchDims); + SmallVector scatterShape(sliceShape.begin(), + sliceShape.end() - 1); + auto permutationUpdate2 = stablehlo::ScatterOp::create( + rewriter, op.getLoc(), + TypeRange{permutationUpdate1->getResult(0).getType()}, + ValueRange(permutationUpdate1->getResult(0)), pivotJ, + ValueRange(stablehlo::ReshapeOp::create( + rewriter, op.getLoc(), + RankedTensorType::get(scatterShape, + permutationX.getType().getElementType()), + permutationX)), + scatterDims); { OpBuilder::InsertionGuard guard(rewriter); - - Block *block = rewriter.createBlock(&pivotToPermWhileOp.getBody()); + auto *block = + rewriter.createBlock(&permutationUpdate2.getUpdateComputation()); + block->addArgument(RankedTensorType::get({}, blasIntType), op.getLoc()); + block->addArgument(RankedTensorType::get({}, blasIntType), op.getLoc()); rewriter.setInsertionPointToStart(block); - for (auto type : pivotToPermReturnTypes) - block->addArgument(type, pivotToPermWhileOp.getLoc()); - - auto iterArg = block->getArgument(0); - - auto updatedIter = stablehlo::AddOp::create( - rewriter, op.getLoc(), iterArg, - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), iterType, - cast(makeAttr(iterType, 1)))); - - /* - for i in range(pivot.shape[-1]): - j = pivot[..., i] # dynamic slice - x = permutation[..., i] # dynamic slice - y = permutation[j] # gather - permutation[..., i] = y # dynamic update slice - permutation[j] = x # scatter - */ - - SmallVector indices; - SmallVector sliceShape, batchDims; - for (int i = 0; i < numBatchDims; i++) { - indices.push_back(zeroConst); - sliceShape.push_back(pivotType.getShape()[i]); - batchDims.push_back(i); - } - indices.push_back(iterArg); - sliceShape.push_back(1); - SmallVector gatherSliceSizes(numBatchDims + 1, 1); - - auto pivotJ = stablehlo::DynamicSliceOp::create( - rewriter, op.getLoc(), pivots0indexed, indices, sliceShape); - auto permutationX = stablehlo::DynamicSliceOp::create( - rewriter, op.getLoc(), block->getArgument(1), indices, sliceShape); - - auto gatherDims = stablehlo::GatherDimensionNumbersAttr::get( - op.getContext(), - /*offsetDims=*/{numBatchDims}, - /*collapsedSliceDims=*/{}, - /*operandBatchingDims=*/batchDims, - /*startIndicesBatchingDims=*/batchDims, - /*startIndexMap=*/{numBatchDims}, - /*indexVectorDim=*/numBatchDims); - auto permutationY = stablehlo::GatherOp::create( - rewriter, op.getLoc(), - RankedTensorType::get( - sliceShape, - cast(block->getArgument(1).getType()) - .getElementType()), - block->getArgument(1), pivotJ.getResult(), gatherDims, - gatherSliceSizes); - - auto permutationUpdate1 = stablehlo::DynamicUpdateSliceOp::create( - rewriter, op.getLoc(), block->getArgument(1), - permutationY->getResult(0), indices); - - auto scatterDims = stablehlo::ScatterDimensionNumbersAttr::get( - op.getContext(), - /*updateWindowDims=*/{}, - /*insertedWindowDims=*/{numBatchDims}, - /*inputBatchingDims=*/batchDims, - /*scatterIndicesBatchingDims=*/batchDims, - /*scatterDimsToOperandDims=*/{numBatchDims}, - /*indexVectorDim=*/numBatchDims); - SmallVector scatterShape(sliceShape.begin(), - sliceShape.end() - 1); - auto permutationUpdate2 = stablehlo::ScatterOp::create( - rewriter, op.getLoc(), - TypeRange{permutationUpdate1->getResult(0).getType()}, - ValueRange(permutationUpdate1->getResult(0)), pivotJ, - ValueRange(stablehlo::ReshapeOp::create( - rewriter, op.getLoc(), - RankedTensorType::get(scatterShape, - permutationX.getType().getElementType()), - permutationX)), - scatterDims); - - { - OpBuilder::InsertionGuard guard(rewriter); - auto *block = - rewriter.createBlock(&permutationUpdate2.getUpdateComputation()); - block->addArgument(RankedTensorType::get({}, blasIntType), - op.getLoc()); - block->addArgument(RankedTensorType::get({}, blasIntType), - op.getLoc()); - rewriter.setInsertionPointToStart(block); - - stablehlo::ReturnOp::create(rewriter, op.getLoc(), - ValueRange{block->getArgument(1)}); - } - - stablehlo::ReturnOp::create( - rewriter, op.getLoc(), - ValueRange{updatedIter, permutationUpdate2->getResult(0)}); + stablehlo::ReturnOp::create(rewriter, op.getLoc(), + ValueRange{block->getArgument(1)}); } - auto finalPermutation = stablehlo::AddOp::create( - rewriter, op.getLoc(), pivotToPermWhileOp.getResult(1), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), blasPivotType, - cast(makeAttr(blasPivotType, 1)))); + stablehlo::ReturnOp::create( + rewriter, op.getLoc(), + ValueRange{updatedIter, permutationUpdate2->getResult(0)}); + } - rewriter.replaceAllUsesWith(op.getResult(0), factorizedResult); - rewriter.replaceAllUsesWith( - op.getResult(1), stablehlo::ConvertOp::create( - rewriter, op.getLoc(), pivotType, pivotResult)); - rewriter.replaceAllUsesWith( - op.getResult(2), - stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, - finalPermutation)); - rewriter.replaceAllUsesWith( - op.getResult(3), stablehlo::ConvertOp::create(rewriter, op.getLoc(), - infoType, infoResult)); + auto finalPermutation = stablehlo::AddOp::create( + rewriter, op.getLoc(), pivotToPermWhileOp.getResult(1), + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), blasPivotType, + cast(makeAttr(blasPivotType, 1)))); + + rewriter.replaceAllUsesWith(op.getResult(0), factorizedResult); + rewriter.replaceAllUsesWith( + op.getResult(1), stablehlo::ConvertOp::create(rewriter, op.getLoc(), + pivotType, pivotResult)); + rewriter.replaceAllUsesWith( + op.getResult(2), + stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, + finalPermutation)); + rewriter.replaceAllUsesWith( + op.getResult(3), stablehlo::ConvertOp::create(rewriter, op.getLoc(), + infoType, infoResult)); + + std::map + batchedFunctionCache; + for (auto [batchOp, func] : llvm::zip(batchOps, batchFunctions)) { + if (failed(enzyme::batchutils::batchOperation(rewriter, batchOp, func, + batchedFunctionCache))) { + return rewriter.notifyMatchFailure(op, "failed to batch operation"); + } + } - return success(); - } else if (backend == "cuda") { - SmallVector aliases = {stablehlo::OutputOperandAliasAttr::get( - ctx, std::vector{0}, 0, std::vector{})}; + return success(); + } - SmallVector isColMajorArrOperands = {true}; - SmallVector operandRanks = {inputRank}; - SmallVector isColMajorArrOutputs = {true, true, true}; - SmallVector outputRanks = {inputRank, pivotRank, infoRank}; + LogicalResult matchAndRewriteCUDA(enzymexla::LUFactorizationOp op, + PatternRewriter &rewriter) const { + auto ctx = op->getContext(); - auto pivotCuSolverType = - RankedTensorType::get(pivotType.getShape(), rewriter.getI32Type()); - auto infoCuSolverType = - RankedTensorType::get(infoType.getShape(), rewriter.getI32Type()); + auto input = op.getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); + auto inputRank = inputType.getRank(); + auto numBatchDims = inputRank - 2; - auto cusolverffi = stablehlo::CustomCallOp::create( - rewriter, op.getLoc(), - TypeRange{inputType, pivotCuSolverType, infoCuSolverType}, - ValueRange{input}, rewriter.getStringAttr("cusolver_getrf_ffi"), + auto pivotType = cast(op.getResult(1).getType()); + auto pivotRank = pivotType.getRank(); + auto infoType = cast(op.getResult(3).getType()); + auto infoRank = infoType.getRank(); + + SmallVector aliases = {stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{0}, 0, std::vector{})}; + + SmallVector isColMajorArrOperands = {true}; + SmallVector operandRanks = {inputRank}; + SmallVector isColMajorArrOutputs = {true, true, true}; + SmallVector outputRanks = {inputRank, pivotRank, infoRank}; + + auto pivotCuSolverType = + RankedTensorType::get(pivotType.getShape(), rewriter.getI32Type()); + auto infoCuSolverType = + RankedTensorType::get(infoType.getShape(), rewriter.getI32Type()); + + auto cusolverffi = stablehlo::CustomCallOp::create( + rewriter, op.getLoc(), + TypeRange{inputType, pivotCuSolverType, infoCuSolverType}, + ValueRange{input}, rewriter.getStringAttr("cusolver_getrf_ffi"), + /*has_side_effect*/ nullptr, + /*backend_config*/ nullptr, + /*api_version*/ + stablehlo::CustomCallApiVersionAttr::get( + rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), + /*calledcomputations*/ nullptr, + /*operand_layouts*/ + getSHLOLayout(rewriter, operandRanks, isColMajorArrOperands, inputRank), + /*result_layouts*/ + getSHLOLayout(rewriter, outputRanks, isColMajorArrOutputs, inputRank), + /*output_operand_aliases*/ rewriter.getArrayAttr(aliases)); + + // unused custom call not getting optimized away. so adding a manual check + if (!op.getResult(2).getUses().empty()) { + auto pivots0indexed = stablehlo::SubtractOp::create( + rewriter, op.getLoc(), cusolverffi.getResult(1), + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), pivotCuSolverType, + cast(makeAttr(pivotCuSolverType, 1)))); + + SmallVector outputRanksPermutation = {pivotRank}; + + auto permutation = stablehlo::CustomCallOp::create( + rewriter, op.getLoc(), TypeRange{pivotCuSolverType}, + ValueRange{pivots0indexed.getResult()}, + rewriter.getStringAttr("cu_lu_pivots_to_permutation"), /*has_side_effect*/ nullptr, /*backend_config*/ nullptr, /*api_version*/ @@ -531,169 +530,145 @@ struct LUFactorizationOpLowering mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), /*calledcomputations*/ nullptr, /*operand_layouts*/ - getSHLOLayout(rewriter, operandRanks, isColMajorArrOperands, - inputRank), + getSHLOLayout(rewriter, SmallVector{pivotRank}, + SmallVector{true}, inputRank), /*result_layouts*/ - getSHLOLayout(rewriter, outputRanks, isColMajorArrOutputs, inputRank), - /*output_operand_aliases*/ rewriter.getArrayAttr(aliases)); - - // unused custom call not getting optimized away. so adding a manual check - if (!op.getResult(2).getUses().empty()) { - auto pivots0indexed = stablehlo::SubtractOp::create( - rewriter, op.getLoc(), cusolverffi.getResult(1), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), pivotCuSolverType, - cast(makeAttr(pivotCuSolverType, 1)))); - - SmallVector isColMajorArrOperandsPermutation = {true}; - SmallVector operandRanksPermutation = {pivotRank}; - SmallVector isColMajorArrOutputsPermutation = {true}; - SmallVector outputRanksPermutation = {pivotRank}; - - auto permutation = stablehlo::CustomCallOp::create( - rewriter, op.getLoc(), TypeRange{pivotCuSolverType}, - ValueRange{pivots0indexed.getResult()}, - rewriter.getStringAttr("cu_lu_pivots_to_permutation"), - /*has_side_effect*/ nullptr, - /*backend_config*/ nullptr, - /*api_version*/ - stablehlo::CustomCallApiVersionAttr::get( - rewriter.getContext(), - mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), - /*calledcomputations*/ nullptr, - /*operand_layouts*/ - getSHLOLayout(rewriter, operandRanksPermutation, - isColMajorArrOperandsPermutation, inputRank), - /*result_layouts*/ - getSHLOLayout(rewriter, outputRanksPermutation, - isColMajorArrOutputsPermutation, inputRank), - /*output_operand_aliases*/ nullptr); - auto permutation1Indexed = stablehlo::AddOp::create( - rewriter, op.getLoc(), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), pivotCuSolverType, - cast(makeAttr(pivotCuSolverType, 1))), - permutation.getResult(0)); - - rewriter.replaceAllUsesWith( - op.getResult(2), - stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, - permutation1Indexed)); - } + getSHLOLayout(rewriter, SmallVector{pivotRank}, + SmallVector{true}, inputRank), + /*output_operand_aliases*/ nullptr); + auto permutation1Indexed = stablehlo::AddOp::create( + rewriter, op.getLoc(), + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), pivotCuSolverType, + cast(makeAttr(pivotCuSolverType, 1))), + permutation.getResult(0)); - rewriter.replaceAllUsesWith(op.getResult(0), cusolverffi.getResult(0)); rewriter.replaceAllUsesWith( - op.getResult(1), + op.getResult(2), stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, - cusolverffi.getResult(1))); - rewriter.replaceAllUsesWith( - op.getResult(3), - stablehlo::ConvertOp::create(rewriter, op.getLoc(), infoType, - cusolverffi.getResult(2))); + permutation1Indexed)); + } - return success(); - } else if (backend == "tpu") { - SmallVector permutationShape; - for (int i = 0; i < numBatchDims; i++) { - permutationShape.push_back(inputShape[i]); - } - permutationShape.push_back(m); - auto permutationType = - RankedTensorType::get(permutationShape, rewriter.getI32Type()); + rewriter.replaceAllUsesWith(op.getResult(0), cusolverffi.getResult(0)); + rewriter.replaceAllUsesWith( + op.getResult(1), + stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, + cusolverffi.getResult(1))); + rewriter.replaceAllUsesWith( + op.getResult(3), + stablehlo::ConvertOp::create(rewriter, op.getLoc(), infoType, + cusolverffi.getResult(2))); - auto pivotTPUType = - RankedTensorType::get(pivotType.getShape(), rewriter.getI32Type()); + return success(); + } - // TPU returns (LU, pivots, permutation). info isn't returned. based on - // how JAX operates, I am assuming info != 0 when there is a nan in the - // output. - auto customCall = stablehlo::CustomCallOp::create( - rewriter, op.getLoc(), - TypeRange{inputType, pivotTPUType, permutationType}, - ValueRange{input}, rewriter.getStringAttr("LuDecomposition"), - /*has_side_effect*/ nullptr, - /*backend_config*/ nullptr, - /*api_version*/ nullptr, - /*calledcomputations*/ nullptr, - /*operand_layouts*/ nullptr, - /*result_layouts*/ nullptr, - /*output_operand_aliases*/ nullptr); + LogicalResult matchAndRewriteTPU(enzymexla::LUFactorizationOp op, + PatternRewriter &rewriter) const { + auto input = op.getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); - // LAPACK returns 1-indexed pivots, while XLA returns 0-indexed pivots. We - // make it consistent with LAPACK by adding 1 to the pivots. - auto pivots1Indexed = stablehlo::AddOp::create( - rewriter, op.getLoc(), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), pivotType, - cast(makeAttr(pivotType, 1))), - stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, - customCall.getResult(1))); + auto inputRank = inputType.getRank(); + auto numBatchDims = inputRank - 2; - auto permutation1Indexed = stablehlo::AddOp::create( - rewriter, op.getLoc(), - stablehlo::ConstantOp::create( - rewriter, op.getLoc(), permutationType, - cast(makeAttr(permutationType, 1))), - stablehlo::ConvertOp::create(rewriter, op.getLoc(), permutationType, - customCall.getResult(2))); + auto pivotType = cast(op.getResult(1).getType()); + auto infoType = cast(op.getResult(3).getType()); - auto isFinite = stablehlo::AndOp::create( - rewriter, op.getLoc(), - stablehlo::IsFiniteOp::create( - rewriter, op.getLoc(), - stablehlo::RealOp::create(rewriter, op.getLoc(), - customCall.getResult(0))), - stablehlo::IsFiniteOp::create( - rewriter, op.getLoc(), - stablehlo::ImagOp::create(rewriter, op.getLoc(), - customCall.getResult(0)))); - - SmallVector reductionDims; - for (int i = numBatchDims; i < inputRank; i++) - reductionDims.push_back(i); - auto initValType = RankedTensorType::get({}, rewriter.getI1Type()); - auto initVal = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), initValType, - cast(makeAttr(initValType, 1))); - - auto allFinite = stablehlo::ReduceOp::create( - rewriter, op.getLoc(), - RankedTensorType::get(infoType.getShape(), rewriter.getI1Type()), - ValueRange{isFinite.getResult()}, ValueRange{initVal}, - rewriter.getDenseI64ArrayAttr(reductionDims)); + SmallVector permutationShape(inputShape.begin(), + inputShape.end() - 2); + permutationShape.push_back(inputShape[inputRank - 2]); + auto permutationType = + RankedTensorType::get(permutationShape, rewriter.getI32Type()); - { - OpBuilder::InsertionGuard guard(rewriter); - auto ®ion = allFinite.getBody(); - auto *block = - rewriter.createBlock(®ion, {}, {initValType, initValType}, - {op.getLoc(), op.getLoc()}); + auto pivotTPUType = + RankedTensorType::get(pivotType.getShape(), rewriter.getI32Type()); - rewriter.setInsertionPointToStart(block); - auto lhs = block->getArgument(0); - auto rhs = block->getArgument(1); - auto andOp = stablehlo::AndOp::create(rewriter, op.getLoc(), lhs, rhs); + // TPU returns (LU, pivots, permutation). info isn't returned. based on + // how JAX operates, I am assuming info != 0 when there is a nan in the + // output. + auto customCall = stablehlo::CustomCallOp::create( + rewriter, op.getLoc(), + TypeRange{inputType, pivotTPUType, permutationType}, ValueRange{input}, + rewriter.getStringAttr("LuDecomposition"), + /*has_side_effect*/ nullptr, + /*backend_config*/ nullptr, + /*api_version*/ nullptr, + /*calledcomputations*/ nullptr, + /*operand_layouts*/ nullptr, + /*result_layouts*/ nullptr, + /*output_operand_aliases*/ nullptr); - stablehlo::ReturnOp::create(rewriter, op.getLoc(), - ValueRange{andOp.getResult()}); - } + // LAPACK returns 1-indexed pivots, while XLA returns 0-indexed pivots. We + // make it consistent with LAPACK by adding 1 to the pivots. + auto pivots1Indexed = stablehlo::AddOp::create( + rewriter, op.getLoc(), + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), pivotType, + cast(makeAttr(pivotType, 1))), + stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType, + customCall.getResult(1))); - // info == 0 if all finite (success) - auto info = stablehlo::ConvertOp::create( - rewriter, op.getLoc(), infoType, - stablehlo::NotOp::create(rewriter, op.getLoc(), - allFinite.getResult(0))); + auto permutation1Indexed = stablehlo::AddOp::create( + rewriter, op.getLoc(), + stablehlo::ConstantOp::create( + rewriter, op.getLoc(), permutationType, + cast(makeAttr(permutationType, 1))), + stablehlo::ConvertOp::create(rewriter, op.getLoc(), permutationType, + customCall.getResult(2))); - rewriter.replaceAllUsesWith(op.getResult(0), customCall.getResult(0)); - rewriter.replaceAllUsesWith(op.getResult(1), pivots1Indexed); - rewriter.replaceAllUsesWith(op.getResult(2), permutation1Indexed); - rewriter.replaceAllUsesWith(op.getResult(3), info); - rewriter.eraseOp(op); + auto isFinite = stablehlo::AndOp::create( + rewriter, op.getLoc(), + stablehlo::IsFiniteOp::create( + rewriter, op.getLoc(), + stablehlo::RealOp::create(rewriter, op.getLoc(), + customCall.getResult(0))), + stablehlo::IsFiniteOp::create( + rewriter, op.getLoc(), + stablehlo::ImagOp::create(rewriter, op.getLoc(), + customCall.getResult(0)))); + + SmallVector reductionDims; + for (int i = numBatchDims; i < inputRank; i++) + reductionDims.push_back(i); + auto initValType = RankedTensorType::get({}, rewriter.getI1Type()); + auto initVal = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), initValType, + cast(makeAttr(initValType, 1))); + + auto allFinite = stablehlo::ReduceOp::create( + rewriter, op.getLoc(), + RankedTensorType::get(infoType.getShape(), rewriter.getI1Type()), + ValueRange{isFinite.getResult()}, ValueRange{initVal}, + rewriter.getDenseI64ArrayAttr(reductionDims)); - return success(); - } else { - return rewriter.notifyMatchFailure(op, "Unknown backend " + backend); + { + OpBuilder::InsertionGuard guard(rewriter); + auto ®ion = allFinite.getBody(); + auto *block = rewriter.createBlock( + ®ion, {}, {initValType, initValType}, {op.getLoc(), op.getLoc()}); + + rewriter.setInsertionPointToStart(block); + auto lhs = block->getArgument(0); + auto rhs = block->getArgument(1); + auto andOp = stablehlo::AndOp::create(rewriter, op.getLoc(), lhs, rhs); + + stablehlo::ReturnOp::create(rewriter, op.getLoc(), + ValueRange{andOp.getResult()}); } + + // info == 0 if all finite (success) + auto info = stablehlo::ConvertOp::create( + rewriter, op.getLoc(), infoType, + stablehlo::NotOp::create(rewriter, op.getLoc(), + allFinite.getResult(0))); + + rewriter.replaceAllUsesWith(op.getResult(0), customCall.getResult(0)); + rewriter.replaceAllUsesWith(op.getResult(1), pivots1Indexed); + rewriter.replaceAllUsesWith(op.getResult(2), permutation1Indexed); + rewriter.replaceAllUsesWith(op.getResult(3), info); + rewriter.eraseOp(op); + + return success(); } }; @@ -710,14 +685,11 @@ struct SVDFactorizationOpLowering LogicalResult matchAndRewrite(enzymexla::SVDFactorizationOp op, PatternRewriter &rewriter) const override { if (backend == "cpu") - return this->matchAndRewrite_cpu(op, rewriter); - + return matchAndRewriteCPU(op, rewriter); else if (backend == "cuda") - return this->matchAndRewrite_cuda(op, rewriter); - + return matchAndRewriteCUDA(op, rewriter); else if (backend == "tpu") - return this->matchAndRewrite_tpu(op, rewriter); - + return matchAndRewriteTPU(op, rewriter); else return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend + "\""); @@ -726,8 +698,8 @@ struct SVDFactorizationOpLowering // TODO get matrix sizes dynamically so that we don't need to create a // function wrapper for each op instance // TODO support more SVD algorithms (e.g. `gesdd`, `gesvj`) - LogicalResult matchAndRewrite_cpu(enzymexla::SVDFactorizationOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCPU(enzymexla::SVDFactorizationOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); LLVMTypeConverter typeConverter(ctx); @@ -755,7 +727,7 @@ struct SVDFactorizationOpLowering // TODO change SVD method with attributes std::string fn = "gesvd_"; - if (auto prefix = lapack_precision_prefix(inputElementType)) { + if (auto prefix = lapackPrecisionPrefix(inputElementType)) { fn = *prefix + fn; } else { op->emitOpError() << "Unsupported complex element type: " @@ -940,10 +912,9 @@ struct SVDFactorizationOpLowering return success(); } - LogicalResult matchAndRewrite_cuda(enzymexla::SVDFactorizationOp op, - PatternRewriter &rewriter) const { + LogicalResult matchAndRewriteCUDA(enzymexla::SVDFactorizationOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); - LLVMTypeConverter typeConverter(ctx); auto type_lapack_int = rewriter.getIntegerType(blasIntWidth); @@ -1019,15 +990,9 @@ struct SVDFactorizationOpLowering return success(); } - // TODO find registered TPU kernel - LogicalResult matchAndRewrite_tpu(enzymexla::SVDFactorizationOp op, - PatternRewriter &rewriter) const { - - return rewriter.notifyMatchFailure( - op, "We don't know yet to which SVD TPU kernel to lower to :_("); - + LogicalResult matchAndRewriteTPU(enzymexla::SVDFactorizationOp op, + PatternRewriter &rewriter) const { auto ctx = op->getContext(); - LLVMTypeConverter typeConverter(ctx); auto input = op.getOperand(); auto type_input = cast(input.getType()); @@ -1043,7 +1008,7 @@ struct SVDFactorizationOpLowering auto custom_call_op = stablehlo::CustomCallOp::create( rewriter, op.getLoc(), TypeRange{type_input, type_tau}, - ValueRange{input}, rewriter.getStringAttr("Svd"), + ValueRange{input}, rewriter.getStringAttr("SVD"), /*has_side_effect*/ nullptr, /*backend_config*/ nullptr, /*api_version*/ nullptr, @@ -1076,8 +1041,8 @@ struct LowerEnzymeXLALinalgPass auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add(backend, blasIntWidth, context); - patterns.add(backend, blasIntWidth, context); + patterns.add( + backend, blasIntWidth, context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 0ce38ff15..42622da99 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -416,6 +416,7 @@ def LowerEnzymeXLALinalgPass : Pass<"lower-enzymexla-linalg"> { "stablehlo::StablehloDialect", "enzymexla::EnzymeXLADialect", "LLVM::LLVMDialect", + "enzyme::EnzymeDialect", ]; let options = [ diff --git a/test/lit_tests/linalg/lu.mlir b/test/lit_tests/linalg/lu.mlir index 9ce775d76..821abc11d 100644 --- a/test/lit_tests/linalg/lu.mlir +++ b/test/lit_tests/linalg/lu.mlir @@ -9,37 +9,36 @@ module { } } -// CPU: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -// CPU-NEXT: llvm.func @enzymexla_lapack_sgetrf_wrapper_[[WRAPPER_ID:[0-9]+]](%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(64 : i64) : i64 -// CPU-NEXT: %1 = llvm.mlir.constant(1 : i64) : i64 -// CPU-NEXT: %2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr -// CPU-NEXT: %3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr -// CPU-NEXT: llvm.store %0, %2 : i64, !llvm.ptr -// CPU-NEXT: llvm.store %0, %3 : i64, !llvm.ptr -// CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%2, %3, %arg0, %2, %arg1, %arg2) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> () +// CPU: func.func private @enzymexla_lapack_sgetrf_[[WRAPPER_ID:[0-9]+]](%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) { +// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<-1> : tensor<64xi64> +// CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor +// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper (%c_1, %c_1, %arg0, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor, tensor, tensor<64x64xf32>, tensor, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) +// CPU-NEXT: return %0#0, %0#1, %0#2 : tensor<64x64xf32>, tensor<64xi64>, tensor +// CPU-NEXT: } +// CPU-NEXT: llvm.func private @enzymexla_lapack_sgetrf_wrapper(%arg0: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg1: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg2: !llvm.ptr {llvm.nofree}, %arg3: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg4: !llvm.ptr {llvm.nofree, llvm.writeonly}, %arg5: !llvm.ptr {llvm.nofree, llvm.writeonly}) { +// CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> () // CPU-NEXT: llvm.return // CPU-NEXT: } +// CPU-NEXT: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) // CPU-NEXT: func.func @main(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xi32>, tensor<64xi32>, tensor) { // CPU-NEXT: %c = stablehlo.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64> // CPU-NEXT: %c_0 = stablehlo.constant dense<1> : tensor // CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor // CPU-NEXT: %c_2 = stablehlo.constant dense<1> : tensor<64xi64> // CPU-NEXT: %c_3 = stablehlo.constant dense<0> : tensor -// CPU-NEXT: %c_4 = stablehlo.constant dense<-1> : tensor<64xi64> -// CPU-NEXT: %c_5 = stablehlo.constant dense<-1> : tensor -// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper_[[WRAPPER_ID]] (%arg0, %c_4, %c_5) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x64xf32>, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) +// CPU-NEXT: %0:3 = call @enzymexla_lapack_sgetrf_[[WRAPPER_ID]](%arg0) : (tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) // CPU-NEXT: %1 = stablehlo.subtract %0#1, %c_2 : tensor<64xi64> -// CPU-NEXT: %2:2 = stablehlo.while(%iterArg = %c_3, %iterArg_6 = %c) : tensor, tensor<64xi64> -// CPU-NEXT: cond { +// CPU-NEXT: %2:2 = stablehlo.while(%iterArg = %c_3, %iterArg_4 = %c) : tensor, tensor<64xi64> +// CPU-NEXT: cond { // CPU-NEXT: %7 = stablehlo.compare LT, %iterArg, %c_1 : (tensor, tensor) -> tensor // CPU-NEXT: stablehlo.return %7 : tensor // CPU-NEXT: } do { // CPU-NEXT: %7 = stablehlo.add %iterArg, %c_0 : tensor // CPU-NEXT: %8 = stablehlo.dynamic_slice %1, %iterArg, sizes = [1] : (tensor<64xi64>, tensor) -> tensor<1xi64> -// CPU-NEXT: %9 = stablehlo.dynamic_slice %iterArg_6, %iterArg, sizes = [1] : (tensor<64xi64>, tensor) -> tensor<1xi64> -// CPU-NEXT: %10 = "stablehlo.gather"(%iterArg_6, %8) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<64xi64>, tensor<1xi64>) -> tensor<1xi64> -// CPU-NEXT: %11 = stablehlo.dynamic_update_slice %iterArg_6, %10, %iterArg : (tensor<64xi64>, tensor<1xi64>, tensor) -> tensor<64xi64> +// CPU-NEXT: %9 = stablehlo.dynamic_slice %iterArg_4, %iterArg, sizes = [1] : (tensor<64xi64>, tensor) -> tensor<1xi64> +// CPU-NEXT: %10 = "stablehlo.gather"(%iterArg_4, %8) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<64xi64>, tensor<1xi64>) -> tensor<1xi64> +// CPU-NEXT: %11 = stablehlo.dynamic_update_slice %iterArg_4, %10, %iterArg : (tensor<64xi64>, tensor<1xi64>, tensor) -> tensor<64xi64> // CPU-NEXT: %12 = stablehlo.reshape %9 : (tensor<1xi64>) -> tensor // CPU-NEXT: %13 = "stablehlo.scatter"(%11, %8, %12) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ // CPU-NEXT: ^bb0(%arg1: tensor, %arg2: tensor): @@ -77,27 +76,27 @@ module { // TPU-NEXT: } module { + // CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_ // CPU: func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor) { - // CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_wrapper_[[WRAPPER_ID:[0-9]+]] %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<64xi32>, tensor) return %0#0, %0#1, %0#3 : tensor<64x64xf64>, tensor<64xi32>, tensor } } module { + // CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_ // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { - // CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_wrapper_[[WRAPPER_ID:[0-9]+]] %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor } } module { + // CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_ // CPU: func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { func.func @main(%arg0: tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor) { - // CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_wrapper_[[WRAPPER_ID:[0-9]+]] %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex>) -> (tensor<64x64xcomplex>, tensor<64xi32>, tensor<64xi32>, tensor) return %0#0, %0#1, %0#3 : tensor<64x64xcomplex>, tensor<64xi32>, tensor } diff --git a/test/lit_tests/linalg/lu_batched.mlir b/test/lit_tests/linalg/lu_batched.mlir index 1af0820a8..39e56fa06 100644 --- a/test/lit_tests/linalg/lu_batched.mlir +++ b/test/lit_tests/linalg/lu_batched.mlir @@ -9,71 +9,72 @@ module { } } -// CPU: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -// CPU-NEXT: llvm.func @enzymexla_lapack_sgetrf_wrapper_0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { -// CPU-NEXT: %0 = llvm.mlir.constant(64 : i64) : i64 -// CPU-NEXT: %1 = llvm.mlir.constant(1 : i64) : i64 -// CPU-NEXT: %2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr -// CPU-NEXT: %3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr -// CPU-NEXT: llvm.store %0, %2 : i64, !llvm.ptr -// CPU-NEXT: llvm.store %0, %3 : i64, !llvm.ptr -// CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%2, %3, %arg0, %2, %arg1, %arg2) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> () +// CPU: llvm.func private @enzymexla_lapack_sgetrf_wrapper(%arg0: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg1: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg2: !llvm.ptr {llvm.nofree}, %arg3: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg4: !llvm.ptr {llvm.nofree, llvm.writeonly}, %arg5: !llvm.ptr {llvm.nofree, llvm.writeonly}) { +// CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> () // CPU-NEXT: llvm.return // CPU-NEXT: } +// CPU: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) // CPU-NEXT: func.func @main(%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi32>, tensor<4x3x64xi32>, tensor<4x3xi32>) { -// CPU: %c_0 = stablehlo.constant dense<64> : tensor -// CPU-NEXT: %c_1 = stablehlo.constant dense<1> : tensor<4x3x64xi64> -// CPU-NEXT: %c_2 = stablehlo.constant dense<1> : tensor -// CPU-NEXT: %c_3 = stablehlo.constant dense<-1> : tensor -// CPU-NEXT: %c_4 = stablehlo.constant dense<-1> : tensor<64xi64> -// CPU-NEXT: %c_5 = stablehlo.constant dense<12> : tensor -// CPU-NEXT: %c_6 = stablehlo.constant dense<-1> : tensor<12xi64> -// CPU-NEXT: %c_7 = stablehlo.constant dense<-1> : tensor<12x64xi64> -// CPU-NEXT: %c_8 = stablehlo.constant dense<0> : tensor -// CPU-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<4x3x64x64xf32>) -> tensor<12x64x64xf32> -// CPU-NEXT: %1:4 = stablehlo.while(%iterArg = %c_8, %iterArg_9 = %0, %iterArg_10 = %c_7, %iterArg_11 = %c_6) : tensor, tensor<12x64x64xf32>, tensor<12x64xi64>, tensor<12xi64> -// CPU-NEXT: cond { -// CPU-NEXT: %11 = stablehlo.compare LT, %iterArg, %c_5 : (tensor, tensor) -> tensor -// CPU-NEXT: stablehlo.return %11 : tensor -// CPU-NEXT: } do { -// CPU-NEXT: %11 = stablehlo.dynamic_slice %iterArg_9, %iterArg, %c_8, %c_8, sizes = [1, 64, 64] : (tensor<12x64x64xf32>, tensor, tensor, tensor) -> tensor<1x64x64xf32> -// CPU-NEXT: %12 = stablehlo.reshape %11 : (tensor<1x64x64xf32>) -> tensor<64x64xf32> -// CPU-NEXT: %13:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper_0 (%12, %c_4, %c_3) {operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<64x64xf32>, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) -// CPU-NEXT: %14 = stablehlo.reshape %13#0 : (tensor<64x64xf32>) -> tensor<1x64x64xf32> -// CPU-NEXT: %15 = stablehlo.dynamic_update_slice %iterArg_9, %14, %iterArg, %c_8, %c_8 : (tensor<12x64x64xf32>, tensor<1x64x64xf32>, tensor, tensor, tensor) -> tensor<12x64x64xf32> -// CPU-NEXT: %16 = stablehlo.reshape %13#1 : (tensor<64xi64>) -> tensor<1x64xi64> -// CPU-NEXT: %17 = stablehlo.dynamic_update_slice %iterArg_10, %16, %iterArg, %c_8 : (tensor<12x64xi64>, tensor<1x64xi64>, tensor, tensor) -> tensor<12x64xi64> -// CPU-NEXT: %18 = stablehlo.reshape %13#2 : (tensor) -> tensor<1xi64> -// CPU-NEXT: %19 = stablehlo.dynamic_update_slice %iterArg_11, %18, %iterArg : (tensor<12xi64>, tensor<1xi64>, tensor) -> tensor<12xi64> -// CPU-NEXT: %20 = stablehlo.add %iterArg, %c_2 : tensor -// CPU-NEXT: stablehlo.return %20, %15, %17, %19 : tensor, tensor<12x64x64xf32>, tensor<12x64xi64>, tensor<12xi64> -// CPU-NEXT: } -// CPU-NEXT: %2 = stablehlo.reshape %1#1 : (tensor<12x64x64xf32>) -> tensor<4x3x64x64xf32> -// CPU-NEXT: %3 = stablehlo.reshape %1#2 : (tensor<12x64xi64>) -> tensor<4x3x64xi64> -// CPU-NEXT: %4 = stablehlo.convert %1#3 : (tensor<12xi64>) -> tensor<12xi32> -// CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<12xi32>) -> tensor<4x3xi32> -// CPU-NEXT: %6 = stablehlo.subtract %3, %c_1 : tensor<4x3x64xi64> -// CPU-NEXT: %7:2 = stablehlo.while(%iterArg = %c_8, %iterArg_9 = %c) : tensor, tensor<4x3x64xi64> -// CPU-NEXT: cond { -// CPU-NEXT: %11 = stablehlo.compare LT, %iterArg, %c_0 : (tensor, tensor) -> tensor -// CPU-NEXT: stablehlo.return %11 : tensor +// CPU: %c_0 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor +// CPU-NEXT: %c_2 = stablehlo.constant dense<1> : tensor<4x3x64xi64> +// CPU-NEXT: %c_3 = stablehlo.constant dense<0> : tensor +// CPU-NEXT: %0:3 = call @batched_enzymexla_lapack_sgetrf_[[WRAPPER_ID:[0-9]+]](%arg0) : (tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>) +// CPU-NEXT: %1 = stablehlo.subtract %0#1, %c_2 : tensor<4x3x64xi64> +// CPU-NEXT: %2:2 = stablehlo.while(%iterArg = %c_3, %iterArg_4 = %c) : tensor, tensor<4x3x64xi64> +// CPU-NEXT: cond { +// CPU-NEXT: %7 = stablehlo.compare LT, %iterArg, %c_1 : (tensor, tensor) -> tensor +// CPU-NEXT: stablehlo.return %7 : tensor // CPU-NEXT: } do { -// CPU-NEXT: %11 = stablehlo.add %iterArg, %c_2 : tensor -// CPU-NEXT: %12 = stablehlo.dynamic_slice %6, %c_8, %c_8, %iterArg, sizes = [4, 3, 1] : (tensor<4x3x64xi64>, tensor, tensor, tensor) -> tensor<4x3x1xi64> -// CPU-NEXT: %13 = stablehlo.dynamic_slice %iterArg_9, %c_8, %c_8, %iterArg, sizes = [4, 3, 1] : (tensor<4x3x64xi64>, tensor, tensor, tensor) -> tensor<4x3x1xi64> -// CPU-NEXT: %14 = "stablehlo.gather"(%iterArg_9, %12) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<4x3x64xi64>, tensor<4x3x1xi64>) -> tensor<4x3x1xi64> -// CPU-NEXT: %15 = stablehlo.dynamic_update_slice %iterArg_9, %14, %c_8, %c_8, %iterArg : (tensor<4x3x64xi64>, tensor<4x3x1xi64>, tensor, tensor, tensor) -> tensor<4x3x64xi64> -// CPU-NEXT: %16 = stablehlo.reshape %13 : (tensor<4x3x1xi64>) -> tensor<4x3xi64> -// CPU-NEXT: %17 = "stablehlo.scatter"(%15, %12, %16) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ +// CPU-NEXT: %7 = stablehlo.add %iterArg, %c_0 : tensor +// CPU-NEXT: %8 = stablehlo.dynamic_slice %1, %c_3, %c_3, %iterArg, sizes = [4, 3, 1] : (tensor<4x3x64xi64>, tensor, tensor, tensor) -> tensor<4x3x1xi64> +// CPU-NEXT: %9 = stablehlo.dynamic_slice %iterArg_4, %c_3, %c_3, %iterArg, sizes = [4, 3, 1] : (tensor<4x3x64xi64>, tensor, tensor, tensor) -> tensor<4x3x1xi64> +// CPU-NEXT: %10 = "stablehlo.gather"(%iterArg_4, %8) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<4x3x64xi64>, tensor<4x3x1xi64>) -> tensor<4x3x1xi64> +// CPU-NEXT: %11 = stablehlo.dynamic_update_slice %iterArg_4, %10, %c_3, %c_3, %iterArg : (tensor<4x3x64xi64>, tensor<4x3x1xi64>, tensor, tensor, tensor) -> tensor<4x3x64xi64> +// CPU-NEXT: %12 = stablehlo.reshape %9 : (tensor<4x3x1xi64>) -> tensor<4x3xi64> +// CPU-NEXT: %13 = "stablehlo.scatter"(%11, %8, %12) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ // CPU-NEXT: ^bb0(%arg1: tensor, %arg2: tensor): // CPU-NEXT: stablehlo.return %arg2 : tensor // CPU-NEXT: }) : (tensor<4x3x64xi64>, tensor<4x3x1xi64>, tensor<4x3xi64>) -> tensor<4x3x64xi64> -// CPU-NEXT: stablehlo.return %11, %17 : tensor, tensor<4x3x64xi64> +// CPU-NEXT: stablehlo.return %7, %13 : tensor, tensor<4x3x64xi64> +// CPU-NEXT: } +// CPU-NEXT: %3 = stablehlo.add %2#1, %c_2 : tensor<4x3x64xi64> +// CPU-NEXT: %4 = stablehlo.convert %0#1 : (tensor<4x3x64xi64>) -> tensor<4x3x64xi32> +// CPU-NEXT: %5 = stablehlo.convert %3 : (tensor<4x3x64xi64>) -> tensor<4x3x64xi32> +// CPU-NEXT: %6 = stablehlo.convert %0#2 : (tensor<4x3xi64>) -> tensor<4x3xi32> +// CPU-NEXT: return %0#0, %4, %5, %6 : tensor<4x3x64x64xf32>, tensor<4x3x64xi32>, tensor<4x3x64xi32>, tensor<4x3xi32> +// CPU-NEXT: } +// CPU-NEXT: func.func private @batched_enzymexla_lapack_sgetrf_[[WRAPPER_ID]](%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64>) { +// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<-1> : tensor<64xi64> +// CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor +// CPU-NEXT: %c_2 = stablehlo.constant dense<4> : tensor +// CPU-NEXT: %c_3 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %c_4 = stablehlo.constant dense<12> : tensor +// CPU-NEXT: %cst = arith.constant dense<0> : tensor<4x3xi64> +// CPU-NEXT: %cst_5 = arith.constant dense<0> : tensor<4x3x64xi64> +// CPU-NEXT: %cst_6 = arith.constant dense<0.000000e+00> : tensor<4x3x64x64xf32> +// CPU-NEXT: %c_7 = stablehlo.constant dense<0> : tensor +// CPU-NEXT: %0:4 = stablehlo.while(%iterArg = %c_7, %iterArg_8 = %cst_6, %iterArg_9 = %cst_5, %iterArg_10 = %cst) : tensor, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> +// CPU-NEXT: cond { +// CPU-NEXT: %1 = stablehlo.compare LT, %iterArg, %c_4 : (tensor, tensor) -> tensor +// CPU-NEXT: stablehlo.return %1 : tensor +// CPU-NEXT: } do { +// CPU-NEXT: %1 = stablehlo.add %iterArg, %c_3 : tensor +// CPU-NEXT: %2 = stablehlo.remainder %iterArg, %c_2 : tensor +// CPU-NEXT: %3 = stablehlo.divide %iterArg, %c_2 : tensor +// CPU-NEXT: %4 = stablehlo.dynamic_slice %arg0, %2, %3, %c_7, %c_7, sizes = [1, 1, 64, 64] : (tensor<4x3x64x64xf32>, tensor, tensor, tensor, tensor) -> tensor<1x1x64x64xf32> +// CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<1x1x64x64xf32>) -> tensor<64x64xf32> +// CPU-NEXT: %6:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper (%c_1, %c_1, %5, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor, tensor, tensor<64x64xf32>, tensor, tensor<64xi64>, tensor) -> (tensor<64x64xf32>, tensor<64xi64>, tensor) +// CPU-NEXT: %7 = stablehlo.reshape %6#0 : (tensor<64x64xf32>) -> tensor<1x1x64x64xf32> +// CPU-NEXT: %8 = stablehlo.dynamic_update_slice %iterArg_8, %7, %2, %3, %c_7, %c_7 : (tensor<4x3x64x64xf32>, tensor<1x1x64x64xf32>, tensor, tensor, tensor, tensor) -> tensor<4x3x64x64xf32> +// CPU-NEXT: %9 = stablehlo.reshape %6#1 : (tensor<64xi64>) -> tensor<1x1x64xi64> +// CPU-NEXT: %10 = stablehlo.dynamic_update_slice %iterArg_9, %9, %2, %3, %c_7 : (tensor<4x3x64xi64>, tensor<1x1x64xi64>, tensor, tensor, tensor) -> tensor<4x3x64xi64> +// CPU-NEXT: %11 = stablehlo.reshape %6#2 : (tensor) -> tensor<1x1xi64> +// CPU-NEXT: %12 = stablehlo.dynamic_update_slice %iterArg_10, %11, %2, %3 : (tensor<4x3xi64>, tensor<1x1xi64>, tensor, tensor) -> tensor<4x3xi64> +// CPU-NEXT: stablehlo.return %1, %8, %10, %12 : tensor, tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> // CPU-NEXT: } -// CPU-NEXT: %8 = stablehlo.add %7#1, %c_1 : tensor<4x3x64xi64> -// CPU-NEXT: %9 = stablehlo.convert %3 : (tensor<4x3x64xi64>) -> tensor<4x3x64xi32> -// CPU-NEXT: %10 = stablehlo.convert %8 : (tensor<4x3x64xi64>) -> tensor<4x3x64xi32> -// CPU-NEXT: return %2, %9, %10, %5 : tensor<4x3x64x64xf32>, tensor<4x3x64xi32>, tensor<4x3x64xi32>, tensor<4x3xi32> +// CPU-NEXT: return %0#1, %0#2, %0#3 : tensor<4x3x64x64xf32>, tensor<4x3x64xi64>, tensor<4x3xi64> // CPU-NEXT: }