Skip to content

Commit e2ae634

Browse files
[mlir][LLVM][NFC] Simplify copyUnrankedDescriptors (#153597)
Split the function into two: one that copies a single unranked descriptor and one that copies multiple unranked descriptors. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns.
1 parent 1945753 commit e2ae634

File tree

3 files changed

+89
-80
lines changed

3 files changed

+89
-80
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,20 @@ class ConvertToLLVMPattern : public ConversionPattern {
183183
ArrayRef<Value> sizes, ArrayRef<Value> strides,
184184
ConversionPatternRewriter &rewriter) const;
185185

186+
/// Copies the given unranked memory descriptor to heap-allocated memory (if
187+
/// toDynamic is true) or to stack-allocated memory (otherwise) and returns
188+
/// the new descriptor. Also frees the previously used memory (that is assumed
189+
/// to be heap-allocated) if toDynamic is false. Returns a "null" SSA value
190+
/// on failure.
191+
Value copyUnrankedDescriptor(OpBuilder &builder, Location loc,
192+
UnrankedMemRefType memRefType, Value operand,
193+
bool toDynamic) const;
194+
186195
/// Copies the memory descriptor for any operands that were unranked
187196
/// descriptors originally to heap-allocated memory (if toDynamic is true) or
188-
/// to stack-allocated memory (otherwise). Also frees the previously used
189-
/// memory (that is assumed to be heap-allocated) if toDynamic is false.
197+
/// to stack-allocated memory (otherwise). The vector of descriptors is
198+
/// updated in place. Also frees the previously used memory (that is assumed
199+
/// to be heap-allocated) if toDynamic is false.
190200
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
191201
TypeRange origTypes,
192202
SmallVectorImpl<Value> &operands,

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -688,28 +688,34 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
688688
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
689689
bool useBarePtrCallConv =
690690
shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
691-
if (useBarePtrCallConv) {
692-
// For the bare-ptr calling convention, extract the aligned pointer to
693-
// be returned from the memref descriptor.
694-
for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
695-
Type oldTy = std::get<0>(it).getType();
696-
Value newOperand = std::get<1>(it);
697-
if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
698-
cast<BaseMemRefType>(oldTy))) {
691+
692+
for (auto [oldOperand, newOperand] :
693+
llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
694+
Type oldTy = oldOperand.getType();
695+
if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
696+
if (useBarePtrCallConv &&
697+
getTypeConverter()->canConvertToBarePtr(memRefType)) {
698+
// For the bare-ptr calling convention, extract the aligned pointer to
699+
// be returned from the memref descriptor.
699700
MemRefDescriptor memrefDesc(newOperand);
700-
newOperand = memrefDesc.allocatedPtr(rewriter, loc);
701-
} else if (isa<UnrankedMemRefType>(oldTy)) {
701+
updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
702+
continue;
703+
}
704+
} else if (auto unrankedMemRefType =
705+
dyn_cast<UnrankedMemRefType>(oldTy)) {
706+
if (useBarePtrCallConv) {
702707
// Unranked memref is not supported in the bare pointer calling
703708
// convention.
704709
return failure();
705710
}
706-
updatedOperands.push_back(newOperand);
711+
Value updatedDesc = copyUnrankedDescriptor(
712+
rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true);
713+
if (!updatedDesc)
714+
return failure();
715+
updatedOperands.push_back(updatedDesc);
716+
continue;
707717
}
708-
} else {
709-
updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
710-
(void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
711-
updatedOperands,
712-
/*toDynamic=*/true);
718+
updatedOperands.push_back(newOperand);
713719
}
714720

715721
// If ReturnOp has 0 or 1 operand, create it and return immediately.

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 55 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -216,28 +216,14 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
216216
return memRefDescriptor;
217217
}
218218

219-
LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
220-
OpBuilder &builder, Location loc, TypeRange origTypes,
221-
SmallVectorImpl<Value> &operands, bool toDynamic) const {
222-
assert(origTypes.size() == operands.size() &&
223-
"expected as may original types as operands");
224-
225-
// Find operands of unranked memref type and store them.
226-
SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
227-
SmallVector<unsigned> unrankedAddressSpaces;
228-
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
229-
if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
230-
unrankedMemrefs.emplace_back(operands[i]);
231-
FailureOr<unsigned> addressSpace =
232-
getTypeConverter()->getMemRefAddressSpace(memRefType);
233-
if (failed(addressSpace))
234-
return failure();
235-
unrankedAddressSpaces.emplace_back(*addressSpace);
236-
}
237-
}
238-
239-
if (unrankedMemrefs.empty())
240-
return success();
219+
Value ConvertToLLVMPattern::copyUnrankedDescriptor(
220+
OpBuilder &builder, Location loc, UnrankedMemRefType memRefType,
221+
Value operand, bool toDynamic) const {
222+
// Convert memory space.
223+
FailureOr<unsigned> addressSpace =
224+
getTypeConverter()->getMemRefAddressSpace(memRefType);
225+
if (failed(addressSpace))
226+
return {};
241227

242228
// Get frequently used types.
243229
Type indexType = getTypeConverter()->getIndexType();
@@ -248,54 +234,61 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
248234
if (toDynamic) {
249235
mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
250236
if (failed(mallocFunc))
251-
return failure();
237+
return {};
252238
}
253239
if (!toDynamic) {
254240
freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
255241
if (failed(freeFunc))
256-
return failure();
242+
return {};
257243
}
258244

259-
unsigned unrankedMemrefPos = 0;
260-
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
261-
Type type = origTypes[i];
262-
if (!isa<UnrankedMemRefType>(type))
263-
continue;
264-
UnrankedMemRefDescriptor desc(operands[i]);
265-
Value allocationSize = UnrankedMemRefDescriptor::computeSize(
266-
builder, loc, *getTypeConverter(), desc,
267-
unrankedAddressSpaces[unrankedMemrefPos++]);
268-
269-
// Allocate memory, copy, and free the source if necessary.
270-
Value memory =
271-
toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
272-
allocationSize)
273-
.getResult()
274-
: LLVM::AllocaOp::create(builder, loc, getPtrType(),
275-
IntegerType::get(getContext(), 8),
276-
allocationSize,
277-
/*alignment=*/0);
278-
Value source = desc.memRefDescPtr(builder, loc);
279-
LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
280-
if (!toDynamic)
281-
LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
282-
283-
// Create a new descriptor. The same descriptor can be returned multiple
284-
// times, attempting to modify its pointer can lead to memory leaks
285-
// (allocated twice and overwritten) or double frees (the caller does not
286-
// know if the descriptor points to the same memory).
287-
Type descriptorType = getTypeConverter()->convertType(type);
288-
if (!descriptorType)
289-
return failure();
290-
auto updatedDesc =
291-
UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
292-
Value rank = desc.rank(builder, loc);
293-
updatedDesc.setRank(builder, loc, rank);
294-
updatedDesc.setMemRefDescPtr(builder, loc, memory);
245+
UnrankedMemRefDescriptor desc(operand);
246+
Value allocationSize = UnrankedMemRefDescriptor::computeSize(
247+
builder, loc, *getTypeConverter(), desc, *addressSpace);
248+
249+
// Allocate memory, copy, and free the source if necessary.
250+
Value memory = toDynamic
251+
? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
252+
allocationSize)
253+
.getResult()
254+
: LLVM::AllocaOp::create(builder, loc, getPtrType(),
255+
IntegerType::get(getContext(), 8),
256+
allocationSize,
257+
/*alignment=*/0);
258+
Value source = desc.memRefDescPtr(builder, loc);
259+
LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
260+
if (!toDynamic)
261+
LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
262+
263+
// Create a new descriptor. The same descriptor can be returned multiple
264+
// times, attempting to modify its pointer can lead to memory leaks
265+
// (allocated twice and overwritten) or double frees (the caller does not
266+
// know if the descriptor points to the same memory).
267+
Type descriptorType = getTypeConverter()->convertType(memRefType);
268+
if (!descriptorType)
269+
return {};
270+
auto updatedDesc =
271+
UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
272+
Value rank = desc.rank(builder, loc);
273+
updatedDesc.setRank(builder, loc, rank);
274+
updatedDesc.setMemRefDescPtr(builder, loc, memory);
275+
return updatedDesc;
276+
}
295277

296-
operands[i] = updatedDesc;
278+
LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
279+
OpBuilder &builder, Location loc, TypeRange origTypes,
280+
SmallVectorImpl<Value> &operands, bool toDynamic) const {
281+
assert(origTypes.size() == operands.size() &&
282+
"expected as may original types as operands");
283+
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
284+
if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
285+
Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType,
286+
operands[i], toDynamic);
287+
if (!updatedDesc)
288+
return failure();
289+
operands[i] = updatedDesc;
290+
}
297291
}
298-
299292
return success();
300293
}
301294

0 commit comments

Comments
 (0)