Skip to content

Commit f0967fc

Browse files
[mlir][LLVM] FuncToLLVM: Add 1:N type conversion support (#153823)
Add support for 1:N type conversions to the `FuncToLLVM` lowering patterns. This commit does not change the lowering of any types (such as `MemRefType`). It just sets up the infrastructure, such that 1:N type conversions can be used during `FuncToLLVM`. Note: When the converted result types of a `func.func` have more than 1 type, then the results are wrapped in an `llvm.struct`. That's because `llvm.func` does not support multiple result values. This "wrapping" was already implemented for cases where the original `func.func` has multiple results. With 1:N conversions, even a single result can now expand to multiple converted results, triggering the same wrapping mechanism. The test cases are exercised with both the old and the new no-rollback conversion driver.
1 parent 1d73b2c commit f0967fc

File tree

6 files changed

+281
-125
lines changed

6 files changed

+281
-125
lines changed

mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,14 @@ class LLVMTypeConverter : public TypeConverter {
7474
/// LLVM-compatible type. In particular, if more than one value is returned,
7575
/// create an LLVM dialect structure type with elements that correspond to
7676
/// each of the types converted with `convertCallingConventionType`.
77-
Type packFunctionResults(TypeRange types,
78-
bool useBarePointerCallConv = false) const;
77+
///
78+
/// Populate the converted (unpacked) types into `groupedTypes`, if provided.
79+
/// `groupedType` contains one nested vector per input type. In case of a 1:N
80+
/// conversion, a nested vector may contain 0 or more then 1 converted type.
81+
Type
82+
packFunctionResults(TypeRange types, bool useBarePointerCallConv = false,
83+
SmallVector<SmallVector<Type>> *groupedTypes = nullptr,
84+
int64_t *numConvertedTypes = nullptr) const;
7985

8086
/// Convert a non-empty list of types of values produced by an operation into
8187
/// an LLVM-compatible type. In particular, if more than one value is
@@ -88,15 +94,9 @@ class LLVMTypeConverter : public TypeConverter {
8894
/// UnrankedMemRefType, are converted following the specific rules for the
8995
/// calling convention. Calling convention independent types are converted
9096
/// following the default LLVM type conversions.
91-
Type convertCallingConventionType(Type type,
92-
bool useBarePointerCallConv = false) const;
93-
94-
/// Promote the bare pointers in 'values' that resulted from memrefs to
95-
/// descriptors. 'stdTypes' holds the types of 'values' before the conversion
96-
/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
97-
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
98-
Location loc, ArrayRef<Type> stdTypes,
99-
SmallVectorImpl<Value> &values) const;
97+
LogicalResult
98+
convertCallingConventionType(Type type, SmallVectorImpl<Type> &result,
99+
bool useBarePointerCallConv = false) const;
100100

101101
/// Returns the MLIR context.
102102
MLIRContext &getContext() const;
@@ -109,9 +109,14 @@ class LLVMTypeConverter : public TypeConverter {
109109
/// Promote the LLVM representation of all operands including promoting MemRef
110110
/// descriptors to stack and use pointers to struct to avoid the complexity
111111
/// of the platform-specific C/C++ ABI lowering related to struct argument
112-
/// passing.
112+
/// passing. (The ArrayRef variant is for 1:N.)
113+
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
114+
ArrayRef<ValueRange> adaptorOperands,
115+
OpBuilder &builder,
116+
bool useBarePtrCallConv = false) const;
113117
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
114-
ValueRange operands, OpBuilder &builder,
118+
ValueRange adaptorOperands,
119+
OpBuilder &builder,
115120
bool useBarePtrCallConv = false) const;
116121

117122
/// Promote the LLVM struct representation of one MemRef descriptor to stack

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 72 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -527,19 +527,21 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
527527
using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
528528
using Super = CallOpInterfaceLowering<CallOpType>;
529529
using Base = ConvertOpToLLVMPattern<CallOpType>;
530+
using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
530531

531-
LogicalResult matchAndRewriteImpl(CallOpType callOp,
532-
typename CallOpType::Adaptor adaptor,
532+
LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
533533
ConversionPatternRewriter &rewriter,
534534
bool useBarePtrCallConv = false) const {
535535
// Pack the result types into a struct.
536536
Type packedResult = nullptr;
537+
SmallVector<SmallVector<Type>> groupedResultTypes;
537538
unsigned numResults = callOp.getNumResults();
538539
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
539-
540+
int64_t numConvertedTypes = 0;
540541
if (numResults != 0) {
541542
if (!(packedResult = this->getTypeConverter()->packFunctionResults(
542-
resultTypes, useBarePtrCallConv)))
543+
resultTypes, useBarePtrCallConv, &groupedResultTypes,
544+
&numConvertedTypes)))
543545
return failure();
544546
}
545547

@@ -565,34 +567,64 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
565567
static_cast<int32_t>(promoted.size()), 0};
566568
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
567569

