Skip to content

Commit 34974ee

Browse files
Incorporate review feedback
1 parent a69e982 commit 34974ee

File tree

5 files changed

+1055
-501
lines changed

5 files changed

+1055
-501
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ INITIALIZE_PASS(NVPTXDAGToDAGISelLegacy, DEBUG_TYPE, PASS_NAME, false, false)
5656

5757
NVPTXDAGToDAGISel::NVPTXDAGToDAGISel(NVPTXTargetMachine &tm,
5858
CodeGenOptLevel OptLevel)
59-
: SelectionDAGISel(tm, OptLevel), TM(tm) {
60-
doMulWide = (OptLevel > CodeGenOptLevel::None);
61-
}
59+
: SelectionDAGISel(tm, OptLevel), TM(tm) {}
6260

6361
bool NVPTXDAGToDAGISel::runOnMachineFunction(MachineFunction &MF) {
6462
Subtarget = &MF.getSubtarget<NVPTXSubtarget>();

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ struct NVPTXScopes {
4040
class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
4141
const NVPTXTargetMachine &TM;
4242

43-
// If true, generate mul.wide from sext and mul
44-
bool doMulWide;
45-
4643
NVPTX::DivPrecisionLevel getDivF32Level(const SDNode *N) const;
4744
bool usePrecSqrtF32(const SDNode *N) const;
4845
bool useF32FTZ() const;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5276,7 +5276,8 @@ static SDValue PerformADDCombine(SDNode *N,
52765276
SDValue N1 = N->getOperand(1);
52775277

52785278
// Skip non-integer, non-scalar case
5279-
if (N->getValueType(0).isVector() || N->getValueType(0) != MVT::i32)
5279+
EVT VT = N0.getValueType();
5280+
if (VT.isVector() || VT != MVT::i32)
52805281
return SDValue();
52815282

52825283
// First try with the default operand order.
@@ -5409,46 +5410,35 @@ static SDValue PerformREMCombine(SDNode *N,
54095410
}
54105411

54115412
// (any_extend|sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y)
5412-
static SDValue
5413-
PerformExtendMULWIDECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5414-
unsigned ExtOpcode = N->getOpcode();
5415-
assert(ExtOpcode == ISD::ANY_EXTEND || ExtOpcode == ISD::SIGN_EXTEND ||
5416-
ExtOpcode == ISD::ZERO_EXTEND);
5417-
EVT ToVT = N->getValueType(0);
5418-
if (!(ToVT == MVT::i32 || ToVT == MVT::i64))
5413+
static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
5414+
CodeGenOptLevel OptLevel) {
5415+
if (OptLevel == CodeGenOptLevel::None)
54195416
return SDValue();
5417+
54205418
SDValue Op = N->getOperand(0);
5421-
if (!(Op.getOpcode() == ISD::MUL || Op.getOpcode() == ISD::SHL))
5422-
return SDValue();
5423-
if (Op.getOpcode() == ISD::SHL && !isa<ConstantSDNode>(Op.getOperand(1)))
5419+
if (!Op.hasOneUse())
54245420
return SDValue();
5421+
EVT ToVT = N->getValueType(0);
54255422
EVT FromVT = Op.getValueType();
5426-
if (!(FromVT == MVT::i16 || FromVT == MVT::i32))
5427-
return SDValue();
5428-
if (ExtOpcode == ISD::SIGN_EXTEND && !Op->getFlags().hasNoSignedWrap())
5423+
if (!((ToVT == MVT::i32 && FromVT == MVT::i16) ||
5424+
(ToVT == MVT::i64 && FromVT == MVT::i32)))
54295425
return SDValue();
5430-
if (ExtOpcode == ISD::ZERO_EXTEND && !Op->getFlags().hasNoUnsignedWrap())
5431-
return SDValue();
5432-
if (ExtOpcode == ISD::ANY_EXTEND && !Op->getFlags().hasNoSignedWrap() &&
5433-
!Op->getFlags().hasNoUnsignedWrap())
5426+
if (!(Op.getOpcode() == ISD::MUL ||
5427+
(Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Op.getOperand(1)))))
54345428
return SDValue();
54355429

54365430
SDLoc DL(N);
5431+
unsigned ExtOpcode = N->getOpcode();
54375432
unsigned Opcode = 0;
5438-
if (ExtOpcode == ISD::SIGN_EXTEND)
5433+
if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap())
54395434
Opcode = NVPTXISD::MUL_WIDE_SIGNED;
5440-
else if (ExtOpcode == ISD::ZERO_EXTEND)
5435+
else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap())
54415436
Opcode = NVPTXISD::MUL_WIDE_UNSIGNED;
5442-
else if (ExtOpcode == ISD::ANY_EXTEND && Op->getFlags().hasNoUnsignedWrap())
5443-
Opcode = NVPTXISD::MUL_WIDE_UNSIGNED;
5444-
else if (ExtOpcode == ISD::ANY_EXTEND && Op->getFlags().hasNoSignedWrap())
5445-
Opcode = NVPTXISD::MUL_WIDE_SIGNED;
54465437
else
5447-
assert(false);
5438+
return SDValue();
54485439
SDValue RHS = Op.getOperand(1);
54495440
if (Op.getOpcode() == ISD::SHL) {
5450-
const auto ShiftAmt =
5451-
cast<ConstantSDNode>(Op.getOperand(1))->getZExtValue();
5441+
const auto ShiftAmt = Op.getConstantOperandVal(1);
54525442
const auto MulVal = APInt(ToVT.getSizeInBits(), 1) << ShiftAmt;
54535443
RHS = DCI.DAG.getConstant(MulVal, DL, ToVT);
54545444
}
@@ -5977,10 +5967,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
59775967
return combineADDRSPACECAST(N, DCI);
59785968
case ISD::AND:
59795969
return PerformANDCombine(N, DCI);
5980-
case ISD::ANY_EXTEND:
59815970
case ISD::SIGN_EXTEND:
59825971
case ISD::ZERO_EXTEND:
5983-
return PerformExtendMULWIDECombine(N, DCI);
5972+
return combineMulWide(N, DCI, OptLevel);
59845973
case ISD::BUILD_VECTOR:
59855974
return PerformBUILD_VECTORCombine(N, DCI);
59865975
case ISD::EXTRACT_VECTOR_ELT:

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 25 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ def doF32FTZ : Predicate<"useF32FTZ()">;
125125
def doNoF32FTZ : Predicate<"!useF32FTZ()">;
126126
def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
127127

128-
def doMulWide : Predicate<"doMulWide">;
129-
130128
def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
131129
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
132130
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
@@ -860,11 +858,11 @@ def MULWIDEU32Imm32 :
860858
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.u16">;
861859

862860
def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>;
863-
def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
864-
def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
861+
def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>;
862+
def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>;
865863

866864
// Matchers for signed, unsigned mul.wide ISD nodes.
867-
let Predicates = [doMulWide] in {
865+
let Predicates = [hasOptEnabled] in {
868866
def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
869867
def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
870868
def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
@@ -876,85 +874,6 @@ let Predicates = [doMulWide] in {
876874
def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
877875
}
878876

879-
// Predicates used for converting some patterns to mul.wide.
880-
def SInt32Const : PatLeaf<(imm), [{
881-
const APInt &v = N->getAPIntValue();
882-
return v.isSignedIntN(32);
883-
}]>;
884-
885-
def UInt32Const : PatLeaf<(imm), [{
886-
const APInt &v = N->getAPIntValue();
887-
return v.isIntN(32);
888-
}]>;
889-
890-
def SInt16Const : PatLeaf<(imm), [{
891-
const APInt &v = N->getAPIntValue();
892-
return v.isSignedIntN(16);
893-
}]>;
894-
895-
def UInt16Const : PatLeaf<(imm), [{
896-
const APInt &v = N->getAPIntValue();
897-
return v.isIntN(16);
898-
}]>;
899-
900-
def IntConst_0_30 : PatLeaf<(imm), [{
901-
// Check if 0 <= v < 31; only then will the result of (x << v) be an int32.
902-
const APInt &v = N->getAPIntValue();
903-
return v.sge(0) && v.slt(31);
904-
}]>;
905-
906-
def IntConst_0_14 : PatLeaf<(imm), [{
907-
// Check if 0 <= v < 15; only then will the result of (x << v) be an int16.
908-
const APInt &v = N->getAPIntValue();
909-
return v.sge(0) && v.slt(15);
910-
}]>;
911-
912-
def SHL2MUL32 : SDNodeXForm<imm, [{
913-
const APInt &v = N->getAPIntValue();
914-
APInt temp(32, 1);
915-
return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i32);
916-
}]>;
917-
918-
def SHL2MUL16 : SDNodeXForm<imm, [{
919-
const APInt &v = N->getAPIntValue();
920-
APInt temp(16, 1);
921-
return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i16);
922-
}]>;
923-
924-
// Convert "sign/zero-extend, then shift left by an immediate" to mul.wide.
925-
let Predicates = [doMulWide] in {
926-
def : Pat<(shl (sext i32:$a), (i32 IntConst_0_30:$b)),
927-
(MULWIDES64Imm $a, (SHL2MUL32 $b))>;
928-
def : Pat<(shl (zext i32:$a), (i32 IntConst_0_30:$b)),
929-
(MULWIDEU64Imm $a, (SHL2MUL32 $b))>;
930-
931-
def : Pat<(shl (sext i16:$a), (i16 IntConst_0_14:$b)),
932-
(MULWIDES32Imm $a, (SHL2MUL16 $b))>;
933-
def : Pat<(shl (zext i16:$a), (i16 IntConst_0_14:$b)),
934-
(MULWIDEU32Imm $a, (SHL2MUL16 $b))>;
935-
936-
// Convert "sign/zero-extend then multiply" to mul.wide.
937-
def : Pat<(mul (sext i32:$a), (sext i32:$b)),
938-
(MULWIDES64 $a, $b)>;
939-
def : Pat<(mul (sext i32:$a), (i64 SInt32Const:$b)),
940-
(MULWIDES64Imm64 $a, (i64 SInt32Const:$b))>;
941-
942-
def : Pat<(mul (zext i32:$a), (zext i32:$b)),
943-
(MULWIDEU64 $a, $b)>;
944-
def : Pat<(mul (zext i32:$a), (i64 UInt32Const:$b)),
945-
(MULWIDEU64Imm64 $a, (i64 UInt32Const:$b))>;
946-
947-
def : Pat<(mul (sext i16:$a), (sext i16:$b)),
948-
(MULWIDES32 $a, $b)>;
949-
def : Pat<(mul (sext i16:$a), (i32 SInt16Const:$b)),
950-
(MULWIDES32Imm32 $a, (i32 SInt16Const:$b))>;
951-
952-
def : Pat<(mul (zext i16:$a), (zext i16:$b)),
953-
(MULWIDEU32 $a, $b)>;
954-
def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)),
955-
(MULWIDEU32Imm32 $a, (i32 UInt16Const:$b))>;
956-
}
957-
958877
//
959878
// Integer multiply-add
960879
//
@@ -990,33 +909,38 @@ defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>;
990909
defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>;
991910
}
992911

