-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][LLVM][NFC] Simplify computeSizes function
#153588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Member
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesRename Full diff: https://github.com/llvm/llvm-project/pull/153588.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index d5055f023cdc8..8e86808cc424a 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -189,15 +189,13 @@ class UnrankedMemRefDescriptor : public StructBuilder {
/// `unpack`.
static unsigned getNumUnpackedValues() { return 2; }
- /// Builds IR computing the sizes in bytes (suitable for opaque allocation)
- /// and appends the corresponding values into `sizes`. `addressSpaces`
- /// which must have the same length as `values`, is needed to handle layouts
- /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
- static void computeSizes(OpBuilder &builder, Location loc,
+ /// Builds and returns IR computing the size in bytes (suitable for opaque
+ /// allocation). `addressSpace` is needed to handle layouts where
+ /// sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
+ static Value computeSize(OpBuilder &builder, Location loc,
const LLVMTypeConverter &typeConverter,
- ArrayRef<UnrankedMemRefDescriptor> values,
- ArrayRef<unsigned> addressSpaces,
- SmallVectorImpl<Value> &sizes);
+ UnrankedMemRefDescriptor desc,
+ unsigned addressSpace);
/// TODO: The following accessors don't take alignment rules between elements
/// of the descriptor struct into account. For some architectures, it might be
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index fce7a3f324b86..522e91421ff55 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
results.push_back(d.memRefDescPtr(builder, loc));
}
-void UnrankedMemRefDescriptor::computeSizes(
+Value UnrankedMemRefDescriptor::computeSize(
OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
- ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
- SmallVectorImpl<Value> &sizes) {
- if (values.empty())
- return;
- assert(values.size() == addressSpaces.size() &&
- "must provide address space for each descriptor");
+ UnrankedMemRefDescriptor desc, unsigned addressSpace) {
// Cache the index type.
Type indexType = typeConverter.getIndexType();
@@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes(
builder, loc, indexType,
llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
- sizes.reserve(sizes.size() + values.size());
- for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
- // Emit IR computing the memory necessary to store the descriptor. This
- // assumes the descriptor to be
- // { type*, type*, index, index[rank], index[rank] }
- // and densely packed, so the total size is
- // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
- // TODO: consider including the actual size (including eventual padding due
- // to data layout) into the unranked descriptor.
- Value pointerSize = createIndexAttrConstant(
- builder, loc, indexType,
- llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
- Value doublePointerSize =
- LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
-
- // (1 + 2 * rank) * sizeof(index)
- Value rank = desc.rank(builder, loc);
- Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
- Value doubleRankIncremented =
- LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
- Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
- doubleRankIncremented, indexSize);
-
- // Total allocation size.
- Value allocationSize = LLVM::AddOp::create(
- builder, loc, indexType, doublePointerSize, rankIndexSize);
- sizes.push_back(allocationSize);
- }
+ // Emit IR computing the memory necessary to store the descriptor. This
+ // assumes the descriptor to be
+ // { type*, type*, index, index[rank], index[rank] }
+ // and densely packed, so the total size is
+ // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
+ // TODO: consider including the actual size (including eventual padding due
+ // to data layout) into the unranked descriptor.
+ Value pointerSize = createIndexAttrConstant(
+ builder, loc, indexType,
+ llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
+ Value doublePointerSize =
+ LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
+
+ // (1 + 2 * rank) * sizeof(index)
+ Value rank = desc.rank(builder, loc);
+ Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
+ Value doubleRankIncremented =
+ LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
+ Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
+ doubleRankIncremented, indexSize);
+
+ // Total allocation size.
+ Value allocationSize = LLVM::AddOp::create(builder, loc, indexType,
+ doublePointerSize, rankIndexSize);
+ return allocationSize;
}
Value UnrankedMemRefDescriptor::allocatedPtr(
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 2568044f1fd32..72f41fd01fe7c 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -239,12 +239,6 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
if (unrankedMemrefs.empty())
return success();
- // Compute allocation sizes.
- SmallVector<Value> sizes;
- UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
- unrankedMemrefs, unrankedAddressSpaces,
- sizes);
-
// Get frequently used types.
Type indexType = getTypeConverter()->getIndexType();
@@ -267,8 +261,10 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
Type type = origTypes[i];
if (!isa<UnrankedMemRefType>(type))
continue;
- Value allocationSize = sizes[unrankedMemrefPos++];
UnrankedMemRefDescriptor desc(operands[i]);
+ Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+ builder, loc, *getTypeConverter(), desc,
+ unrankedAddressSpaces[unrankedMemrefPos++]);
// Allocate memory, copy, and free the source if necessary.
Value memory =
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 9216e2a35a5ae..262e0e7a30c63 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering
auto result = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(resultTypeU));
result.setRank(rewriter, loc, rank);
- SmallVector<Value, 1> sizes;
- UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
- result, resultAddrSpace, sizes);
- Value resultUnderlyingSize = sizes.front();
+ Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
+ rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
Value resultUnderlyingDesc =
LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
rewriter.getI8Type(), resultUnderlyingSize);
@@ -1530,12 +1528,11 @@ struct MemRefReshapeOpLowering
auto targetDesc = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(targetType));
targetDesc.setRank(rewriter, loc, resultRank);
- SmallVector<Value, 4> sizes;
- UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
- targetDesc, addressSpace, sizes);
+ Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+ rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
Value underlyingDescPtr = LLVM::AllocaOp::create(
rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
- sizes.front());
+ allocationSize);
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
// Extract pointers and offset from the source memref.
|
zero9178
approved these changes
Aug 14, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Rename
computeSizestocomputeSizeand make it compute just a single size. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns.