Skip to content

Commit b365375

Browse files
committed
[NVPTX] add combiner rule to peek through bitcast of BUILD_VECTOR
1 parent 9fbd9d1 commit b365375

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
828828
// We have some custom DAG combine patterns for these nodes
829829
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
830830
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
831-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
831+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
832+
ISD::TRUNCATE});
832833

833834
// setcc for f16x2 and bf16x2 needs special handling to prevent
834835
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5752,6 +5753,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
57525753
return SDValue();
57535754
}
57545755

5756+
static SDValue PerformTRUNCATECombine(SDNode *N,
5757+
TargetLowering::DAGCombinerInfo &DCI) {
5758+
SDLoc DL(N);
5759+
SDValue Op = N->getOperand(0);
5760+
EVT FromVT = Op.getValueType();
5761+
EVT ResultVT = N->getValueType(0);
5762+
5763+
if (FromVT == MVT::i64 && ResultVT == MVT::i32) {
5764+
// i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5765+
// -> i32 = bitcast (f32 A)
5766+
if (Op.getOpcode() == ISD::BITCAST) {
5767+
SDValue BV = Op.getOperand(0);
5768+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5769+
BV.getValueType() == MVT::v2f32) {
5770+
// get lower
5771+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT, BV.getOperand(0));
5772+
}
5773+
}
5774+
5775+
// i32 = truncate (i64 = srl
5776+
// (i64 = bitcast
5777+
// (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5778+
// -> i32 = bitcast (f32 B)
5779+
if (Op.getOpcode() == ISD::SRL) {
5780+
if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand(1));
5781+
ShAmt && ShAmt->getAsAPIntVal() == 32) {
5782+
SDValue Cast = Op.getOperand(0);
5783+
if (Cast.getOpcode() == ISD::BITCAST) {
5784+
SDValue BV = Cast.getOperand(0);
5785+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5786+
BV.getValueType() == MVT::v2f32) {
5787+
// get upper
5788+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT,
5789+
BV.getOperand(1));
5790+
}
5791+
}
5792+
}
5793+
}
5794+
}
5795+
5796+
return SDValue();
5797+
}
5798+
57555799
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57565800
DAGCombinerInfo &DCI) const {
57575801
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5790,6 +5834,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57905834
return combineADDRSPACECAST(N, DCI);
57915835
case ISD::FP_ROUND:
57925836
return PerformFP_ROUNDCombine(N, DCI);
5837+
case ISD::TRUNCATE:
5838+
return PerformTRUNCATECombine(N, DCI);
57935839
}
57945840
return SDValue();
57955841
}

0 commit comments

Comments
 (0)