993-
multiclass MAD_WIDE<string PtxSuffix, SDNode Op, ValueType BigVT, NVPTXRegClass BigReg, Operand BigImm, ValueType SmallVT, NVPTXRegClass SmallReg, Operand SmallImm> {
912+
multiclass MAD_WIDE<string PtxSuffix, OneUse2 Op, RegTyInfo BigT, RegTyInfo SmallT> {
994913
def rrr:
995-
BasicNVPTXInst<(outs BigReg:$dst),
996-
(ins SmallReg:$a, SmallReg:$b, BigReg:$c),
914+
BasicNVPTXInst<(outs BigT.RC:$dst),
915+
(ins SmallT.RC:$a, SmallT.RC:$b, BigT.RC:$c),
997916
"mad.wide." # PtxSuffix,
998-
[(set BigVT:$dst, (add (Op SmallVT:$a, SmallVT:$b), BigVT:$c))]>;
917+
[(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), BigT.Ty:$c))]>;
999918
def rri:
1000-
BasicNVPTXInst<(outs BigReg:$dst),
1001-
(ins SmallReg:$a, SmallReg:$b, BigImm:$c),
919+
BasicNVPTXInst<(outs BigT.RC:$dst),
920+
(ins SmallT.RC:$a, SmallT.RC:$b, BigT.Imm:$c),
1002921
"mad.wide." # PtxSuffix,
1003-
[(set BigVT:$dst, (add (Op SmallVT:$a, SmallVT:$b), imm:$c))]>;
922+
[(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), imm:$c))]>;
1004923
def rir:
1005-
BasicNVPTXInst<(outs BigReg:$dst),
1006-
(ins SmallReg:$a, SmallImm:$b, BigReg:$c),
924+
BasicNVPTXInst<(outs BigT.RC:$dst),
925+
(ins SmallT.RC:$a, SmallT.Imm:$b, BigT.RC:$c),
1007926
"mad.wide." # PtxSuffix,
1008-
[(set BigVT:$dst, (add (Op SmallVT:$a, imm:$b), BigVT:$c))]>;
927+
[(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), BigT.Ty:$c))]>;
1009928
def rii:
1010-
BasicNVPTXInst<(outs BigReg:$dst),
1011-
(ins SmallReg:$a, SmallImm:$b, BigImm:$c),
929+
BasicNVPTXInst<(outs BigT.RC:$dst),
930+
(ins SmallT.RC:$a, SmallT.Imm:$b, BigT.Imm:$c),
1012931
"mad.wide." # PtxSuffix,
1013-
[(set BigVT:$dst, (add (Op SmallVT:$a, imm:$b), imm:$c))]>;
932+
[(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), imm:$c))]>;
1014933
}
1015934

