From eafdb3d2052a88bfd4a2e348aad36acff5ea5a93 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 14 Aug 2025 13:45:48 +0000 Subject: [PATCH] [mlir][LLVM] FuncToLLVM: Add 1:N support --- .../Conversion/LLVMCommon/TypeConverter.h | 31 +++-- mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 109 +++++++++++------ .../Conversion/LLVMCommon/TypeConverter.cpp | 115 ++++++++++-------- .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 24 ++-- .../MemRefToLLVM/type-conversion.mlir | 97 +++++++++++++-- mlir/test/lib/Dialect/LLVM/TestPatterns.cpp | 30 +++++ 6 files changed, 281 insertions(+), 125 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h index 38b5e492a8ed8..2096bcb9896a5 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -74,8 +74,14 @@ class LLVMTypeConverter : public TypeConverter { /// LLVM-compatible type. In particular, if more than one value is returned, /// create an LLVM dialect structure type with elements that correspond to /// each of the types converted with `convertCallingConventionType`. - Type packFunctionResults(TypeRange types, - bool useBarePointerCallConv = false) const; + /// + /// Populate the converted (unpacked) types into `groupedTypes`, if provided. + /// `groupedType` contains one nested vector per input type. In case of a 1:N + /// conversion, a nested vector may contain 0 or more then 1 converted type. + Type + packFunctionResults(TypeRange types, bool useBarePointerCallConv = false, + SmallVector> *groupedTypes = nullptr, + int64_t *numConvertedTypes = nullptr) const; /// Convert a non-empty list of types of values produced by an operation into /// an LLVM-compatible type. In particular, if more than one value is @@ -88,15 +94,9 @@ class LLVMTypeConverter : public TypeConverter { /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. - Type convertCallingConventionType(Type type, - bool useBarePointerCallConv = false) const; - - /// Promote the bare pointers in 'values' that resulted from memrefs to - /// descriptors. 'stdTypes' holds the types of 'values' before the conversion - /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). - void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, - Location loc, ArrayRef stdTypes, - SmallVectorImpl &values) const; + LogicalResult + convertCallingConventionType(Type type, SmallVectorImpl &result, + bool useBarePointerCallConv = false) const; /// Returns the MLIR context. MLIRContext &getContext() const; @@ -109,9 +109,14 @@ class LLVMTypeConverter : public TypeConverter { /// Promote the LLVM representation of all operands including promoting MemRef /// descriptors to stack and use pointers to struct to avoid the complexity /// of the platform-specific C/C++ ABI lowering related to struct argument - /// passing. + /// passing. (The ArrayRef variant is for 1:N.) + SmallVector promoteOperands(Location loc, ValueRange opOperands, + ArrayRef adaptorOperands, + OpBuilder &builder, + bool useBarePtrCallConv = false) const; SmallVector promoteOperands(Location loc, ValueRange opOperands, - ValueRange operands, OpBuilder &builder, + ValueRange adaptorOperands, + OpBuilder &builder, bool useBarePtrCallConv = false) const; /// Promote the LLVM struct representation of one MemRef descriptor to stack diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index a4a6ae250640f..42c76ed475b4c 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -527,19 +527,21 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = CallOpInterfaceLowering; using Base = ConvertOpToLLVMPattern; + using Adaptor = typename ConvertOpToLLVMPattern::OneToNOpAdaptor; - LogicalResult matchAndRewriteImpl(CallOpType callOp, - typename CallOpType::Adaptor adaptor, + LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor, ConversionPatternRewriter &rewriter, bool useBarePtrCallConv = false) const { // Pack the result types into a struct. Type packedResult = nullptr; + SmallVector> groupedResultTypes; unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); - + int64_t numConvertedTypes = 0; if (numResults != 0) { if (!(packedResult = this->getTypeConverter()->packFunctionResults( - resultTypes, useBarePtrCallConv))) + resultTypes, useBarePtrCallConv, &groupedResultTypes, + &numConvertedTypes))) return failure(); } @@ -565,34 +567,64 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { static_cast(promoted.size()), 0}; newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); - SmallVector results; - if (numResults < 2) { - // If < 2 results, packing did not do anything and we can just return. - results.append(newOp.result_begin(), newOp.result_end()); - } else { - // Otherwise, it had been converted to an operation producing a structure. - // Extract individual results from the structure and return them as list. - results.reserve(numResults); - for (unsigned i = 0; i < numResults; ++i) { - results.push_back(LLVM::ExtractValueOp::create( - rewriter, callOp.getLoc(), newOp->getResult(0), i)); + // Helper function that extracts an individual result from the return value + // of the new call op. llvm.call ops support only 0 or 1 result. In case of + // 2 or more results, the results are packed into a structure. + // + // The new call op may have more than 2 results because: + // a. The original call op has more than 2 results. + // b. An original op result type-converted to more than 1 result. + auto getUnpackedResult = [&](unsigned i) -> Value { + assert(numConvertedTypes > 0 && "convert op has no results"); + if (numConvertedTypes == 1) { + assert(i == 0 && "out of bounds: converted op has only one result"); + return newOp->getResult(0); } + // Results have been converted to a structure. Extract individual results + // from the structure. + return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(), + newOp->getResult(0), i); + }; + + // Group the results into a vector of vectors, such that it is clear which + // original op result is replaced with which range of values. (In case of a + // 1:N conversion, there can be multiple replacements for a single result.) + SmallVector> results; + results.reserve(numResults); + unsigned counter = 0; + for (unsigned i = 0; i < numResults; ++i) { + SmallVector &group = results.emplace_back(); + for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j) + group.push_back(getUnpackedResult(counter++)); } - if (useBarePtrCallConv) { - // For the bare-ptr calling convention, promote memref results to - // descriptors. - assert(results.size() == resultTypes.size() && - "The number of arguments and types doesn't match"); - this->getTypeConverter()->promoteBarePtrsToDescriptors( - rewriter, callOp.getLoc(), resultTypes, results); - } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(), - resultTypes, results, - /*toDynamic=*/false))) { - return failure(); + // Special handling for MemRef types. + for (unsigned i = 0; i < numResults; ++i) { + Type origType = resultTypes[i]; + auto memrefType = dyn_cast(origType); + auto unrankedMemrefType = dyn_cast(origType); + if (useBarePtrCallConv && memrefType) { + // For the bare-ptr calling convention, promote memref results to + // descriptors. + assert(results[i].size() == 1 && "expected one converted result"); + results[i].front() = MemRefDescriptor::fromStaticShape( + rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType, + results[i].front()); + } + if (unrankedMemrefType) { + assert(!useBarePtrCallConv && "unranked memref is not supported in the " + "bare-ptr calling convention"); + assert(results[i].size() == 1 && "expected one converted result"); + Value desc = this->copyUnrankedDescriptor( + rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(), + /*toDynamic=*/false); + if (!desc) + return failure(); + results[i].front() = desc; + } } - rewriter.replaceOp(callOp, results); + rewriter.replaceOpWithMultiple(callOp, results); return success(); } }; @@ -606,7 +638,7 @@ class CallOpLowering : public CallOpInterfaceLowering { symbolTables(symbolTables) {} LogicalResult - matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, + matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { bool useBarePtrCallConv = false; if (getTypeConverter()->getOptions().useBarePtrCallConv) { @@ -636,7 +668,7 @@ struct CallIndirectOpLowering using Super::Super; LogicalResult - matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor, + matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter); } @@ -679,47 +711,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - unsigned numArguments = op.getNumOperands(); SmallVector updatedOperands; auto funcOp = op->getParentOfType(); bool useBarePtrCallConv = shouldUseBarePtrCallConv(funcOp, this->getTypeConverter()); - for (auto [oldOperand, newOperand] : + for (auto [oldOperand, newOperands] : llvm::zip_equal(op->getOperands(), adaptor.getOperands())) { Type oldTy = oldOperand.getType(); if (auto memRefType = dyn_cast(oldTy)) { + assert(newOperands.size() == 1 && "expected one converted result"); if (useBarePtrCallConv && getTypeConverter()->canConvertToBarePtr(memRefType)) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. - MemRefDescriptor memrefDesc(newOperand); + MemRefDescriptor memrefDesc(newOperands.front()); updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc)); continue; } } else if (auto unrankedMemRefType = dyn_cast(oldTy)) { + assert(newOperands.size() == 1 && "expected one converted result"); if (useBarePtrCallConv) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); } - Value updatedDesc = copyUnrankedDescriptor( - rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true); + Value updatedDesc = + copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType, + newOperands.front(), /*toDynamic=*/true); if (!updatedDesc) return failure(); updatedOperands.push_back(updatedDesc); continue; } - updatedOperands.push_back(newOperand); + + llvm::append_range(updatedOperands, newOperands); } // If ReturnOp has 0 or 1 operand, create it and return immediately. - if (numArguments <= 1) { + if (updatedOperands.size() <= 1) { rewriter.replaceOpWithNewOp( op, TypeRange(), updatedOperands, op->getAttrs()); return success(); diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 1a9bf569086da..cb9dea108cc48 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -365,6 +365,7 @@ Type LLVMTypeConverter::convertFunctionSignatureImpl( useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv; auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter : structFuncArgTypeConverter; + // Convert argument types one by one and check for errors. for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) { SmallVector converted; @@ -658,27 +659,19 @@ FailureOr LLVMTypeConverter::convertVectorType(VectorType type) const { /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. -Type LLVMTypeConverter::convertCallingConventionType( - Type type, bool useBarePtrCallConv) const { - if (useBarePtrCallConv) - if (auto memrefTy = dyn_cast(type)) - return convertMemRefToBarePtr(memrefTy); - - return convertType(type); -} +LogicalResult LLVMTypeConverter::convertCallingConventionType( + Type type, SmallVectorImpl &result, bool useBarePtrCallConv) const { + if (useBarePtrCallConv) { + if (auto memrefTy = dyn_cast(type)) { + Type converted = convertMemRefToBarePtr(memrefTy); + if (!converted) + return failure(); + result.push_back(converted); + return success(); + } + } -/// Promote the bare pointers in 'values' that resulted from memrefs to -/// descriptors. 'stdTypes' holds they types of 'values' before the conversion -/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). -void LLVMTypeConverter::promoteBarePtrsToDescriptors( - ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, - SmallVectorImpl &values) const { - assert(stdTypes.size() == values.size() && - "The number of types and values doesn't match"); - for (unsigned i = 0, end = values.size(); i < end; ++i) - if (auto memrefTy = dyn_cast(stdTypes[i])) - values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, - memrefTy, values[i]); + return convertType(type, result); } /// Convert a non-empty list of types of values produced by an operation into an @@ -706,23 +699,35 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) const { /// LLVM-compatible type. In particular, if more than one value is returned, /// create an LLVM dialect structure type with elements that correspond to each /// of the types converted with `convertCallingConventionType`. -Type LLVMTypeConverter::packFunctionResults(TypeRange types, - bool useBarePtrCallConv) const { +Type LLVMTypeConverter::packFunctionResults( + TypeRange types, bool useBarePtrCallConv, + SmallVector> *groupedTypes, + int64_t *numConvertedTypes) const { assert(!types.empty() && "expected non-empty list of type"); + assert((!groupedTypes || groupedTypes->empty()) && + "expected groupedTypes to be empty"); useBarePtrCallConv |= options.useBarePtrCallConv; - if (types.size() == 1) - return convertCallingConventionType(types.front(), useBarePtrCallConv); - SmallVector resultTypes; resultTypes.reserve(types.size()); + size_t sizeBefore = 0; for (auto t : types) { - auto converted = convertCallingConventionType(t, useBarePtrCallConv); - if (!converted || !LLVM::isCompatibleType(converted)) + if (failed( + convertCallingConventionType(t, resultTypes, useBarePtrCallConv))) return {}; - resultTypes.push_back(converted); + if (groupedTypes) { + SmallVector &group = groupedTypes->emplace_back(); + llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore)); + } + sizeBefore = resultTypes.size(); } + if (numConvertedTypes) + *numConvertedTypes = resultTypes.size(); + if (resultTypes.size() == 1) + return resultTypes.front(); + if (resultTypes.empty()) + return {}; return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); } @@ -740,40 +745,50 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, return allocated; } -SmallVector -LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, - ValueRange operands, OpBuilder &builder, - bool useBarePtrCallConv) const { +SmallVector LLVMTypeConverter::promoteOperands( + Location loc, ValueRange opOperands, ValueRange adaptorOperands, + OpBuilder &builder, bool useBarePtrCallConv) const { + SmallVector ranges; + for (size_t i = 0, e = adaptorOperands.size(); i < e; i++) + ranges.push_back(adaptorOperands.slice(i, 1)); + return promoteOperands(loc, opOperands, ranges, builder, useBarePtrCallConv); +} + +SmallVector LLVMTypeConverter::promoteOperands( + Location loc, ValueRange opOperands, ArrayRef adaptorOperands, + OpBuilder &builder, bool useBarePtrCallConv) const { SmallVector promotedOperands; - promotedOperands.reserve(operands.size()); + promotedOperands.reserve(adaptorOperands.size()); useBarePtrCallConv |= options.useBarePtrCallConv; - for (auto it : llvm::zip(opOperands, operands)) { - auto operand = std::get<0>(it); - auto llvmOperand = std::get<1>(it); - + for (auto [operand, llvmOperand] : + llvm::zip_equal(opOperands, adaptorOperands)) { if (useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. if (isa(operand.getType())) { - MemRefDescriptor desc(llvmOperand); - llvmOperand = desc.alignedPtr(builder, loc); + assert(llvmOperand.size() == 1 && "Expected a single operand"); + MemRefDescriptor desc(llvmOperand.front()); + promotedOperands.push_back(desc.alignedPtr(builder, loc)); + continue; } else if (isa(operand.getType())) { llvm_unreachable("Unranked memrefs are not supported"); } } else { if (isa(operand.getType())) { - UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, + assert(llvmOperand.size() == 1 && "Expected a single operand"); + UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand.front(), promotedOperands); continue; } if (auto memrefType = dyn_cast(operand.getType())) { - MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType, + assert(llvmOperand.size() == 1 && "Expected a single operand"); + MemRefDescriptor::unpack(builder, loc, llvmOperand.front(), memrefType, promotedOperands); continue; } } - promotedOperands.push_back(llvmOperand); + llvm::append_range(promotedOperands, llvmOperand); } return promotedOperands; } @@ -802,11 +817,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, result.append(converted.begin(), converted.end()); return success(); } - auto converted = converter.convertType(type); - if (!converted) - return failure(); - result.push_back(converted); - return success(); + return converter.convertType(type, result); } /// Callback to convert function argument types. It converts MemRef function @@ -814,11 +825,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, LogicalResult mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { - auto llvmTy = converter.convertCallingConventionType( - type, /*useBarePointerCallConv=*/true); - if (!llvmTy) - return failure(); - - result.push_back(llvmTy); - return success(); + return converter.convertCallingConventionType( + type, result, + /*useBarePointerCallConv=*/true); } diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index f7f5381799529..c6c5ab356f256 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -1106,12 +1106,10 @@ struct NVGPUGenerateWarpgroupDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LDBG() << "Generating warpgroup.descriptor: " - << "leading_off:" << leadDimVal << "\t" - << "stride_off :" << strideDimVal << "\t" - << "base_offset:" << offsetVal << "\t" - << "layout_type:" << swizzle << " (" - << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) + LDBG() << "Generating warpgroup.descriptor: " << "leading_off:" + << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t" + << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle + << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) << ")\n start_addr : " << baseAddr; rewriter.replaceOp(op, dsc); @@ -1401,14 +1399,12 @@ struct NVGPUWarpgroupMmaOpLowering /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix /// descriptors and arranges them based on induction variables: i, j, and k. Value generateWgmma(int i, int j, int k, Value matrixC) { - LDBG() << "\t wgmma." - << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A[" - << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM - << "][" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "] * " - << " B[" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN - << "])"; + LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK + << "(A[" << (iterationM * wgmmaM) << ":" + << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK) + << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B[" + << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK) + << "][" << 0 << ":" << wgmmaN << "])"; Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir index 0288aa11313c7..c1751f282b002 100644 --- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir @@ -1,12 +1,13 @@ -// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file +// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-llvm-legalize-patterns="allow-pattern-rollback=0" -split-input-file | FileCheck %s // Test the argument materializer for ranked MemRef types. // CHECK-LABEL: func @construct_ranked_memref_descriptor( -// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-COUNT-7: llvm.insertvalue // CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32> -func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) { +func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) attributes {is_legal} { %0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>) "test.legal_op"(%0) : (memref<5x4xf32>) -> () return @@ -21,7 +22,7 @@ func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr // CHECK-LABEL: func @invalid_ranked_memref_descriptor( // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32> // CHECK: "test.legal_op"(%[[cast]]) -func.func @invalid_ranked_memref_descriptor(%arg0: i1) { +func.func @invalid_ranked_memref_descriptor(%arg0: i1) attributes {is_legal} { %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>) "test.legal_op"(%0) : (memref<5x4xf32>) -> () return @@ -32,10 +33,10 @@ func.func @invalid_ranked_memref_descriptor(%arg0: i1) { // Test the argument materializer for unranked MemRef types. // CHECK-LABEL: func @construct_unranked_memref_descriptor( -// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)> +// CHECK: llvm.mlir.poison : !llvm.struct<(i64, ptr)> // CHECK-COUNT-2: llvm.insertvalue // CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32> -func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) { +func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) attributes {is_legal} { %0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>) "test.legal_op"(%0) : (memref<*xf32>) -> () return @@ -50,8 +51,90 @@ func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) { // CHECK-LABEL: func @invalid_unranked_memref_descriptor( // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32> // CHECK: "test.legal_op"(%[[cast]]) -func.func @invalid_unranked_memref_descriptor(%arg0: i1) { +func.func @invalid_unranked_memref_descriptor(%arg0: i1) attributes {is_legal} { %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>) "test.legal_op"(%0) : (memref<*xf32>) -> () return } + +// ----- + +// CHECK-LABEL: llvm.func @simple_func_conversion( +// CHECK-SAME: %[[arg0:.*]]: i64) -> i64 +// CHECK: llvm.return %[[arg0]] : i64 +func.func @simple_func_conversion(%arg0: i64) -> i64 { + return %arg0 : i64 +} + +// ----- + +// CHECK-LABEL: llvm.func @one_to_n_argument_conversion( +// CHECK-SAME: %[[arg0:.*]]: i18, %[[arg1:.*]]: i18) +// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg0]], %[[arg1]] : i18, i18 to i17 +// CHECK: "test.legal_op"(%[[cast]]) : (i17) -> () +func.func @one_to_n_argument_conversion(%arg0: i17) { + "test.legal_op"(%arg0) : (i17) -> () + return +} + +// CHECK: llvm.func @caller(%[[arg0:.*]]: i18, %[[arg1:.*]]: i18) +// CHECK: llvm.call @one_to_n_argument_conversion(%[[arg0]], %[[arg1]]) : (i18, i18) -> () +func.func @caller(%arg0: i17) { + func.call @one_to_n_argument_conversion(%arg0) : (i17) -> () + return +} + +// ----- + +// CHECK-LABEL: llvm.func @one_to_n_return_conversion( +// CHECK-SAME: %[[arg0:.*]]: i18, %[[arg1:.*]]: i18) -> !llvm.struct<(i18, i18)> +// CHECK: %[[p1:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)> +// CHECK: %[[p2:.*]] = llvm.insertvalue %[[arg0]], %[[p1]][0] : !llvm.struct<(i18, i18)> +// CHECK: %[[p3:.*]] = llvm.insertvalue %[[arg1]], %[[p2]][1] : !llvm.struct<(i18, i18)> +// CHECK: llvm.return %[[p3]] +func.func @one_to_n_return_conversion(%arg0: i17) -> i17 { + return %arg0 : i17 +} + +// CHECK: llvm.func @caller(%[[arg0:.*]]: i18, %[[arg1:.*]]: i18) +// CHECK: %[[res:.*]] = llvm.call @one_to_n_return_conversion(%[[arg0]], %[[arg1]]) : (i18, i18) -> !llvm.struct<(i18, i18)> +// CHECK: %[[e0:.*]] = llvm.extractvalue %[[res]][0] : !llvm.struct<(i18, i18)> +// CHECK: %[[e1:.*]] = llvm.extractvalue %[[res]][1] : !llvm.struct<(i18, i18)> +// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)> +// CHECK: %[[i1:.*]] = llvm.insertvalue %[[e0]], %[[i0]][0] : !llvm.struct<(i18, i18)> +// CHECK: %[[i2:.*]] = llvm.insertvalue %[[e1]], %[[i1]][1] : !llvm.struct<(i18, i18)> +// CHECK: llvm.return %[[i2]] +func.func @caller(%arg0: i17) -> (i17) { + %res = func.call @one_to_n_return_conversion(%arg0) : (i17) -> (i17) + return %res : i17 +} + +// ----- + +// CHECK-LABEL: llvm.func @multi_return( +// CHECK-SAME: %[[arg0:.*]]: i18, %[[arg1:.*]]: i18, %[[arg2:.*]]: i1) -> !llvm.struct<(i18, i18, i1)> +// CHECK: %[[p1:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18, i1)> +// CHECK: %[[p2:.*]] = llvm.insertvalue %[[arg0]], %[[p1]][0] : !llvm.struct<(i18, i18, i1)> +// CHECK: %[[p3:.*]] = llvm.insertvalue %[[arg1]], %[[p2]][1] : !llvm.struct<(i18, i18, i1)> +// CHECK: %[[p4:.*]] = llvm.insertvalue %[[arg2]], %[[p3]][2] : !llvm.struct<(i18, i18, i1)> +// CHECK: llvm.return %[[p4]] +func.func @multi_return(%arg0: i17, %arg1: i1) -> (i17, i1) { + return %arg0, %arg1 : i17, i1 +} + +// CHECK: llvm.func @caller(%[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18) +// CHECK: %[[res:.*]] = llvm.call @multi_return(%[[arg1]], %[[arg2]], %[[arg0]]) : (i18, i18, i1) -> !llvm.struct<(i18, i18, i1)> +// CHECK: %[[e0:.*]] = llvm.extractvalue %[[res]][0] : !llvm.struct<(i18, i18, i1)> +// CHECK: %[[e1:.*]] = llvm.extractvalue %[[res]][1] : !llvm.struct<(i18, i18, i1)> +// CHECK: %[[e2:.*]] = llvm.extractvalue %[[res]][2] : !llvm.struct<(i18, i18, i1)> +// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18, i1, i18, i18)> +// CHECK: %[[i1:.*]] = llvm.insertvalue %[[e0]], %[[i0]][0] +// CHECK: %[[i2:.*]] = llvm.insertvalue %[[e1]], %[[i1]][1] +// CHECK: %[[i3:.*]] = llvm.insertvalue %[[e2]], %[[i2]][2] +// CHECK: %[[i4:.*]] = llvm.insertvalue %[[e0]], %[[i3]][3] +// CHECK: %[[i5:.*]] = llvm.insertvalue %[[e1]], %[[i4]][4] +// CHECK: llvm.return %[[i5]] +func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) { + %res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1) + return %res#0, %res#1, %res#0 : i17, i1, i17 +} diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp index ab02866970b1d..fe9aa0f2a9902 100644 --- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp @@ -6,7 +6,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Pass/Pass.h" @@ -34,6 +36,10 @@ struct TestLLVMLegalizePatternsPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass) + TestLLVMLegalizePatternsPass() = default; + TestLLVMLegalizePatternsPass(const TestLLVMLegalizePatternsPass &other) + : PassWrapper(other) {} + StringRef getArgument() const final { return "test-llvm-legalize-patterns"; } StringRef getDescription() const final { return "Run LLVM dialect legalization patterns"; @@ -45,22 +51,46 @@ struct TestLLVMLegalizePatternsPass void runOnOperation() override { MLIRContext *ctx = &getContext(); + + // Set up type converter. LLVMTypeConverter converter(ctx); + converter.addConversion( + [&](IntegerType type, SmallVectorImpl &result) { + if (type.isInteger(17)) { + // Convert i17 -> (i18, i18). + result.append(2, Builder(ctx).getIntegerType(18)); + return success(); + } + + result.push_back(type); + return success(); + }); + + // Populate patterns. mlir::RewritePatternSet patterns(ctx); patterns.add(ctx, converter); + populateFuncToLLVMConversionPatterns(converter, patterns); // Define the conversion target used for the test. ConversionTarget target(*ctx); target.addLegalOp(OperationName("test.legal_op", ctx)); + target.addLegalDialect(); + target.addDynamicallyLegalOp( + [&](func::FuncOp funcOp) { return funcOp->hasAttr("is_legal"); }); // Handle a partial conversion. DenseSet unlegalizedOps; ConversionConfig config; config.unlegalizedOps = &unlegalizedOps; + config.allowPatternRollback = allowPatternRollback; if (failed(applyPartialConversion(getOperation(), target, std::move(patterns), config))) getOperation()->emitError() << "applyPartialConversion failed"; } + + Option allowPatternRollback{*this, "allow-pattern-rollback", + llvm::cl::desc("Allow pattern rollback"), + llvm::cl::init(true)}; }; } // namespace