@@ -852,6 +852,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
852852 if (STI.allowFP16Math () || STI.hasBF16Math ())
853853 setTargetDAGCombine (ISD::SETCC);
854854
855+ // Combine reduction operations on packed types (e.g. fadd.f16x2) with vector
856+ // shuffles when one of their lanes is a no-op.
857+ if (STI.allowFP16Math () || STI.hasBF16Math ())
858+ // already added above: FADD, ADD, AND
859+ setTargetDAGCombine ({ISD::FMUL, ISD::FMINIMUM, ISD::FMAXIMUM, ISD::UMIN,
860+ ISD::UMAX, ISD::SMIN, ISD::SMAX, ISD::OR, ISD::XOR});
861+
855862 // Promote fp16 arithmetic if fp16 hardware isn't available or the
856863 // user passed --nvptx-no-fp16-math. The flag is useful because,
857864 // although sm_53+ GPUs have some sort of FP16 support in
@@ -5069,20 +5076,102 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
50695076 return PerformStoreCombineHelper (N, 2 , 0 );
50705077}
50715078
5079+ // / For vector reductions, the final result needs to be a scalar. The default
5080+ // / expansion will use packed ops (ex. fadd.f16x2) even for the final operation.
5081+ // / This requires a packed operation where one of the lanes is undef.
5082+ // /
5083+ // / ex: lowering of vecreduce_fadd(V) where V = v4f16<a b c d>
5084+ // /
5085+ // / v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5086+ // / v2: v2f16 = vector_shuffle<1,u> v1, undef:v2f16 (== <b+d undef>)
5087+ // / v3: v2f16 = fadd reassoc v2, v1 (== <b+d+a+c undef>)
5088+ // / vR: f16 = extractelt v3, 1
5089+ // /
5090+ // / We wish to replace vR, v3, and v2 with:
5091+ // / vR: f16 = fadd reassoc (extractelt v1, 1) (extractelt v1, 0)
5092+ // /
5093+ // / ...so that we get:
5094+ // / v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5095+ // / s1: f16 = extractelt v1, 1
5096+ // / s2: f16 = extractelt v1, 0
5097+ // / vR: f16 = fadd reassoc s1, s2 (== a+c+b+d)
5098+ // /
5099+ // / So for this example, this rule will replace v3 and v2, returning a vector
5100+ // / with the result in lane 0 and an undef in lane 1, which we expect will be
5101+ // / folded into the extractelt in vR.
5102+ static SDValue PerformPackedOpCombine (SDNode *N,
5103+ TargetLowering::DAGCombinerInfo &DCI) {
5104+ // Convert:
5105+ // (fop.x2 (vector_shuffle<i,u> A), B) -> ((fop A:i, B:0), undef)
5106+ // ...or...
5107+ // (fop.x2 (vector_shuffle<u,i> A), B) -> (undef, (fop A:i, B:1))
5108+ // ...where i is a valid index and u is poison.
5109+ const EVT VectorVT = N->getValueType (0 );
5110+ if (!Isv2x16VT (VectorVT))
5111+ return SDValue ();
5112+
5113+ SDLoc DL (N);
5114+
5115+ SDValue ShufOp = N->getOperand (0 );
5116+ SDValue VectOp = N->getOperand (1 );
5117+ bool Swapped = false ;
5118+
5119+ // canonicalize shuffle to op0
5120+ if (VectOp.getOpcode () == ISD::VECTOR_SHUFFLE) {
5121+ std::swap (ShufOp, VectOp);
5122+ Swapped = true ;
5123+ }
5124+
5125+ if (ShufOp.getOpcode () != ISD::VECTOR_SHUFFLE)
5126+ return SDValue ();
5127+
5128+ auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
5129+ int LiveLane; // exclusively live lane
5130+ for (LiveLane = 0 ; LiveLane < 2 ; ++LiveLane) {
5131+ // check if the current lane is live and the other lane is dead
5132+ if (ShuffleOp->getMaskElt (LiveLane) != PoisonMaskElem &&
5133+ ShuffleOp->getMaskElt (!LiveLane) == PoisonMaskElem)
5134+ break ;
5135+ }
5136+ if (LiveLane == 2 )
5137+ return SDValue ();
5138+
5139+ int ElementIdx = ShuffleOp->getMaskElt (LiveLane);
5140+ const EVT ScalarVT = VectorVT.getScalarType ();
5141+ SDValue Lanes[2 ] = {};
5142+ for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
5143+ if (LaneID == (unsigned )LiveLane) {
5144+ SDValue Operands[2 ] = {
5145+ DCI.DAG .getExtractVectorElt (DL, ScalarVT, ShufOp.getOperand (0 ),
5146+ ElementIdx),
5147+ DCI.DAG .getExtractVectorElt (DL, ScalarVT, VectOp, LiveLane)};
5148+ // preserve the order of operands
5149+ if (Swapped)
5150+ std::swap (Operands[0 ], Operands[1 ]);
5151+ LaneVal = DCI.DAG .getNode (N->getOpcode (), DL, ScalarVT, Operands);
5152+ } else {
5153+ LaneVal = DCI.DAG .getUNDEF (ScalarVT);
5154+ }
5155+ }
5156+ return DCI.DAG .getBuildVector (VectorVT, DL, Lanes);
5157+ }
5158+
50725159// / PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
50735160// /
50745161static SDValue PerformADDCombine (SDNode *N,
50755162 TargetLowering::DAGCombinerInfo &DCI,
50765163 CodeGenOptLevel OptLevel) {
5077- if (OptLevel == CodeGenOptLevel::None)
5078- return SDValue ();
5079-
50805164 SDValue N0 = N->getOperand (0 );
50815165 SDValue N1 = N->getOperand (1 );
50825166
50835167 // Skip non-integer, non-scalar case
50845168 EVT VT = N0.getValueType ();
5085- if (VT.isVector () || VT != MVT::i32 )
5169+ if (VT.isVector ())
5170+ return PerformPackedOpCombine (N, DCI);
5171+ if (VT != MVT::i32 )
5172+ return SDValue ();
5173+
5174+ if (OptLevel == CodeGenOptLevel::None)
50865175 return SDValue ();
50875176
50885177 // First try with the default operand order.
@@ -5102,7 +5191,10 @@ static SDValue PerformFADDCombine(SDNode *N,
51025191 SDValue N1 = N->getOperand (1 );
51035192
51045193 EVT VT = N0.getValueType ();
5105- if (VT.isVector () || !(VT == MVT::f32 || VT == MVT::f64 ))
5194+ if (VT.isVector ())
5195+ return PerformPackedOpCombine (N, DCI);
5196+
5197+ if (!(VT == MVT::f32 || VT == MVT::f64 ))
51065198 return SDValue ();
51075199
51085200 // First try with the default operand order.
@@ -5205,7 +5297,7 @@ static SDValue PerformANDCombine(SDNode *N,
52055297 DCI.CombineTo (N, Val, AddTo);
52065298 }
52075299
5208- return SDValue ( );
5300+ return PerformPackedOpCombine (N, DCI );
52095301}
52105302
52115303static SDValue PerformREMCombine (SDNode *N,
@@ -5686,6 +5778,16 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56865778 return PerformADDCombine (N, DCI, OptLevel);
56875779 case ISD::FADD:
56885780 return PerformFADDCombine (N, DCI, OptLevel);
5781+ case ISD::FMUL:
5782+ case ISD::FMINNUM:
5783+ case ISD::FMAXIMUM:
5784+ case ISD::UMIN:
5785+ case ISD::UMAX:
5786+ case ISD::SMIN:
5787+ case ISD::SMAX:
5788+ case ISD::OR:
5789+ case ISD::XOR:
5790+ return PerformPackedOpCombine (N, DCI);
56895791 case ISD::MUL:
56905792 return PerformMULCombine (N, DCI, OptLevel);
56915793 case ISD::SHL:
0 commit comments