@@ -527,19 +527,19 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
527527 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
528528 using Super = CallOpInterfaceLowering<CallOpType>;
529529 using Base = ConvertOpToLLVMPattern<CallOpType>;
530+ using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
530531
531- LogicalResult matchAndRewriteImpl (CallOpType callOp,
532- typename CallOpType::Adaptor adaptor,
532+ LogicalResult matchAndRewriteImpl (CallOpType callOp, Adaptor adaptor,
533533 ConversionPatternRewriter &rewriter,
534534 bool useBarePtrCallConv = false ) const {
535535 // Pack the result types into a struct.
536536 Type packedResult = nullptr ;
537+ SmallVector<SmallVector<Type>> groupedResultTypes;
537538 unsigned numResults = callOp.getNumResults ();
538539 auto resultTypes = llvm::to_vector<4 >(callOp.getResultTypes ());
539-
540540 if (numResults != 0 ) {
541541 if (!(packedResult = this ->getTypeConverter ()->packFunctionResults (
542- resultTypes, useBarePtrCallConv)))
542+ resultTypes, useBarePtrCallConv, &groupedResultTypes )))
543543 return failure ();
544544 }
545545
@@ -565,34 +565,61 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
565565 static_cast <int32_t >(promoted.size ()), 0 };
566566 newOp.getProperties ().op_bundle_sizes = rewriter.getDenseI32ArrayAttr ({});
567567
568- SmallVector<Value, 4 > results;
569- if (numResults < 2 ) {
570- // If < 2 results, packing did not do anything and we can just return.
571- results.append (newOp.result_begin (), newOp.result_end ());
572- } else {
573- // Otherwise, it had been converted to an operation producing a structure.
574- // Extract individual results from the structure and return them as list.
575- results.reserve (numResults);
576- for (unsigned i = 0 ; i < numResults; ++i) {
577- results.push_back (LLVM::ExtractValueOp::create (
578- rewriter, callOp.getLoc (), newOp->getResult (0 ), i));
568+ // Helper function that extracts an individual result from the return value
569+ // of the new call op. llvm.call ops support only 0 or 1 result. In case of
570+ // 2 or more results, the results are packed into a structure.
571+ auto getUnpackedResult = [&](unsigned i) -> Value {
572+ assert (packedResult && " convert op has no results" );
573+ if (!isa<LLVM::LLVMStructType>(packedResult)) {
574+ assert (i == 0 && " out of bounds: converted op has only one result" );
575+ return newOp->getResult (0 );
576+ }
577+ // Results have been converted to a structure. Extract individual results
578+ // from the structure.
579+ return LLVM::ExtractValueOp::create (rewriter, callOp.getLoc (),
580+ newOp->getResult (0 ), i);
581+ };
582+
583+ // Group the results into a vector of vectors, such that it is clear which
584+ // original op result is replaced with which range of values. (In case of a
585+ // 1:N conversion, there can be multiple replacements for a single result.)
586+ SmallVector<SmallVector<Value>> results;
587+ results.reserve (numResults);
588+ unsigned counter = 0 ;
589+ for (unsigned i = 0 ; i < numResults; ++i) {
590+ SmallVector<Value> &group = results.emplace_back ();
591+ for (unsigned j = 0 , e = groupedResultTypes[i].size (); j < e; ++j) {
592+ group.push_back (getUnpackedResult (counter++));
579593 }
580594 }
581595
582- if (useBarePtrCallConv) {
583- // For the bare-ptr calling convention, promote memref results to
584- // descriptors.
585- assert (results.size () == resultTypes.size () &&
586- " The number of arguments and types doesn't match" );
587- this ->getTypeConverter ()->promoteBarePtrsToDescriptors (
588- rewriter, callOp.getLoc (), resultTypes, results);
589- } else if (failed (this ->copyUnrankedDescriptors (rewriter, callOp.getLoc (),
590- resultTypes, results,
591- /* toDynamic=*/ false ))) {
592- return failure ();
596+ // Special handling for MemRef types.
597+ for (unsigned i = 0 ; i < numResults; ++i) {
598+ Type origType = resultTypes[i];
599+ auto memrefType = dyn_cast<MemRefType>(origType);
600+ auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
601+ if (useBarePtrCallConv && memrefType) {
602+ // For the bare-ptr calling convention, promote memref results to
603+ // descriptors.
604+ assert (results[i].size () == 1 && " expected one converted result" );
605+ results[i].front () = MemRefDescriptor::fromStaticShape (
606+ rewriter, callOp.getLoc (), *this ->getTypeConverter (), memrefType,
607+ results[i].front ());
608+ }
609+ if (unrankedMemrefType) {
610+ assert (!useBarePtrCallConv && " unranked memref is not supported in the "
611+ " bare-ptr calling convention" );
612+ assert (results[i].size () == 1 && " expected one converted result" );
613+ Value desc = this ->copyUnrankedDescriptor (
614+ rewriter, callOp.getLoc (), unrankedMemrefType, results[i].front (),
615+ /* toDynamic=*/ false );
616+ if (!desc)
617+ return failure ();
618+ results[i].front () = desc;
619+ }
593620 }
594621
595- rewriter.replaceOp (callOp, results);
622+ rewriter.replaceOpWithMultiple (callOp, results);
596623 return success ();
597624 }
598625};
@@ -606,7 +633,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
606633 symbolTables(symbolTables) {}
607634
608635 LogicalResult
609- matchAndRewrite (func::CallOp callOp, OpAdaptor adaptor,
636+ matchAndRewrite (func::CallOp callOp, OneToNOpAdaptor adaptor,
610637 ConversionPatternRewriter &rewriter) const override {
611638 bool useBarePtrCallConv = false ;
612639 if (getTypeConverter ()->getOptions ().useBarePtrCallConv ) {
@@ -636,7 +663,7 @@ struct CallIndirectOpLowering
636663 using Super::Super;
637664
638665 LogicalResult
639- matchAndRewrite (func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
666+ matchAndRewrite (func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
640667 ConversionPatternRewriter &rewriter) const override {
641668 return matchAndRewriteImpl (callIndirectOp, adaptor, rewriter);
642669 }
@@ -679,47 +706,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
679706 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
680707
681708 LogicalResult
682- matchAndRewrite (func::ReturnOp op, OpAdaptor adaptor,
709+ matchAndRewrite (func::ReturnOp op, OneToNOpAdaptor adaptor,
683710 ConversionPatternRewriter &rewriter) const override {
684711 Location loc = op.getLoc ();
685- unsigned numArguments = op.getNumOperands ();
686712 SmallVector<Value, 4 > updatedOperands;
687713
688714 auto funcOp = op->getParentOfType <LLVM::LLVMFuncOp>();
689715 bool useBarePtrCallConv =
690716 shouldUseBarePtrCallConv (funcOp, this ->getTypeConverter ());
691717
692- for (auto [oldOperand, newOperand ] :
718+ for (auto [oldOperand, newOperands ] :
693719 llvm::zip_equal (op->getOperands (), adaptor.getOperands ())) {
694720 Type oldTy = oldOperand.getType ();
695721 if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
722+ assert (newOperands.size () == 1 && " expected one converted result" );
696723 if (useBarePtrCallConv &&
697724 getTypeConverter ()->canConvertToBarePtr (memRefType)) {
698725 // For the bare-ptr calling convention, extract the aligned pointer to
699726 // be returned from the memref descriptor.
700- MemRefDescriptor memrefDesc (newOperand );
727+ MemRefDescriptor memrefDesc (newOperands. front () );
701728 updatedOperands.push_back (memrefDesc.allocatedPtr (rewriter, loc));
702729 continue ;
703730 }
704731 } else if (auto unrankedMemRefType =
705732 dyn_cast<UnrankedMemRefType>(oldTy)) {
733+ assert (newOperands.size () == 1 && " expected one converted result" );
706734 if (useBarePtrCallConv) {
707735 // Unranked memref is not supported in the bare pointer calling
708736 // convention.
709737 return failure ();
710738 }
711- Value updatedDesc = copyUnrankedDescriptor (
712- rewriter, loc, unrankedMemRefType, newOperand, /* toDynamic=*/ true );
739+ Value updatedDesc =
740+ copyUnrankedDescriptor (rewriter, loc, unrankedMemRefType,
741+ newOperands.front (), /* toDynamic=*/ true );
713742 if (!updatedDesc)
714743 return failure ();
715744 updatedOperands.push_back (updatedDesc);
716745 continue ;
717746 }
718- updatedOperands.push_back (newOperand);
747+
748+ llvm::append_range (updatedOperands, newOperands);
719749 }
720750
721751 // If ReturnOp has 0 or 1 operand, create it and return immediately.
722- if (numArguments <= 1 ) {
752+ if (updatedOperands. size () <= 1 ) {
723753 rewriter.replaceOpWithNewOp <LLVM::ReturnOp>(
724754 op, TypeRange (), updatedOperands, op->getAttrs ());
725755 return success ();
0 commit comments