Skip to content

Commit b79aea3

Browse files
[mlir][LLVM] FuncToLLVM: Add 1:N support
1 parent 4c38917 commit b79aea3

File tree

7 files changed

+268
-127
lines changed

7 files changed

+268
-127
lines changed

mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,13 @@ class LLVMTypeConverter : public TypeConverter {
7474
/// LLVM-compatible type. In particular, if more than one value is returned,
7575
/// create an LLVM dialect structure type with elements that correspond to
7676
/// each of the types converted with `convertCallingConventionType`.
77-
Type packFunctionResults(TypeRange types,
78-
bool useBarePointerCallConv = false) const;
77+
///
78+
/// Populate the converted (unpacked) types into `groupedTypes`, if provided.
79+
/// `groupedType` contains one nested vector per input type. In case of a 1:N
80+
/// conversion, a nested vector may contain 0 or more then 1 converted type.
81+
Type packFunctionResults(
82+
TypeRange types, bool useBarePointerCallConv = false,
83+
SmallVector<SmallVector<Type>> *groupedTypes = nullptr) const;
7984

8085
/// Convert a non-empty list of types of values produced by an operation into
8186
/// an LLVM-compatible type. In particular, if more than one value is
@@ -88,15 +93,9 @@ class LLVMTypeConverter : public TypeConverter {
8893
/// UnrankedMemRefType, are converted following the specific rules for the
8994
/// calling convention. Calling convention independent types are converted
9095
/// following the default LLVM type conversions.
91-
Type convertCallingConventionType(Type type,
92-
bool useBarePointerCallConv = false) const;
93-
94-
/// Promote the bare pointers in 'values' that resulted from memrefs to
95-
/// descriptors. 'stdTypes' holds the types of 'values' before the conversion
96-
/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
97-
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
98-
Location loc, ArrayRef<Type> stdTypes,
99-
SmallVectorImpl<Value> &values) const;
96+
LogicalResult
97+
convertCallingConventionType(Type type, SmallVectorImpl<Type> &result,
98+
bool useBarePointerCallConv = false) const;
10099

101100
/// Returns the MLIR context.
102101
MLIRContext &getContext() const;
@@ -111,7 +110,8 @@ class LLVMTypeConverter : public TypeConverter {
111110
/// of the platform-specific C/C++ ABI lowering related to struct argument
112111
/// passing.
113112
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
114-
ValueRange operands, OpBuilder &builder,
113+
ArrayRef<ValueRange> operands,
114+
OpBuilder &builder,
115115
bool useBarePtrCallConv = false) const;
116116

117117
/// Promote the LLVM struct representation of one MemRef descriptor to stack

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -527,19 +527,19 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
527527
using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
528528
using Super = CallOpInterfaceLowering<CallOpType>;
529529
using Base = ConvertOpToLLVMPattern<CallOpType>;
530+
using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
530531

531-
LogicalResult matchAndRewriteImpl(CallOpType callOp,
532-
typename CallOpType::Adaptor adaptor,
532+
LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
533533
ConversionPatternRewriter &rewriter,
534534
bool useBarePtrCallConv = false) const {
535535
// Pack the result types into a struct.
536536
Type packedResult = nullptr;
537+
SmallVector<SmallVector<Type>> groupedResultTypes;
537538
unsigned numResults = callOp.getNumResults();
538539
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
539-
540540
if (numResults != 0) {
541541
if (!(packedResult = this->getTypeConverter()->packFunctionResults(
542-
resultTypes, useBarePtrCallConv)))
542+
resultTypes, useBarePtrCallConv, &groupedResultTypes)))
543543
return failure();
544544
}
545545

@@ -565,34 +565,61 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
565565
static_cast<int32_t>(promoted.size()), 0};
566566
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
567567

568-
SmallVector<Value, 4> results;
569-
if (numResults < 2) {
570-
// If < 2 results, packing did not do anything and we can just return.
571-
results.append(newOp.result_begin(), newOp.result_end());
572-
} else {
573-
// Otherwise, it had been converted to an operation producing a structure.
574-
// Extract individual results from the structure and return them as list.
575-
results.reserve(numResults);
576-
for (unsigned i = 0; i < numResults; ++i) {
577-
results.push_back(LLVM::ExtractValueOp::create(
578-
rewriter, callOp.getLoc(), newOp->getResult(0), i));
568+
// Helper function that extracts an individual result from the return value
569+
// of the new call op. llvm.call ops support only 0 or 1 result. In case of
570+
// 2 or more results, the results are packed into a structure.
571+
auto getUnpackedResult = [&](unsigned i) -> Value {
572+
assert(packedResult && "convert op has no results");
573+
if (!isa<LLVM::LLVMStructType>(packedResult)) {
574+
assert(i == 0 && "out of bounds: converted op has only one result");
575+
return newOp->getResult(0);
576+
}
577+
// Results have been converted to a structure. Extract individual results
578+
// from the structure.
579+
return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
580+
newOp->getResult(0), i);
581+
};
582+
583+
// Group the results into a vector of vectors, such that it is clear which
584+
// original op result is replaced with which range of values. (In case of a
585+
// 1:N conversion, there can be multiple replacements for a single result.)
586+
SmallVector<SmallVector<Value>> results;
587+
results.reserve(numResults);
588+
unsigned counter = 0;
589+
for (unsigned i = 0; i < numResults; ++i) {
590+
SmallVector<Value> &group = results.emplace_back();
591+
for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j) {
592+
group.push_back(getUnpackedResult(counter++));
579593
}
580594
}
581595

582-
if (useBarePtrCallConv) {
583-
// For the bare-ptr calling convention, promote memref results to
584-
// descriptors.
585-
assert(results.size() == resultTypes.size() &&
586-
"The number of arguments and types doesn't match");
587-
this->getTypeConverter()->promoteBarePtrsToDescriptors(
588-
rewriter, callOp.getLoc(), resultTypes, results);
589-
} else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
590-
resultTypes, results,
591-
/*toDynamic=*/false))) {
592-
return failure();
596+
// Special handling for MemRef types.
597+
for (unsigned i = 0; i < numResults; ++i) {
598+
Type origType = resultTypes[i];
599+
auto memrefType = dyn_cast<MemRefType>(origType);
600+
auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
601+
if (useBarePtrCallConv && memrefType) {
602+
// For the bare-ptr calling convention, promote memref results to
603+
// descriptors.
604+
assert(results[i].size() == 1 && "expected one converted result");
605+
results[i].front() = MemRefDescriptor::fromStaticShape(
606+
rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
607+
results[i].front());
608+
}
609+
if (unrankedMemrefType) {
610+
assert(!useBarePtrCallConv && "unranked memref is not supported in the "
611+
"bare-ptr calling convention");
612+
assert(results[i].size() == 1 && "expected one converted result");
613+
Value desc = this->copyUnrankedDescriptor(
614+
rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
615+
/*toDynamic=*/false);
616+
if (!desc)
617+
return failure();
618+
results[i].front() = desc;
619+
}
593620
}
594621

595-
rewriter.replaceOp(callOp, results);
622+
rewriter.replaceOpWithMultiple(callOp, results);
596623
return success();
597624
}
598625
};
@@ -606,7 +633,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
606633
symbolTables(symbolTables) {}
607634

608635
LogicalResult
609-
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
636+
matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
610637
ConversionPatternRewriter &rewriter) const override {
611638
bool useBarePtrCallConv = false;
612639
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
@@ -636,7 +663,7 @@ struct CallIndirectOpLowering
636663
using Super::Super;
637664

638665
LogicalResult
639-
matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
666+
matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
640667
ConversionPatternRewriter &rewriter) const override {
641668
return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
642669
}
@@ -679,47 +706,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
679706
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
680707

681708
LogicalResult
682-
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
709+
matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
683710
ConversionPatternRewriter &rewriter) const override {
684711
Location loc = op.getLoc();
685-
unsigned numArguments = op.getNumOperands();
686712
SmallVector<Value, 4> updatedOperands;
687713

688714
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
689715
bool useBarePtrCallConv =
690716
shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
691717

692-
for (auto [oldOperand, newOperand] :
718+
for (auto [oldOperand, newOperands] :
693719
llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
694720
Type oldTy = oldOperand.getType();
695721
if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
722+
assert(newOperands.size() == 1 && "expected one converted result");
696723
if (useBarePtrCallConv &&
697724
getTypeConverter()->canConvertToBarePtr(memRefType)) {
698725
// For the bare-ptr calling convention, extract the aligned pointer to
699726
// be returned from the memref descriptor.
700-
MemRefDescriptor memrefDesc(newOperand);
727+
MemRefDescriptor memrefDesc(newOperands.front());
701728
updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
702729
continue;
703730
}
704731
} else if (auto unrankedMemRefType =
705732
dyn_cast<UnrankedMemRefType>(oldTy)) {
733+
assert(newOperands.size() == 1 && "expected one converted result");
706734
if (useBarePtrCallConv) {
707735
// Unranked memref is not supported in the bare pointer calling
708736
// convention.
709737
return failure();
710738
}
711-
Value updatedDesc = copyUnrankedDescriptor(
712-
rewriter, loc, unrankedMemRefType, newOperand, /*toDynamic=*/true);
739+
Value updatedDesc =
740+
copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
741+
newOperands.front(), /*toDynamic=*/true);
713742
if (!updatedDesc)
714743
return failure();
715744
updatedOperands.push_back(updatedDesc);
716745
continue;
717746
}
718-
updatedOperands.push_back(newOperand);
747+
748+
llvm::append_range(updatedOperands, newOperands);
719749
}
720750

721751
// If ReturnOp has 0 or 1 operand, create it and return immediately.
722-
if (numArguments <= 1) {
752+
if (updatedOperands.size() <= 1) {
723753
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
724754
op, TypeRange(), updatedOperands, op->getAttrs());
725755
return success();

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -719,8 +719,10 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
719719
auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
720720
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
721721

722+
SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
723+
adaptor.getOperands(), [](Value v) { return ValueRange(v); });
722724
auto arguments = getTypeConverter()->promoteOperands(
723-
loc, op->getOperands(), adaptor.getOperands(), rewriter);
725+
loc, op->getOperands(), adaptorOperands, rewriter);
724726
arguments.push_back(elementSize);
725727
hostRegisterCallBuilder.create(loc, rewriter, arguments);
726728

@@ -741,8 +743,10 @@ LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
741743
auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
742744
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
743745

746+
SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
747+
adaptor.getOperands(), [](Value v) { return ValueRange(v); });
744748
auto arguments = getTypeConverter()->promoteOperands(
745-
loc, op->getOperands(), adaptor.getOperands(), rewriter);
749+
loc, op->getOperands(), adaptorOperands, rewriter);
746750
arguments.push_back(elementSize);
747751
hostUnregisterCallBuilder.create(loc, rewriter, arguments);
748752

@@ -973,8 +977,10 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
973977
// Note: If `useBarePtrCallConv` is set in the type converter's options,
974978
// the value of `kernelBarePtrCallConv` will be ignored.
975979
OperandRange origArguments = launchOp.getKernelOperands();
980+
SmallVector<ValueRange> adaptorOperands = llvm::map_to_vector(
981+
adaptor.getKernelOperands(), [](Value v) { return ValueRange(v); });
976982
SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
977-
loc, origArguments, adaptor.getKernelOperands(), rewriter,
983+
loc, origArguments, adaptorOperands, rewriter,
978984
/*useBarePtrCallConv=*/kernelBarePtrCallConv);
979985
SmallVector<Value, 8> llvmArgumentsWithSizes;
980986

0 commit comments

Comments
 (0)