Skip to content

Conversation

@matthias-springer
Copy link
Member

Split the function into two: one that copies a single unranked descriptor and one that copies multiple unranked descriptors. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns.

@llvmbot
Copy link
Member

llvmbot commented Aug 14, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

Split the function into two: one that copies a single unranked descriptor and one that copies multiple unranked descriptors. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns.


Full diff: https://github.com/llvm/llvm-project/pull/153597.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+12-2)
  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+19-15)
  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+55-62)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 969154abe8830..8b72a6c5db9c2 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -183,10 +183,20 @@ class ConvertToLLVMPattern : public ConversionPattern {
                          ArrayRef<Value> sizes, ArrayRef<Value> strides,
                          ConversionPatternRewriter &rewriter) const;
 
+  /// Copies the given unranked memory descriptor to heap-allocated memory (if
+  /// toDynamic is true) or to stack-allocated memory (otherwise) and returns
+  /// the new descriptor. Also frees the previously used memory (that is assumed
+  /// to be heap-allocated) if toDynamic is false. Returns a "null" SSA value
+  /// on failure.
+  Value copyUnrankedDescriptor(OpBuilder &builder, Location loc,
+                               UnrankedMemRefType memRefType, Value &operand,
+                               bool toDynamic) const;
+
   /// Copies the memory descriptor for any operands that were unranked
   /// descriptors originally to heap-allocated memory (if toDynamic is true) or
-  /// to stack-allocated memory (otherwise). Also frees the previously used
-  /// memory (that is assumed to be heap-allocated) if toDynamic is false.
+  /// to stack-allocated memory (otherwise). The vector of descriptors is
+  /// updated in place. Also frees the previously used memory (that is assumed
+  /// to be heap-allocated) if toDynamic is false.
   LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
                                         TypeRange origTypes,
                                         SmallVectorImpl<Value> &operands,
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 67bb1c14c99a2..704492a83d680 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -688,28 +688,32 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
     auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
     bool useBarePtrCallConv =
         shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
-    if (useBarePtrCallConv) {
-      // For the bare-ptr calling convention, extract the aligned pointer to
-      // be returned from the memref descriptor.
-      for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
-        Type oldTy = std::get<0>(it).getType();
-        Value newOperand = std::get<1>(it);
-        if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
-                                          cast<BaseMemRefType>(oldTy))) {
+
+    for (auto it : llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
+      Type oldTy = std::get<0>(it).getType();
+      Value newOperand = std::get<1>(it);
+      if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
+        if (useBarePtrCallConv &&
+            getTypeConverter()->canConvertToBarePtr(memRefType)) {
+          // For the bare-ptr calling convention, extract the aligned pointer to
+          // be returned from the memref descriptor.
           MemRefDescriptor memrefDesc(newOperand);
           newOperand = memrefDesc.allocatedPtr(rewriter, loc);
-        } else if (isa<UnrankedMemRefType>(oldTy)) {
+        }
+      } else if (auto unrankedMemRefType =
+                     dyn_cast<UnrankedMemRefType>(oldTy)) {
+        if (useBarePtrCallConv) {
           // Unranked memref is not supported in the bare pointer calling
           // convention.
           return failure();
         }
-        updatedOperands.push_back(newOperand);
+        Value updatedDesc = copyUnrankedDescriptor(
+            rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true);
+        if (!updatedDesc)
+          return failure();
+        newOperand = updatedDesc;
       }
-    } else {
-      updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
-      (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
-                                    updatedOperands,
-                                    /*toDynamic=*/true);
+      updatedOperands.push_back(newOperand);
     }
 
     // If ReturnOp has 0 or 1 operand, create it and return immediately.
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 72f41fd01fe7c..de7528a0f1a2b 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -216,28 +216,14 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
   return memRefDescriptor;
 }
 
-LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
-    OpBuilder &builder, Location loc, TypeRange origTypes,
-    SmallVectorImpl<Value> &operands, bool toDynamic) const {
-  assert(origTypes.size() == operands.size() &&
-         "expected as may original types as operands");
-
-  // Find operands of unranked memref type and store them.
-  SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
-  SmallVector<unsigned> unrankedAddressSpaces;
-  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
-    if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
-      unrankedMemrefs.emplace_back(operands[i]);
-      FailureOr<unsigned> addressSpace =
-          getTypeConverter()->getMemRefAddressSpace(memRefType);
-      if (failed(addressSpace))
-        return failure();
-      unrankedAddressSpaces.emplace_back(*addressSpace);
-    }
-  }
-
-  if (unrankedMemrefs.empty())
-    return success();
+Value ConvertToLLVMPattern::copyUnrankedDescriptor(
+    OpBuilder &builder, Location loc, UnrankedMemRefType memRefType,
+    Value &operand, bool toDynamic) const {
+  // Convert memory space.
+  FailureOr<unsigned> addressSpace =
+      getTypeConverter()->getMemRefAddressSpace(memRefType);
+  if (failed(addressSpace))
+    return {};
 
   // Get frequently used types.
   Type indexType = getTypeConverter()->getIndexType();
@@ -248,54 +234,61 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
   if (toDynamic) {
     mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
     if (failed(mallocFunc))
-      return failure();
+      return {};
   }
   if (!toDynamic) {
     freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
     if (failed(freeFunc))
-      return failure();
+      return {};
   }
 
-  unsigned unrankedMemrefPos = 0;
-  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
-    Type type = origTypes[i];
-    if (!isa<UnrankedMemRefType>(type))
-      continue;
-    UnrankedMemRefDescriptor desc(operands[i]);
-    Value allocationSize = UnrankedMemRefDescriptor::computeSize(
-        builder, loc, *getTypeConverter(), desc,
-        unrankedAddressSpaces[unrankedMemrefPos++]);
-
-    // Allocate memory, copy, and free the source if necessary.
-    Value memory =
-        toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
-                                         allocationSize)
-                        .getResult()
-                  : LLVM::AllocaOp::create(builder, loc, getPtrType(),
-                                           IntegerType::get(getContext(), 8),
-                                           allocationSize,
-                                           /*alignment=*/0);
-    Value source = desc.memRefDescPtr(builder, loc);
-    LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
-    if (!toDynamic)
-      LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
-
-    // Create a new descriptor. The same descriptor can be returned multiple
-    // times, attempting to modify its pointer can lead to memory leaks
-    // (allocated twice and overwritten) or double frees (the caller does not
-    // know if the descriptor points to the same memory).
-    Type descriptorType = getTypeConverter()->convertType(type);
-    if (!descriptorType)
-      return failure();
-    auto updatedDesc =
-        UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
-    Value rank = desc.rank(builder, loc);
-    updatedDesc.setRank(builder, loc, rank);
-    updatedDesc.setMemRefDescPtr(builder, loc, memory);
+  UnrankedMemRefDescriptor desc(operand);
+  Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+      builder, loc, *getTypeConverter(), desc, *addressSpace);
+
+  // Allocate memory, copy, and free the source if necessary.
+  Value memory = toDynamic
+                     ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
+                                            allocationSize)
+                           .getResult()
+                     : LLVM::AllocaOp::create(builder, loc, getPtrType(),
+                                              IntegerType::get(getContext(), 8),
+                                              allocationSize,
+                                              /*alignment=*/0);
+  Value source = desc.memRefDescPtr(builder, loc);
+  LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
+  if (!toDynamic)
+    LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
+
+  // Create a new descriptor. The same descriptor can be returned multiple
+  // times, attempting to modify its pointer can lead to memory leaks
+  // (allocated twice and overwritten) or double frees (the caller does not
+  // know if the descriptor points to the same memory).
+  Type descriptorType = getTypeConverter()->convertType(memRefType);
+  if (!descriptorType)
+    return {};
+  auto updatedDesc =
+      UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
+  Value rank = desc.rank(builder, loc);
+  updatedDesc.setRank(builder, loc, rank);
+  updatedDesc.setMemRefDescPtr(builder, loc, memory);
+  return updatedDesc;
+}
 
-    operands[i] = updatedDesc;
+LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
+    OpBuilder &builder, Location loc, TypeRange origTypes,
+    SmallVectorImpl<Value> &operands, bool toDynamic) const {
+  assert(origTypes.size() == operands.size() &&
+         "expected as may original types as operands");
+  for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+    if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
+      Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType,
+                                                 operands[i], toDynamic);
+      if (!updatedDesc)
+        return failure();
+      operands[i] = updatedDesc;
+    }
   }
-
   return success();
 }
 

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo optional nits.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/simplify_copy_unranked branch from 6b77603 to 137f5f3 Compare August 14, 2025 16:16
@matthias-springer matthias-springer merged commit e2ae634 into main Aug 14, 2025
9 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/simplify_copy_unranked branch August 14, 2025 16:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants