-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][LLVM][NFC] Simplify copyUnrankedDescriptors
#153597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][LLVM][NFC] Simplify copyUnrankedDescriptors
#153597
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesSplit the function into two: one that copies a single unranked descriptor and one that copies multiple unranked descriptors. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns. Full diff: https://github.com/llvm/llvm-project/pull/153597.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 969154abe8830..8b72a6c5db9c2 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -183,10 +183,20 @@ class ConvertToLLVMPattern : public ConversionPattern {
ArrayRef<Value> sizes, ArrayRef<Value> strides,
ConversionPatternRewriter &rewriter) const;
+ /// Copies the given unranked memory descriptor to heap-allocated memory (if
+ /// toDynamic is true) or to stack-allocated memory (otherwise) and returns
+ /// the new descriptor. Also frees the previously used memory (that is assumed
+ /// to be heap-allocated) if toDynamic is false. Returns a "null" SSA value
+ /// on failure.
+ Value copyUnrankedDescriptor(OpBuilder &builder, Location loc,
+ UnrankedMemRefType memRefType, Value &operand,
+ bool toDynamic) const;
+
/// Copies the memory descriptor for any operands that were unranked
/// descriptors originally to heap-allocated memory (if toDynamic is true) or
- /// to stack-allocated memory (otherwise). Also frees the previously used
- /// memory (that is assumed to be heap-allocated) if toDynamic is false.
+ /// to stack-allocated memory (otherwise). The vector of descriptors is
+ /// updated in place. Also frees the previously used memory (that is assumed
+ /// to be heap-allocated) if toDynamic is false.
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
TypeRange origTypes,
SmallVectorImpl<Value> &operands,
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 67bb1c14c99a2..704492a83d680 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -688,28 +688,32 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
bool useBarePtrCallConv =
shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
- if (useBarePtrCallConv) {
- // For the bare-ptr calling convention, extract the aligned pointer to
- // be returned from the memref descriptor.
- for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
- Type oldTy = std::get<0>(it).getType();
- Value newOperand = std::get<1>(it);
- if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
- cast<BaseMemRefType>(oldTy))) {
+
+ for (auto it : llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
+ Type oldTy = std::get<0>(it).getType();
+ Value newOperand = std::get<1>(it);
+ if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
+ if (useBarePtrCallConv &&
+ getTypeConverter()->canConvertToBarePtr(memRefType)) {
+ // For the bare-ptr calling convention, extract the aligned pointer to
+ // be returned from the memref descriptor.
MemRefDescriptor memrefDesc(newOperand);
newOperand = memrefDesc.allocatedPtr(rewriter, loc);
- } else if (isa<UnrankedMemRefType>(oldTy)) {
+ }
+ } else if (auto unrankedMemRefType =
+ dyn_cast<UnrankedMemRefType>(oldTy)) {
+ if (useBarePtrCallConv) {
// Unranked memref is not supported in the bare pointer calling
// convention.
return failure();
}
- updatedOperands.push_back(newOperand);
+ Value updatedDesc = copyUnrankedDescriptor(
+ rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true);
+ if (!updatedDesc)
+ return failure();
+ newOperand = updatedDesc;
}
- } else {
- updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
- (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
- updatedOperands,
- /*toDynamic=*/true);
+ updatedOperands.push_back(newOperand);
}
// If ReturnOp has 0 or 1 operand, create it and return immediately.
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 72f41fd01fe7c..de7528a0f1a2b 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -216,28 +216,14 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
return memRefDescriptor;
}
-LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
- OpBuilder &builder, Location loc, TypeRange origTypes,
- SmallVectorImpl<Value> &operands, bool toDynamic) const {
- assert(origTypes.size() == operands.size() &&
- "expected as may original types as operands");
-
- // Find operands of unranked memref type and store them.
- SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
- SmallVector<unsigned> unrankedAddressSpaces;
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
- unrankedMemrefs.emplace_back(operands[i]);
- FailureOr<unsigned> addressSpace =
- getTypeConverter()->getMemRefAddressSpace(memRefType);
- if (failed(addressSpace))
- return failure();
- unrankedAddressSpaces.emplace_back(*addressSpace);
- }
- }
-
- if (unrankedMemrefs.empty())
- return success();
+Value ConvertToLLVMPattern::copyUnrankedDescriptor(
+ OpBuilder &builder, Location loc, UnrankedMemRefType memRefType,
+ Value &operand, bool toDynamic) const {
+ // Convert memory space.
+ FailureOr<unsigned> addressSpace =
+ getTypeConverter()->getMemRefAddressSpace(memRefType);
+ if (failed(addressSpace))
+ return {};
// Get frequently used types.
Type indexType = getTypeConverter()->getIndexType();
@@ -248,54 +234,61 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
if (toDynamic) {
mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
if (failed(mallocFunc))
- return failure();
+ return {};
}
if (!toDynamic) {
freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
if (failed(freeFunc))
- return failure();
+ return {};
}
- unsigned unrankedMemrefPos = 0;
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- Type type = origTypes[i];
- if (!isa<UnrankedMemRefType>(type))
- continue;
- UnrankedMemRefDescriptor desc(operands[i]);
- Value allocationSize = UnrankedMemRefDescriptor::computeSize(
- builder, loc, *getTypeConverter(), desc,
- unrankedAddressSpaces[unrankedMemrefPos++]);
-
- // Allocate memory, copy, and free the source if necessary.
- Value memory =
- toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
- allocationSize)
- .getResult()
- : LLVM::AllocaOp::create(builder, loc, getPtrType(),
- IntegerType::get(getContext(), 8),
- allocationSize,
- /*alignment=*/0);
- Value source = desc.memRefDescPtr(builder, loc);
- LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
- if (!toDynamic)
- LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
-
- // Create a new descriptor. The same descriptor can be returned multiple
- // times, attempting to modify its pointer can lead to memory leaks
- // (allocated twice and overwritten) or double frees (the caller does not
- // know if the descriptor points to the same memory).
- Type descriptorType = getTypeConverter()->convertType(type);
- if (!descriptorType)
- return failure();
- auto updatedDesc =
- UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
- Value rank = desc.rank(builder, loc);
- updatedDesc.setRank(builder, loc, rank);
- updatedDesc.setMemRefDescPtr(builder, loc, memory);
+ UnrankedMemRefDescriptor desc(operand);
+ Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+ builder, loc, *getTypeConverter(), desc, *addressSpace);
+
+ // Allocate memory, copy, and free the source if necessary.
+ Value memory = toDynamic
+ ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
+ allocationSize)
+ .getResult()
+ : LLVM::AllocaOp::create(builder, loc, getPtrType(),
+ IntegerType::get(getContext(), 8),
+ allocationSize,
+ /*alignment=*/0);
+ Value source = desc.memRefDescPtr(builder, loc);
+ LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
+ if (!toDynamic)
+ LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
+
+ // Create a new descriptor. The same descriptor can be returned multiple
+ // times, attempting to modify its pointer can lead to memory leaks
+ // (allocated twice and overwritten) or double frees (the caller does not
+ // know if the descriptor points to the same memory).
+ Type descriptorType = getTypeConverter()->convertType(memRefType);
+ if (!descriptorType)
+ return {};
+ auto updatedDesc =
+ UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
+ Value rank = desc.rank(builder, loc);
+ updatedDesc.setRank(builder, loc, rank);
+ updatedDesc.setMemRefDescPtr(builder, loc, memory);
+ return updatedDesc;
+}
- operands[i] = updatedDesc;
+LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
+ OpBuilder &builder, Location loc, TypeRange origTypes,
+ SmallVectorImpl<Value> &operands, bool toDynamic) const {
+ assert(origTypes.size() == operands.size() &&
+ "expected as may original types as operands");
+ for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+ if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
+ Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType,
+ operands[i], toDynamic);
+ if (!updatedDesc)
+ return failure();
+ operands[i] = updatedDesc;
+ }
}
-
return success();
}
|
gysit
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo optional nits.
6b77603 to
137f5f3
Compare
Split the function into two: one that copies a single unranked descriptor and one that copies multiple unranked descriptors. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns.