Skip to content

Commit 864cb6f

Browse files
committed
[NVPTX] add combiner rule for v2[b]f16 = fp_round v2f32
Now that v2f32 is legal, this node will go straight to instruction selection. Instead, we want to break it up into two nodes, which can be handled better in instruction selection, since the final instruction (cvt.[b]f16x2.f32) takes two f32 arguments.
1 parent 3179640 commit 864cb6f

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
826826
// We have some custom DAG combine patterns for these nodes
827827
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
828828
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
829-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
829+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
830830

831831
// setcc for f16x2 and bf16x2 needs special handling to prevent
832832
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5713,6 +5713,40 @@ static SDValue combineADDRSPACECAST(SDNode *N,
57135713
return SDValue();
57145714
}
57155715

5716+
static SDValue PerformFP_ROUNDCombine(SDNode *N,
5717+
TargetLowering::DAGCombinerInfo &DCI) {
5718+
SDLoc DL(N);
5719+
SDValue Op = N->getOperand(0);
5720+
SDValue Trunc = N->getOperand(1);
5721+
EVT NarrowVT = N->getValueType(0);
5722+
EVT WideVT = Op.getValueType();
5723+
5724+
// v2[b]f16 = fp_round (v2f32 A)
5725+
// -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5726+
// ([b]f16 = fp_round (extractelt A, 1)))
5727+
if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5728+
WideVT == MVT::v2f32) {
5729+
SDValue F32Op0, F32Op1;
5730+
if (Op.getOpcode() == ISD::BUILD_VECTOR) {
5731+
F32Op0 = Op.getOperand(0);
5732+
F32Op1 = Op.getOperand(1);
5733+
} else {
5734+
F32Op0 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5735+
DCI.DAG.getIntPtrConstant(0, DL));
5736+
F32Op1 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5737+
DCI.DAG.getIntPtrConstant(1, DL));
5738+
}
5739+
return DCI.DAG.getBuildVector(
5740+
NarrowVT, DL,
5741+
{DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op0,
5742+
Trunc),
5743+
DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op1,
5744+
Trunc)});
5745+
}
5746+
5747+
return SDValue();
5748+
}
5749+
57165750
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57175751
DAGCombinerInfo &DCI) const {
57185752
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5749,6 +5783,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57495783
return PerformBUILD_VECTORCombine(N, DCI);
57505784
case ISD::ADDRSPACECAST:
57515785
return combineADDRSPACECAST(N, DCI);
5786+
case ISD::FP_ROUND:
5787+
return PerformFP_ROUNDCombine(N, DCI);
57525788
}
57535789
return SDValue();
57545790
}

0 commit comments

Comments
 (0)