568-
SmallVector<Value, 4> results;
569-
if (numResults < 2) {
570-
// If < 2 results, packing did not do anything and we can just return.
571-
results.append(newOp.result_begin(), newOp.result_end());
572-
} else {
573-
// Otherwise, it had been converted to an operation producing a structure.
574-
// Extract individual results from the structure and return them as list.
575-
results.reserve(numResults);
576-
for (unsigned i = 0; i < numResults; ++i) {
577-
results.push_back(LLVM::ExtractValueOp::create(
578-
rewriter, callOp.getLoc(), newOp->getResult(0), i));
570+
// Helper function that extracts an individual result from the return value
571+
// of the new call op. llvm.call ops support only 0 or 1 result. In case of
572+
// 2 or more results, the results are packed into a structure.
573+
//
574+
// The new call op may have more than 2 results because:
575+
// a. The original call op has more than 2 results.
576+
// b. An original op result type-converted to more than 1 result.
577+
auto getUnpackedResult = [&](unsigned i) -> Value {
578+
assert(numConvertedTypes > 0 && "convert op has no results");
579+
if (numConvertedTypes == 1) {
580+
assert(i == 0 && "out of bounds: converted op has only one result");
581+
return newOp->getResult(0);
579582
}
583+
// Results have been converted to a structure. Extract individual results
584+
// from the structure.
585+
return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
586+
newOp->getResult(0), i);
587+
};
588+
589+
// Group the results into a vector of vectors, such that it is clear which
590+
// original op result is replaced with which range of values. (In case of a
591+
// 1:N conversion, there can be multiple replacements for a single result.)
592+
SmallVector<SmallVector<Value>> results;
593+
results.reserve(numResults);
594+
unsigned counter = 0;
595+
for (unsigned i = 0; i < numResults; ++i) {
596+
SmallVector<Value> &group = results.emplace_back();
597+
for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j)
598+
group.push_back(getUnpackedResult(counter++));
580599
}
581600

582-
if (useBarePtrCallConv) {
583-
// For the bare-ptr calling convention, promote memref results to
584-
// descriptors.
585-
assert(results.size() == resultTypes.size() &&
586-
"The number of arguments and types doesn't match");
587-
this->getTypeConverter()->promoteBarePtrsToDescriptors(
588-
rewriter, callOp.getLoc(), resultTypes, results);
589-
} else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
590-
resultTypes, results,
591-
/*toDynamic=*/false))) {
592-
return failure();
601+
// Special handling for MemRef types.
602+
for (unsigned i = 0; i < numResults; ++i) {
603+
Type origType = resultTypes[i];
604+
auto memrefType = dyn_cast<MemRefType>(origType);
605+
auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
606+
if (useBarePtrCallConv && memrefType) {
607+
// For the bare-ptr calling convention, promote memref results to
608+
// descriptors.
609+
assert(results[i].size() == 1 && "expected one converted result");
610+
results[i].front() = MemRefDescriptor::fromStaticShape(
611+
rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
612+
results[i].front());
613+
}
614+
if (unrankedMemrefType) {
615+
assert(!useBarePtrCallConv && "unranked memref is not supported in the "
616+
"bare-ptr calling convention");
617+
assert(results[i].size() == 1 && "expected one converted result");
618+
Value desc = this->copyUnrankedDescriptor(
619+
rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
620+
/*toDynamic=*/false);
621+
if (!desc)
622+
return failure();
623+
results[i].front() = desc;
624+
}
593625
}
594626

595-
rewriter.replaceOp(callOp, results);
627+
rewriter.replaceOpWithMultiple(callOp, results);
596628
return success();
597629
}
598630
};
@@ -606,7 +638,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
606638
symbolTables(symbolTables) {}
607639

608640
LogicalResult
609-
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
641+
matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
610642
ConversionPatternRewriter &rewriter) const override {
611643
bool useBarePtrCallConv = false;
612644
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
@@ -636,7 +668,7 @@ struct CallIndirectOpLowering
636668
using Super::Super;
637669

638670
LogicalResult
639-
matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
671+
matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
640672
ConversionPatternRewriter &rewriter) const override {
641673
return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
642674
}
@@ -679,47 +711,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
679711
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
680712

681713
LogicalResult
682-
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
714+
matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
683715
ConversionPatternRewriter &rewriter) const override {
684716
Location loc = op.getLoc();
685-
unsigned numArguments = op.getNumOperands();
686717
SmallVector<Value, 4> updatedOperands;
687718

688719
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
689720
bool useBarePtrCallConv =
690721
shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
691722

692-
for (auto [oldOperand, newOperand] :
723+
for (auto [oldOperand, newOperands] :
693724
llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
694725
Type oldTy = oldOperand.getType();
695726
if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
727+
assert(newOperands.size() == 1 && "expected one converted result");
696728
if (useBarePtrCallConv &&
697729
getTypeConverter()->canConvertToBarePtr(memRefType)) {
698730
// For the bare-ptr calling convention, extract the aligned pointer to
699731
// be returned from the memref descriptor.
700-
MemRefDescriptor memrefDesc(newOperand);
732+
MemRefDescriptor memrefDesc(newOperands.front());
701733
updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
702734
continue;
703735
}
704736
} else if (auto unrankedMemRefType =
705737
dyn_cast<UnrankedMemRefType>(oldTy)) {
738+
assert(newOperands.size() == 1 && "expected one converted result");
706739
if (useBarePtrCallConv) {
707740
// Unranked memref is not supported in the bare pointer calling
708741
// convention.
709742
return failure();
710743
}
711-
Value updatedDesc = copyUnrankedDescriptor(
712-
rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true);
744+
Value updatedDesc =
745+
copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
746+
newOperands.front(), /*toDynamic=*/true);
713747
if (!updatedDesc)
714748
return failure();
715749
updatedOperands.push_back(updatedDesc);
716750
continue;
717751
}
718-
updatedOperands.push_back(newOperand);
752+
753+
llvm::append_range(updatedOperands, newOperands);
719754
}
720755

721756
// If ReturnOp has 0 or 1 operand, create it and return immediately.
722-
if (numArguments <= 1) {
757+
if (updatedOperands.size() <= 1) {
723758
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
724759
op, TypeRange(), updatedOperands, op->getAttrs());
725760
return success();

0 commit comments

Comments
 (0)