@@ -527,19 +527,21 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
527527 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
528528 using Super = CallOpInterfaceLowering<CallOpType>;
529529 using Base = ConvertOpToLLVMPattern<CallOpType>;
530+ using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
530531
531- LogicalResult matchAndRewriteImpl (CallOpType callOp,
532- typename CallOpType::Adaptor adaptor,
532+ LogicalResult matchAndRewriteImpl (CallOpType callOp, Adaptor adaptor,
533533 ConversionPatternRewriter &rewriter,
534534 bool useBarePtrCallConv = false ) const {
535535 // Pack the result types into a struct.
536536 Type packedResult = nullptr ;
537+ SmallVector<SmallVector<Type>> groupedResultTypes;
537538 unsigned numResults = callOp.getNumResults ();
538539 auto resultTypes = llvm::to_vector<4 >(callOp.getResultTypes ());
539-
540+ int64_t numConvertedTypes = 0 ;
540541 if (numResults != 0 ) {
541542 if (!(packedResult = this ->getTypeConverter ()->packFunctionResults (
542- resultTypes, useBarePtrCallConv)))
543+ resultTypes, useBarePtrCallConv, &groupedResultTypes,
544+ &numConvertedTypes)))
543545 return failure ();
544546 }
545547
@@ -565,34 +567,64 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
565567 static_cast <int32_t >(promoted.size ()), 0 };
566568 newOp.getProperties ().op_bundle_sizes = rewriter.getDenseI32ArrayAttr ({});
567569
568- SmallVector<Value, 4 > results;
569- if (numResults < 2 ) {
570- // If < 2 results, packing did not do anything and we can just return.
571- results.append (newOp.result_begin (), newOp.result_end ());
572- } else {
573- // Otherwise, it had been converted to an operation producing a structure.
574- // Extract individual results from the structure and return them as list.
575- results.reserve (numResults);
576- for (unsigned i = 0 ; i < numResults; ++i) {
577- results.push_back (LLVM::ExtractValueOp::create (
578- rewriter, callOp.getLoc (), newOp->getResult (0 ), i));
570+ // Helper function that extracts an individual result from the return value
571+ // of the new call op. llvm.call ops support only 0 or 1 result. In case of
572+ // 2 or more results, the results are packed into a structure.
573+ //
574+ // The new call op may have more than 2 results because:
575+ // a. The original call op has more than 2 results.
576+ // b. An original op result type-converted to more than 1 result.
577+ auto getUnpackedResult = [&](unsigned i) -> Value {
578+ assert (numConvertedTypes > 0 && " convert op has no results" );
579+ if (numConvertedTypes == 1 ) {
580+ assert (i == 0 && " out of bounds: converted op has only one result" );
581+ return newOp->getResult (0 );
579582 }
583+ // Results have been converted to a structure. Extract individual results
584+ // from the structure.
585+ return LLVM::ExtractValueOp::create (rewriter, callOp.getLoc (),
586+ newOp->getResult (0 ), i);
587+ };
588+
589+ // Group the results into a vector of vectors, such that it is clear which
590+ // original op result is replaced with which range of values. (In case of a
591+ // 1:N conversion, there can be multiple replacements for a single result.)
592+ SmallVector<SmallVector<Value>> results;
593+ results.reserve (numResults);
594+ unsigned counter = 0 ;
595+ for (unsigned i = 0 ; i < numResults; ++i) {
596+ SmallVector<Value> &group = results.emplace_back ();
597+ for (unsigned j = 0 , e = groupedResultTypes[i].size (); j < e; ++j)
598+ group.push_back (getUnpackedResult (counter++));
580599 }
581600
582- if (useBarePtrCallConv) {
583- // For the bare-ptr calling convention, promote memref results to
584- // descriptors.
585- assert (results.size () == resultTypes.size () &&
586- " The number of arguments and types doesn't match" );
587- this ->getTypeConverter ()->promoteBarePtrsToDescriptors (
588- rewriter, callOp.getLoc (), resultTypes, results);
589- } else if (failed (this ->copyUnrankedDescriptors (rewriter, callOp.getLoc (),
590- resultTypes, results,
591- /* toDynamic=*/ false ))) {
592- return failure ();
601+ // Special handling for MemRef types.
602+ for (unsigned i = 0 ; i < numResults; ++i) {
603+ Type origType = resultTypes[i];
604+ auto memrefType = dyn_cast<MemRefType>(origType);
605+ auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
606+ if (useBarePtrCallConv && memrefType) {
607+ // For the bare-ptr calling convention, promote memref results to
608+ // descriptors.
609+ assert (results[i].size () == 1 && " expected one converted result" );
610+ results[i].front () = MemRefDescriptor::fromStaticShape (
611+ rewriter, callOp.getLoc (), *this ->getTypeConverter (), memrefType,
612+ results[i].front ());
613+ }
614+ if (unrankedMemrefType) {
615+ assert (!useBarePtrCallConv && " unranked memref is not supported in the "
616+ " bare-ptr calling convention" );
617+ assert (results[i].size () == 1 && " expected one converted result" );
618+ Value desc = this ->copyUnrankedDescriptor (
619+ rewriter, callOp.getLoc (), unrankedMemrefType, results[i].front (),
620+ /* toDynamic=*/ false );
621+ if (!desc)
622+ return failure ();
623+ results[i].front () = desc;
624+ }
593625 }
594626
595- rewriter.replaceOp (callOp, results);
627+ rewriter.replaceOpWithMultiple (callOp, results);
596628 return success ();
597629 }
598630};
@@ -606,7 +638,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
606638 symbolTables(symbolTables) {}
607639
608640 LogicalResult
609- matchAndRewrite (func::CallOp callOp, OpAdaptor adaptor,
641+ matchAndRewrite (func::CallOp callOp, OneToNOpAdaptor adaptor,
610642 ConversionPatternRewriter &rewriter) const override {
611643 bool useBarePtrCallConv = false ;
612644 if (getTypeConverter ()->getOptions ().useBarePtrCallConv ) {
@@ -636,7 +668,7 @@ struct CallIndirectOpLowering
636668 using Super::Super;
637669
638670 LogicalResult
639- matchAndRewrite (func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
671+ matchAndRewrite (func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
640672 ConversionPatternRewriter &rewriter) const override {
641673 return matchAndRewriteImpl (callIndirectOp, adaptor, rewriter);
642674 }
@@ -679,47 +711,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
679711 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
680712
681713 LogicalResult
682- matchAndRewrite (func::ReturnOp op, OpAdaptor adaptor,
714+ matchAndRewrite (func::ReturnOp op, OneToNOpAdaptor adaptor,
683715 ConversionPatternRewriter &rewriter) const override {
684716 Location loc = op.getLoc ();
685- unsigned numArguments = op.getNumOperands ();
686717 SmallVector<Value, 4 > updatedOperands;
687718
688719 auto funcOp = op->getParentOfType <LLVM::LLVMFuncOp>();
689720 bool useBarePtrCallConv =
690721 shouldUseBarePtrCallConv (funcOp, this ->getTypeConverter ());
691722
692- for (auto [oldOperand, newOperand ] :
723+ for (auto [oldOperand, newOperands ] :
693724 llvm::zip_equal (op->getOperands (), adaptor.getOperands ())) {
694725 Type oldTy = oldOperand.getType ();
695726 if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
727+ assert (newOperands.size () == 1 && " expected one converted result" );
696728 if (useBarePtrCallConv &&
697729 getTypeConverter ()->canConvertToBarePtr (memRefType)) {
698730 // For the bare-ptr calling convention, extract the aligned pointer to
699731 // be returned from the memref descriptor.
700- MemRefDescriptor memrefDesc (newOperand );
732+ MemRefDescriptor memrefDesc (newOperands. front () );
701733 updatedOperands.push_back (memrefDesc.allocatedPtr (rewriter, loc));
702734 continue ;
703735 }
704736 } else if (auto unrankedMemRefType =
705737 dyn_cast<UnrankedMemRefType>(oldTy)) {
738+ assert (newOperands.size () == 1 && " expected one converted result" );
706739 if (useBarePtrCallConv) {
707740 // Unranked memref is not supported in the bare pointer calling
708741 // convention.
709742 return failure ();
710743 }
711- Value updatedDesc = copyUnrankedDescriptor (
712- rewriter, loc, unrankedMemRefType, newOperand, /* toDynamic=*/ true );
744+ Value updatedDesc =
745+ copyUnrankedDescriptor (rewriter, loc, unrankedMemRefType,
746+ newOperands.front (), /* toDynamic=*/ true );
713747 if (!updatedDesc)
714748 return failure ();
715749 updatedOperands.push_back (updatedDesc);
716750 continue ;
717751 }
718- updatedOperands.push_back (newOperand);
752+
753+ llvm::append_range (updatedOperands, newOperands);
719754 }
720755
721756 // If ReturnOp has 0 or 1 operand, create it and return immediately.
722- if (numArguments <= 1 ) {
757+ if (updatedOperands. size () <= 1 ) {
723758 rewriter.replaceOpWithNewOp <LLVM::ReturnOp>(
724759 op, TypeRange (), updatedOperands, op->getAttrs ());
725760 return success ();
0 commit comments