@@ -419,9 +419,10 @@ static aievec::CmpOp createCmpOpAIE2(ConversionPatternRewriter &rewriter,
419419}
420420
421421template <typename DstOpTy>
422- static void generateAIEVecOpsForReductionOp (ConversionPatternRewriter &rewriter,
423- vector::ReductionOp srcOp,
424- int shiftIndex, Value curValue) {
422+ static aievec::ExtElemOp
423+ generateAIEVecOpsForReductionOp (ConversionPatternRewriter &rewriter,
424+ vector::ReductionOp srcOp, int shiftIndex,
425+ Value curValue) {
425426 assert (shiftIndex > 0 && (shiftIndex & (shiftIndex - 1 )) == 0 &&
426427 " shiftIndex must be power of 2" );
427428
@@ -447,8 +448,8 @@ static void generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
447448
448449 auto zeroConstOp =
449450 rewriter.create <arith::ConstantOp>(loc, rewriter.getI32IntegerAttr (0 ));
450- rewriter.replaceOpWithNewOp <aievec::ExtElemOp>(srcOp , scalarType, curOp,
451- zeroConstOp.getResult ());
451+ return rewriter.create <aievec::ExtElemOp>(loc , scalarType, curOp,
452+ zeroConstOp.getResult ());
452453}
453454
454455static func::FuncOp getOrInsertFuncDecl (ConversionPatternRewriter &rewriter,
@@ -1476,6 +1477,8 @@ using LowerVectorMinimumFOpToAIEVecMinOp =
14761477 LowerVectorMinMaxOpToAIEVecMinMaxOp<arith::MinimumFOp, aievec::MinOp>;
14771478using LowerVectorMaximumFOpToAIEVecMaxOp =
14781479 LowerVectorMinMaxOpToAIEVecMinMaxOp<arith::MaximumFOp, aievec::MaxOp>;
1480+ using LowerVectorMaxNumFFOpToAIEVecMaxOp =
1481+ LowerVectorMinMaxOpToAIEVecMinMaxOp<arith::MaxNumFOp, aievec::MaxOp>;
14791482
14801483template <typename SrcOpTy, typename CmpTy>
14811484struct LowerVectorCmpOpToAIEVecCmpOp : OpConversionPattern<SrcOpTy> {
@@ -1591,8 +1594,14 @@ struct LowerVectorReductionMinOp : OpConversionPattern<vector::ReductionOp> {
15911594 return failure ();
15921595
15931596 int shiftIndex = laneSize / 2 ;
1594- generateAIEVecOpsForReductionOp<aievec::MinOp>(rewriter, srcOp, shiftIndex,
1595- srcOp.getVector ());
1597+ auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::MinOp>(
1598+ rewriter, srcOp, shiftIndex, srcOp.getVector ());
1599+
1600+ if (srcOp.getAcc ())
1601+ rewriter.replaceOpWithNewOp <arith::MinimumFOp>(
1602+ srcOp, reduceResultOp.getResult (), srcOp.getAcc ());
1603+ else
1604+ rewriter.replaceOp (srcOp, reduceResultOp);
15961605 return success ();
15971606 }
15981607};
@@ -1605,7 +1614,8 @@ struct LowerVectorReductionMaxOp : OpConversionPattern<vector::ReductionOp> {
16051614 ConversionPatternRewriter &rewriter) const override {
16061615 if (auto kind = srcOp.getKind (); kind != vector::CombiningKind::MAXSI &&
16071616 kind != vector::CombiningKind::MAXUI &&
1608- kind != vector::CombiningKind::MAXIMUMF)
1617+ kind != vector::CombiningKind::MAXIMUMF &&
1618+ kind != vector::CombiningKind::MAXNUMF)
16091619 return failure ();
16101620
16111621 auto vType = cast<VectorType>(srcOp.getVector ().getType ());
@@ -1617,8 +1627,14 @@ struct LowerVectorReductionMaxOp : OpConversionPattern<vector::ReductionOp> {
16171627 return failure ();
16181628
16191629 int shiftIndex = laneSize / 2 ;
1620- generateAIEVecOpsForReductionOp<aievec::MaxOp>(rewriter, srcOp, shiftIndex,
1621- srcOp.getVector ());
1630+ auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::MaxOp>(
1631+ rewriter, srcOp, shiftIndex, srcOp.getVector ());
1632+
1633+ if (srcOp.getAcc ())
1634+ rewriter.replaceOpWithNewOp <arith::MaximumFOp>(
1635+ srcOp, reduceResultOp.getResult (), srcOp.getAcc ());
1636+ else
1637+ rewriter.replaceOp (srcOp, reduceResultOp);
16221638 return success ();
16231639 }
16241640};
@@ -1659,11 +1675,22 @@ struct LowerVectorReductionAddIntOp : OpConversionPattern<vector::ReductionOp> {
16591675 loc, lExtOp.getResult ().getType (), lExtOp.getResult (),
16601676 rExtOp.getResult ());
16611677 shiftIndex /= 2 ;
1662- generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1678+ auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
16631679 rewriter, srcOp, shiftIndex, addElemOp.getResult ());
1664- } else
1665- generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1680+ if (srcOp.getAcc ())
1681+ rewriter.replaceOpWithNewOp <arith::AddIOp>(
1682+ srcOp, reduceResultOp.getResult (), srcOp.getAcc ());
1683+ else
1684+ rewriter.replaceOp (srcOp, reduceResultOp);
1685+ } else {
1686+ auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
16661687 rewriter, srcOp, shiftIndex, srcOp.getVector ());
1688+ if (srcOp.getAcc ())
1689+ rewriter.replaceOpWithNewOp <arith::AddIOp>(
1690+ srcOp, reduceResultOp.getResult (), srcOp.getAcc ());
1691+ else
1692+ rewriter.replaceOp (srcOp, reduceResultOp);
1693+ }
16671694
16681695 return success ();
16691696 }
@@ -1717,8 +1744,14 @@ struct LowerVectorReductionAddFloatOp
17171744
17181745 auto zeroConstOp =
17191746 rewriter.create <arith::ConstantOp>(loc, rewriter.getI32IntegerAttr (0 ));
1720- rewriter.replaceOpWithNewOp <aievec::ExtElemOp>(srcOp, scalarType, curOp,
1721- zeroConstOp.getResult ());
1747+ auto reduceResultOp = rewriter.create <aievec::ExtElemOp>(
1748+ srcOp.getLoc (), scalarType, curOp, zeroConstOp.getResult ());
1749+
1750+ if (srcOp.getAcc ())
1751+ rewriter.replaceOpWithNewOp <arith::AddFOp>(
1752+ srcOp, reduceResultOp.getResult (), srcOp.getAcc ());
1753+ else
1754+ rewriter.replaceOp (srcOp, reduceResultOp);
17221755 return success ();
17231756 }
17241757};
@@ -1779,8 +1812,14 @@ struct LowerVectorReductionAddBfloat16Op
17791812
17801813 auto zeroConstOp =
17811814 rewriter.create <arith::ConstantOp>(loc, rewriter.getI32IntegerAttr (0 ));
1782- rewriter.replaceOpWithNewOp <aievec::ExtElemOp>(srcOp, scalarType, concatOp,
1783- zeroConstOp.getResult ());
1815+ auto reduceResultOp = rewriter.create <aievec::ExtElemOp>(
1816+ srcOp.getLoc (), scalarType, concatOp, zeroConstOp.getResult ());
1817+
1818+ if (srcOp.getAcc ())
1819+ rewriter.replaceOpWithNewOp <arith::AddFOp>(
1820+ srcOp, reduceResultOp.getResult (), srcOp.getAcc ());
1821+ else
1822+ rewriter.replaceOp (srcOp, reduceResultOp);
17841823 return success ();
17851824 }
17861825};
@@ -3125,6 +3164,7 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
31253164 LowerVectorMinimumFOpToAIEVecMinOp,
31263165 LowerVectorMaxSIOpToAIEVecMaxOp,
31273166 LowerVectorMaximumFOpToAIEVecMaxOp,
3167+ LowerVectorMaxNumFFOpToAIEVecMaxOp,
31283168 LowerVectorCmpIOpToAIEVecCmpOp,
31293169 LowerVectorCmpFOpToAIEVecCmpOp,
31303170 LowerVectorSelectOpToAIEVecSelOp,
@@ -3705,6 +3745,17 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target,
37053745 return !elWidthSet.count (resultElWidth) || laneSize * resultElWidth != 512 ;
37063746 });
37073747
3748+ target.addDynamicallyLegalOp <arith::MaxNumFOp>([=](arith::MaxNumFOp op) {
3749+ auto resultType = dyn_cast<VectorType>(op.getType ());
3750+ if (!resultType)
3751+ return true ;
3752+
3753+ auto resultElWidth = resultType.getElementType ().getIntOrFloatBitWidth ();
3754+ unsigned laneSize = getVectorLaneSize (resultType);
3755+
3756+ return !elWidthSet.count (resultElWidth) || laneSize * resultElWidth != 512 ;
3757+ });
3758+
37083759 target.addDynamicallyLegalOp <arith::CmpIOp>([=](arith::CmpIOp op) {
37093760 auto lhsType = dyn_cast<VectorType>(op.getLhs ().getType ());
37103761 if (!lhsType)
@@ -3746,7 +3797,8 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target,
37463797 kind != vector::CombiningKind::MINIMUMF &&
37473798 kind != vector::CombiningKind::MAXSI &&
37483799 kind != vector::CombiningKind::MAXUI &&
3749- kind != vector::CombiningKind::MAXIMUMF)
3800+ kind != vector::CombiningKind::MAXIMUMF &&
3801+ kind != vector::CombiningKind::MAXNUMF)
37503802 return true ;
37513803
37523804 auto vType = dyn_cast<VectorType>(op.getVector ().getType ());
0 commit comments