diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 969154abe8830..79b102b43a15f 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -183,10 +183,20 @@ class ConvertToLLVMPattern : public ConversionPattern { ArrayRef sizes, ArrayRef strides, ConversionPatternRewriter &rewriter) const; + /// Copies the given unranked memory descriptor to heap-allocated memory (if + /// toDynamic is true) or to stack-allocated memory (otherwise) and returns + /// the new descriptor. Also frees the previously used memory (that is assumed + /// to be heap-allocated) if toDynamic is false. Returns a "null" SSA value + /// on failure. + Value copyUnrankedDescriptor(OpBuilder &builder, Location loc, + UnrankedMemRefType memRefType, Value operand, + bool toDynamic) const; + /// Copies the memory descriptor for any operands that were unranked /// descriptors originally to heap-allocated memory (if toDynamic is true) or - /// to stack-allocated memory (otherwise). Also frees the previously used - /// memory (that is assumed to be heap-allocated) if toDynamic is false. + /// to stack-allocated memory (otherwise). The vector of descriptors is + /// updated in place. Also frees the previously used memory (that is assumed + /// to be heap-allocated) if toDynamic is false. LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl &operands, diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 67bb1c14c99a2..a4a6ae250640f 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -688,28 +688,34 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { auto funcOp = op->getParentOfType(); bool useBarePtrCallConv = shouldUseBarePtrCallConv(funcOp, this->getTypeConverter()); - if (useBarePtrCallConv) { - // For the bare-ptr calling convention, extract the aligned pointer to - // be returned from the memref descriptor. - for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { - Type oldTy = std::get<0>(it).getType(); - Value newOperand = std::get<1>(it); - if (isa(oldTy) && getTypeConverter()->canConvertToBarePtr( - cast(oldTy))) { + + for (auto [oldOperand, newOperand] : + llvm::zip_equal(op->getOperands(), adaptor.getOperands())) { + Type oldTy = oldOperand.getType(); + if (auto memRefType = dyn_cast(oldTy)) { + if (useBarePtrCallConv && + getTypeConverter()->canConvertToBarePtr(memRefType)) { + // For the bare-ptr calling convention, extract the aligned pointer to + // be returned from the memref descriptor. MemRefDescriptor memrefDesc(newOperand); - newOperand = memrefDesc.allocatedPtr(rewriter, loc); - } else if (isa(oldTy)) { + updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc)); + continue; + } + } else if (auto unrankedMemRefType = + dyn_cast(oldTy)) { + if (useBarePtrCallConv) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); } - updatedOperands.push_back(newOperand); + Value updatedDesc = copyUnrankedDescriptor( + rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true); + if (!updatedDesc) + return failure(); + updatedOperands.push_back(updatedDesc); + continue; } - } else { - updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); - (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), - updatedOperands, - /*toDynamic=*/true); + updatedOperands.push_back(newOperand); } // If ReturnOp has 0 or 1 operand, create it and return immediately. diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 72f41fd01fe7c..48a03198fd465 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -216,28 +216,14 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( return memRefDescriptor; } -LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( - OpBuilder &builder, Location loc, TypeRange origTypes, - SmallVectorImpl &operands, bool toDynamic) const { - assert(origTypes.size() == operands.size() && - "expected as may original types as operands"); - - // Find operands of unranked memref type and store them. - SmallVector unrankedMemrefs; - SmallVector unrankedAddressSpaces; - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (auto memRefType = dyn_cast(origTypes[i])) { - unrankedMemrefs.emplace_back(operands[i]); - FailureOr addressSpace = - getTypeConverter()->getMemRefAddressSpace(memRefType); - if (failed(addressSpace)) - return failure(); - unrankedAddressSpaces.emplace_back(*addressSpace); - } - } - - if (unrankedMemrefs.empty()) - return success(); +Value ConvertToLLVMPattern::copyUnrankedDescriptor( + OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, + Value operand, bool toDynamic) const { + // Convert memory space. + FailureOr addressSpace = + getTypeConverter()->getMemRefAddressSpace(memRefType); + if (failed(addressSpace)) + return {}; // Get frequently used types. Type indexType = getTypeConverter()->getIndexType(); @@ -248,54 +234,61 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( if (toDynamic) { mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType); if (failed(mallocFunc)) - return failure(); + return {}; } if (!toDynamic) { freeFunc = LLVM::lookupOrCreateFreeFn(builder, module); if (failed(freeFunc)) - return failure(); + return {}; } - unsigned unrankedMemrefPos = 0; - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - Type type = origTypes[i]; - if (!isa(type)) - continue; - UnrankedMemRefDescriptor desc(operands[i]); - Value allocationSize = UnrankedMemRefDescriptor::computeSize( - builder, loc, *getTypeConverter(), desc, - unrankedAddressSpaces[unrankedMemrefPos++]); - - // Allocate memory, copy, and free the source if necessary. - Value memory = - toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(), - allocationSize) - .getResult() - : LLVM::AllocaOp::create(builder, loc, getPtrType(), - IntegerType::get(getContext(), 8), - allocationSize, - /*alignment=*/0); - Value source = desc.memRefDescPtr(builder, loc); - LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); - if (!toDynamic) - LLVM::CallOp::create(builder, loc, freeFunc.value(), source); - - // Create a new descriptor. The same descriptor can be returned multiple - // times, attempting to modify its pointer can lead to memory leaks - // (allocated twice and overwritten) or double frees (the caller does not - // know if the descriptor points to the same memory). - Type descriptorType = getTypeConverter()->convertType(type); - if (!descriptorType) - return failure(); - auto updatedDesc = - UnrankedMemRefDescriptor::poison(builder, loc, descriptorType); - Value rank = desc.rank(builder, loc); - updatedDesc.setRank(builder, loc, rank); - updatedDesc.setMemRefDescPtr(builder, loc, memory); + UnrankedMemRefDescriptor desc(operand); + Value allocationSize = UnrankedMemRefDescriptor::computeSize( + builder, loc, *getTypeConverter(), desc, *addressSpace); + + // Allocate memory, copy, and free the source if necessary. + Value memory = toDynamic + ? LLVM::CallOp::create(builder, loc, mallocFunc.value(), + allocationSize) + .getResult() + : LLVM::AllocaOp::create(builder, loc, getPtrType(), + IntegerType::get(getContext(), 8), + allocationSize, + /*alignment=*/0); + Value source = desc.memRefDescPtr(builder, loc); + LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); + if (!toDynamic) + LLVM::CallOp::create(builder, loc, freeFunc.value(), source); + + // Create a new descriptor. The same descriptor can be returned multiple + // times, attempting to modify its pointer can lead to memory leaks + // (allocated twice and overwritten) or double frees (the caller does not + // know if the descriptor points to the same memory). + Type descriptorType = getTypeConverter()->convertType(memRefType); + if (!descriptorType) + return {}; + auto updatedDesc = + UnrankedMemRefDescriptor::poison(builder, loc, descriptorType); + Value rank = desc.rank(builder, loc); + updatedDesc.setRank(builder, loc, rank); + updatedDesc.setMemRefDescPtr(builder, loc, memory); + return updatedDesc; +} - operands[i] = updatedDesc; +LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( + OpBuilder &builder, Location loc, TypeRange origTypes, + SmallVectorImpl &operands, bool toDynamic) const { + assert(origTypes.size() == operands.size() && + "expected as may original types as operands"); + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + if (auto memRefType = dyn_cast(origTypes[i])) { + Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType, + operands[i], toDynamic); + if (!updatedDesc) + return failure(); + operands[i] = updatedDesc; + } } - return success(); }