1016-
defm MAD_WIDE_U16 : MAD_WIDE<"u16", mul_wide_unsigned, i32, B32, i32imm, i16, B16, i16imm>;
1017-
defm MAD_WIDE_S16 : MAD_WIDE<"s16", mul_wide_signed, i32, B32, i32imm, i16, B16, i16imm>;
1018-
defm MAD_WIDE_U32 : MAD_WIDE<"u32", mul_wide_unsigned, i64, B64, i64imm, i32, B32, i32imm>;
1019-
defm MAD_WIDE_S32 : MAD_WIDE<"s32", mul_wide_signed, i64, B64, i64imm, i32, B32, i32imm>;
935+
def mul_wide_unsigned_oneuse : OneUse2<mul_wide_unsigned>;
936+
def mul_wide_signed_oneuse : OneUse2<mul_wide_signed>;
937+
938+
let Predicates = [hasOptEnabled] in {
939+
defm MAD_WIDE_U16 : MAD_WIDE<"u16", mul_wide_unsigned_oneuse, I32RT, I16RT>;
940+
defm MAD_WIDE_S16 : MAD_WIDE<"s16", mul_wide_signed_oneuse, I32RT, I16RT>;
941+
defm MAD_WIDE_U32 : MAD_WIDE<"u32", mul_wide_unsigned_oneuse, I64RT, I32RT>;
942+
defm MAD_WIDE_S32 : MAD_WIDE<"s32", mul_wide_signed_oneuse, I64RT, I32RT>;
943+
}
1020944

1021945
foreach t = [I16RT, I32RT, I64RT] in {
1022946
def NEG_S # t.Size :

0 commit comments

Comments
 (0)