Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,14 @@ class LLVMTypeConverter : public TypeConverter {
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to
/// each of the types converted with `convertCallingConventionType`.
Type packFunctionResults(TypeRange types,
bool useBarePointerCallConv = false) const;
///
/// Populate the converted (unpacked) types into `groupedTypes`, if provided.
/// `groupedType` contains one nested vector per input type. In case of a 1:N
/// conversion, a nested vector may contain 0 or more then 1 converted type.
Type
packFunctionResults(TypeRange types, bool useBarePointerCallConv = false,
SmallVector<SmallVector<Type>> *groupedTypes = nullptr,
int64_t *numConvertedTypes = nullptr) const;

/// Convert a non-empty list of types of values produced by an operation into
/// an LLVM-compatible type. In particular, if more than one value is
Expand All @@ -88,15 +94,9 @@ class LLVMTypeConverter : public TypeConverter {
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
Type convertCallingConventionType(Type type,
bool useBarePointerCallConv = false) const;

/// Promote the bare pointers in 'values' that resulted from memrefs to
/// descriptors. 'stdTypes' holds the types of 'values' before the conversion
/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
Location loc, ArrayRef<Type> stdTypes,
SmallVectorImpl<Value> &values) const;
LogicalResult
convertCallingConventionType(Type type, SmallVectorImpl<Type> &result,
bool useBarePointerCallConv = false) const;

/// Returns the MLIR context.
MLIRContext &getContext() const;
Expand All @@ -109,9 +109,14 @@ class LLVMTypeConverter : public TypeConverter {
/// Promote the LLVM representation of all operands including promoting MemRef
/// descriptors to stack and use pointers to struct to avoid the complexity
/// of the platform-specific C/C++ ABI lowering related to struct argument
/// passing.
/// passing. (The ArrayRef variant is for 1:N.)
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
ArrayRef<ValueRange> adaptorOperands,
OpBuilder &builder,
bool useBarePtrCallConv = false) const;
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
ValueRange operands, OpBuilder &builder,
ValueRange adaptorOperands,
OpBuilder &builder,
bool useBarePtrCallConv = false) const;

/// Promote the LLVM struct representation of one MemRef descriptor to stack
Expand Down
109 changes: 72 additions & 37 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,19 +527,21 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = ConvertOpToLLVMPattern<CallOpType>;
using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;

LogicalResult matchAndRewriteImpl(CallOpType callOp,
typename CallOpType::Adaptor adaptor,
LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter,
bool useBarePtrCallConv = false) const {
// Pack the result types into a struct.
Type packedResult = nullptr;
SmallVector<SmallVector<Type>> groupedResultTypes;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());

int64_t numConvertedTypes = 0;
if (numResults != 0) {
if (!(packedResult = this->getTypeConverter()->packFunctionResults(
resultTypes, useBarePtrCallConv)))
resultTypes, useBarePtrCallConv, &groupedResultTypes,
&numConvertedTypes)))
return failure();
}

Expand All @@ -565,34 +567,64 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
static_cast<int32_t>(promoted.size()), 0};
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});

SmallVector<Value, 4> results;
if (numResults < 2) {
// If < 2 results, packing did not do anything and we can just return.
results.append(newOp.result_begin(), newOp.result_end());
} else {
// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
results.push_back(LLVM::ExtractValueOp::create(
rewriter, callOp.getLoc(), newOp->getResult(0), i));
// Helper function that extracts an individual result from the return value
// of the new call op. llvm.call ops support only 0 or 1 result. In case of
// 2 or more results, the results are packed into a structure.
//
// The new call op may have more than 2 results because:
// a. The original call op has more than 2 results.
// b. An original op result type-converted to more than 1 result.
auto getUnpackedResult = [&](unsigned i) -> Value {
assert(numConvertedTypes > 0 && "convert op has no results");
if (numConvertedTypes == 1) {
assert(i == 0 && "out of bounds: converted op has only one result");
return newOp->getResult(0);
}
// Results have been converted to a structure. Extract individual results
// from the structure.
return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
newOp->getResult(0), i);
};

// Group the results into a vector of vectors, such that it is clear which
// original op result is replaced with which range of values. (In case of a
// 1:N conversion, there can be multiple replacements for a single result.)
SmallVector<SmallVector<Value>> results;
results.reserve(numResults);
unsigned counter = 0;
for (unsigned i = 0; i < numResults; ++i) {
SmallVector<Value> &group = results.emplace_back();
for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j)
group.push_back(getUnpackedResult(counter++));
}

if (useBarePtrCallConv) {
// For the bare-ptr calling convention, promote memref results to
// descriptors.
assert(results.size() == resultTypes.size() &&
"The number of arguments and types doesn't match");
this->getTypeConverter()->promoteBarePtrsToDescriptors(
rewriter, callOp.getLoc(), resultTypes, results);
} else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
resultTypes, results,
/*toDynamic=*/false))) {
return failure();
// Special handling for MemRef types.
for (unsigned i = 0; i < numResults; ++i) {
Type origType = resultTypes[i];
auto memrefType = dyn_cast<MemRefType>(origType);
auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
if (useBarePtrCallConv && memrefType) {
// For the bare-ptr calling convention, promote memref results to
// descriptors.
assert(results[i].size() == 1 && "expected one converted result");
results[i].front() = MemRefDescriptor::fromStaticShape(
rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
results[i].front());
}
if (unrankedMemrefType) {
assert(!useBarePtrCallConv && "unranked memref is not supported in the "
"bare-ptr calling convention");
assert(results[i].size() == 1 && "expected one converted result");
Value desc = this->copyUnrankedDescriptor(
rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
/*toDynamic=*/false);
if (!desc)
return failure();
results[i].front() = desc;
}
}

rewriter.replaceOp(callOp, results);
rewriter.replaceOpWithMultiple(callOp, results);
return success();
}
};
Expand All @@ -606,7 +638,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
symbolTables(symbolTables) {}

LogicalResult
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool useBarePtrCallConv = false;
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
Expand Down Expand Up @@ -636,7 +668,7 @@ struct CallIndirectOpLowering
using Super::Super;

LogicalResult
matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
}
Expand Down Expand Up @@ -679,47 +711,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
unsigned numArguments = op.getNumOperands();
SmallVector<Value, 4> updatedOperands;

auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
bool useBarePtrCallConv =
shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());

for (auto [oldOperand, newOperand] :
for (auto [oldOperand, newOperands] :
llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
Type oldTy = oldOperand.getType();
if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
assert(newOperands.size() == 1 && "expected one converted result");
if (useBarePtrCallConv &&
getTypeConverter()->canConvertToBarePtr(memRefType)) {
// For the bare-ptr calling convention, extract the aligned pointer to
// be returned from the memref descriptor.
MemRefDescriptor memrefDesc(newOperand);
MemRefDescriptor memrefDesc(newOperands.front());
updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
continue;
}
} else if (auto unrankedMemRefType =
dyn_cast<UnrankedMemRefType>(oldTy)) {
assert(newOperands.size() == 1 && "expected one converted result");
if (useBarePtrCallConv) {
// Unranked memref is not supported in the bare pointer calling
// convention.
return failure();
}
Value updatedDesc = copyUnrankedDescriptor(
rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true);
Value updatedDesc =
copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
newOperands.front(), /*toDynamic=*/true);
if (!updatedDesc)
return failure();
updatedOperands.push_back(updatedDesc);
continue;
}
updatedOperands.push_back(newOperand);

llvm::append_range(updatedOperands, newOperands);
}

// If ReturnOp has 0 or 1 operand, create it and return immediately.
if (numArguments <= 1) {
if (updatedOperands.size() <= 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, TypeRange(), updatedOperands, op->getAttrs());
return success();
Expand Down
Loading