Skip to content

Conversation

@matthias-springer
Copy link
Member

Rename computeSizes to computeSize and make it compute just a single size. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns.

@llvmbot
Copy link
Member

llvmbot commented Aug 14, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

Rename computeSizes to computeSize and make it compute just a single size. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns.


Full diff: https://github.com/llvm/llvm-project/pull/153588.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h (+6-8)
  • (modified) mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp (+27-35)
  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+3-7)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+5-8)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index d5055f023cdc8..8e86808cc424a 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -189,15 +189,13 @@ class UnrankedMemRefDescriptor : public StructBuilder {
   /// `unpack`.
   static unsigned getNumUnpackedValues() { return 2; }
 
-  /// Builds IR computing the sizes in bytes (suitable for opaque allocation)
-  /// and appends the corresponding values into `sizes`. `addressSpaces`
-  /// which must have the same length as `values`, is needed to handle layouts
-  /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
-  static void computeSizes(OpBuilder &builder, Location loc,
+  /// Builds and returns IR computing the size in bytes (suitable for opaque
+  /// allocation). `addressSpace` is needed to handle layouts where
+  /// sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
+  static Value computeSize(OpBuilder &builder, Location loc,
                            const LLVMTypeConverter &typeConverter,
-                           ArrayRef<UnrankedMemRefDescriptor> values,
-                           ArrayRef<unsigned> addressSpaces,
-                           SmallVectorImpl<Value> &sizes);
+                           UnrankedMemRefDescriptor desc,
+                           unsigned addressSpace);
 
   /// TODO: The following accessors don't take alignment rules between elements
   /// of the descriptor struct into account. For some architectures, it might be
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index fce7a3f324b86..522e91421ff55 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
   results.push_back(d.memRefDescPtr(builder, loc));
 }
 
-void UnrankedMemRefDescriptor::computeSizes(
+Value UnrankedMemRefDescriptor::computeSize(
     OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
-    ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
-    SmallVectorImpl<Value> &sizes) {
-  if (values.empty())
-    return;
-  assert(values.size() == addressSpaces.size() &&
-         "must provide address space for each descriptor");
+    UnrankedMemRefDescriptor desc, unsigned addressSpace) {
   // Cache the index type.
   Type indexType = typeConverter.getIndexType();
 
@@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes(
       builder, loc, indexType,
       llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
 
-  sizes.reserve(sizes.size() + values.size());
-  for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
-    // Emit IR computing the memory necessary to store the descriptor. This
-    // assumes the descriptor to be
-    //   { type*, type*, index, index[rank], index[rank] }
-    // and densely packed, so the total size is
-    //   2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
-    // TODO: consider including the actual size (including eventual padding due
-    // to data layout) into the unranked descriptor.
-    Value pointerSize = createIndexAttrConstant(
-        builder, loc, indexType,
-        llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
-    Value doublePointerSize =
-        LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
-
-    // (1 + 2 * rank) * sizeof(index)
-    Value rank = desc.rank(builder, loc);
-    Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
-    Value doubleRankIncremented =
-        LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
-    Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
-                                              doubleRankIncremented, indexSize);
-
-    // Total allocation size.
-    Value allocationSize = LLVM::AddOp::create(
-        builder, loc, indexType, doublePointerSize, rankIndexSize);
-    sizes.push_back(allocationSize);
-  }
+  // Emit IR computing the memory necessary to store the descriptor. This
+  // assumes the descriptor to be
+  //   { type*, type*, index, index[rank], index[rank] }
+  // and densely packed, so the total size is
+  //   2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
+  // TODO: consider including the actual size (including eventual padding due
+  // to data layout) into the unranked descriptor.
+  Value pointerSize = createIndexAttrConstant(
+      builder, loc, indexType,
+      llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
+  Value doublePointerSize =
+      LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
+
+  // (1 + 2 * rank) * sizeof(index)
+  Value rank = desc.rank(builder, loc);
+  Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
+  Value doubleRankIncremented =
+      LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
+  Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
+                                            doubleRankIncremented, indexSize);
+
+  // Total allocation size.
+  Value allocationSize = LLVM::AddOp::create(builder, loc, indexType,
+                                             doublePointerSize, rankIndexSize);
+  return allocationSize;
 }
 
 Value UnrankedMemRefDescriptor::allocatedPtr(
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 2568044f1fd32..72f41fd01fe7c 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -239,12 +239,6 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
   if (unrankedMemrefs.empty())
     return success();
 
-  // Compute allocation sizes.
-  SmallVector<Value> sizes;
-  UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
-                                         unrankedMemrefs, unrankedAddressSpaces,
-                                         sizes);
-
   // Get frequently used types.
   Type indexType = getTypeConverter()->getIndexType();
 
@@ -267,8 +261,10 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
     Type type = origTypes[i];
     if (!isa<UnrankedMemRefType>(type))
       continue;
-    Value allocationSize = sizes[unrankedMemrefPos++];
     UnrankedMemRefDescriptor desc(operands[i]);
+    Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+        builder, loc, *getTypeConverter(), desc,
+        unrankedAddressSpaces[unrankedMemrefPos++]);
 
     // Allocate memory, copy, and free the source if necessary.
     Value memory =
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 9216e2a35a5ae..262e0e7a30c63 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering
       auto result = UnrankedMemRefDescriptor::poison(
           rewriter, loc, typeConverter->convertType(resultTypeU));
       result.setRank(rewriter, loc, rank);
-      SmallVector<Value, 1> sizes;
-      UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
-                                             result, resultAddrSpace, sizes);
-      Value resultUnderlyingSize = sizes.front();
+      Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
+          rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
       Value resultUnderlyingDesc =
           LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
                                  rewriter.getI8Type(), resultUnderlyingSize);
@@ -1530,12 +1528,11 @@ struct MemRefReshapeOpLowering
     auto targetDesc = UnrankedMemRefDescriptor::poison(
         rewriter, loc, typeConverter->convertType(targetType));
     targetDesc.setRank(rewriter, loc, resultRank);
-    SmallVector<Value, 4> sizes;
-    UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
-                                           targetDesc, addressSpace, sizes);
+    Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+        rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
     Value underlyingDescPtr = LLVM::AllocaOp::create(
         rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
-        sizes.front());
+        allocationSize);
     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
 
     // Extract pointers and offset from the source memref.

@matthias-springer matthias-springer merged commit 0ff92fe into main Aug 14, 2025
12 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/compute_sizes branch August 14, 2025 15:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants