Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,20 @@ class ConvertToLLVMPattern : public ConversionPattern {
ArrayRef<Value> sizes, ArrayRef<Value> strides,
ConversionPatternRewriter &rewriter) const;

/// Copies the given unranked memory descriptor to heap-allocated memory (if
/// toDynamic is true) or to stack-allocated memory (otherwise) and returns
/// the new descriptor. Also frees the previously used memory (that is assumed
/// to be heap-allocated) if toDynamic is false. Returns a "null" SSA value
/// on failure.
Value copyUnrankedDescriptor(OpBuilder &builder, Location loc,
UnrankedMemRefType memRefType, Value operand,
bool toDynamic) const;

/// Copies the memory descriptor for any operands that were unranked
/// descriptors originally to heap-allocated memory (if toDynamic is true) or
/// to stack-allocated memory (otherwise). Also frees the previously used
/// memory (that is assumed to be heap-allocated) if toDynamic is false.
/// to stack-allocated memory (otherwise). The vector of descriptors is
/// updated in place. Also frees the previously used memory (that is assumed
/// to be heap-allocated) if toDynamic is false.
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
TypeRange origTypes,
SmallVectorImpl<Value> &operands,
Expand Down
38 changes: 22 additions & 16 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -688,28 +688,34 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
bool useBarePtrCallConv =
shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
if (useBarePtrCallConv) {
// For the bare-ptr calling convention, extract the aligned pointer to
// be returned from the memref descriptor.
for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
Type oldTy = std::get<0>(it).getType();
Value newOperand = std::get<1>(it);
if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
cast<BaseMemRefType>(oldTy))) {

for (auto [oldOperand, newOperand] :
llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
Type oldTy = oldOperand.getType();
if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
if (useBarePtrCallConv &&
getTypeConverter()->canConvertToBarePtr(memRefType)) {
// For the bare-ptr calling convention, extract the aligned pointer to
// be returned from the memref descriptor.
MemRefDescriptor memrefDesc(newOperand);
newOperand = memrefDesc.allocatedPtr(rewriter, loc);
} else if (isa<UnrankedMemRefType>(oldTy)) {
updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
continue;
}
} else if (auto unrankedMemRefType =
dyn_cast<UnrankedMemRefType>(oldTy)) {
if (useBarePtrCallConv) {
// Unranked memref is not supported in the bare pointer calling
// convention.
return failure();
}
updatedOperands.push_back(newOperand);
Value updatedDesc = copyUnrankedDescriptor(
rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true);
if (!updatedDesc)
return failure();
updatedOperands.push_back(updatedDesc);
continue;
}
} else {
updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
(void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
updatedOperands,
/*toDynamic=*/true);
updatedOperands.push_back(newOperand);
}

// If ReturnOp has 0 or 1 operand, create it and return immediately.
Expand Down
117 changes: 55 additions & 62 deletions mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,28 +216,14 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
return memRefDescriptor;
}

LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
OpBuilder &builder, Location loc, TypeRange origTypes,
SmallVectorImpl<Value> &operands, bool toDynamic) const {
assert(origTypes.size() == operands.size() &&
"expected as may original types as operands");

// Find operands of unranked memref type and store them.
SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
SmallVector<unsigned> unrankedAddressSpaces;
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
unrankedMemrefs.emplace_back(operands[i]);
FailureOr<unsigned> addressSpace =
getTypeConverter()->getMemRefAddressSpace(memRefType);
if (failed(addressSpace))
return failure();
unrankedAddressSpaces.emplace_back(*addressSpace);
}
}

if (unrankedMemrefs.empty())
return success();
Value ConvertToLLVMPattern::copyUnrankedDescriptor(
OpBuilder &builder, Location loc, UnrankedMemRefType memRefType,
Value operand, bool toDynamic) const {
// Convert memory space.
FailureOr<unsigned> addressSpace =
getTypeConverter()->getMemRefAddressSpace(memRefType);
if (failed(addressSpace))
return {};

// Get frequently used types.
Type indexType = getTypeConverter()->getIndexType();
Expand All @@ -248,54 +234,61 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
if (toDynamic) {
mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
if (failed(mallocFunc))
return failure();
return {};
}
if (!toDynamic) {
freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
if (failed(freeFunc))
return failure();
return {};
}

unsigned unrankedMemrefPos = 0;
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
Type type = origTypes[i];
if (!isa<UnrankedMemRefType>(type))
continue;
UnrankedMemRefDescriptor desc(operands[i]);
Value allocationSize = UnrankedMemRefDescriptor::computeSize(
builder, loc, *getTypeConverter(), desc,
unrankedAddressSpaces[unrankedMemrefPos++]);

// Allocate memory, copy, and free the source if necessary.
Value memory =
toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
allocationSize)
.getResult()
: LLVM::AllocaOp::create(builder, loc, getPtrType(),
IntegerType::get(getContext(), 8),
allocationSize,
/*alignment=*/0);
Value source = desc.memRefDescPtr(builder, loc);
LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
if (!toDynamic)
LLVM::CallOp::create(builder, loc, freeFunc.value(), source);

// Create a new descriptor. The same descriptor can be returned multiple
// times, attempting to modify its pointer can lead to memory leaks
// (allocated twice and overwritten) or double frees (the caller does not
// know if the descriptor points to the same memory).
Type descriptorType = getTypeConverter()->convertType(type);
if (!descriptorType)
return failure();
auto updatedDesc =
UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
Value rank = desc.rank(builder, loc);
updatedDesc.setRank(builder, loc, rank);
updatedDesc.setMemRefDescPtr(builder, loc, memory);
UnrankedMemRefDescriptor desc(operand);
Value allocationSize = UnrankedMemRefDescriptor::computeSize(
builder, loc, *getTypeConverter(), desc, *addressSpace);

// Allocate memory, copy, and free the source if necessary.
Value memory = toDynamic
? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
allocationSize)
.getResult()
: LLVM::AllocaOp::create(builder, loc, getPtrType(),
IntegerType::get(getContext(), 8),
allocationSize,
/*alignment=*/0);
Value source = desc.memRefDescPtr(builder, loc);
LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
if (!toDynamic)
LLVM::CallOp::create(builder, loc, freeFunc.value(), source);

// Create a new descriptor. The same descriptor can be returned multiple
// times, attempting to modify its pointer can lead to memory leaks
// (allocated twice and overwritten) or double frees (the caller does not
// know if the descriptor points to the same memory).
Type descriptorType = getTypeConverter()->convertType(memRefType);
if (!descriptorType)
return {};
auto updatedDesc =
UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
Value rank = desc.rank(builder, loc);
updatedDesc.setRank(builder, loc, rank);
updatedDesc.setMemRefDescPtr(builder, loc, memory);
return updatedDesc;
}

operands[i] = updatedDesc;
LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
OpBuilder &builder, Location loc, TypeRange origTypes,
SmallVectorImpl<Value> &operands, bool toDynamic) const {
assert(origTypes.size() == operands.size() &&
"expected as may original types as operands");
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType,
operands[i], toDynamic);
if (!updatedDesc)
return failure();
operands[i] = updatedDesc;
}
}

return success();
}

Expand Down