Skip to content

Commit 42e0d30

Browse files
[NVPTX] Enhance mul.wide and mad.wide peepholes (#150477)
Implements `(sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y)` as a DAG combine. Implements `(add (mul.wide a, b), c) -> (mad.wide a, b, c)` in instruction selection.
1 parent e8e9bef commit 42e0d30

File tree

8 files changed

+1429
-115
lines changed

8 files changed

+1429
-115
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ INITIALIZE_PASS(NVPTXDAGToDAGISelLegacy, DEBUG_TYPE, PASS_NAME, false, false)
5656

5757
NVPTXDAGToDAGISel::NVPTXDAGToDAGISel(NVPTXTargetMachine &tm,
5858
CodeGenOptLevel OptLevel)
59-
: SelectionDAGISel(tm, OptLevel), TM(tm) {
60-
doMulWide = (OptLevel > CodeGenOptLevel::None);
61-
}
59+
: SelectionDAGISel(tm, OptLevel), TM(tm) {}
6260

6361
bool NVPTXDAGToDAGISel::runOnMachineFunction(MachineFunction &MF) {
6462
Subtarget = &MF.getSubtarget<NVPTXSubtarget>();

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ struct NVPTXScopes {
4040
class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
4141
const NVPTXTargetMachine &TM;
4242

43-
// If true, generate mul.wide from sext and mul
44-
bool doMulWide;
45-
4643
NVPTX::DivPrecisionLevel getDivF32Level(const SDNode *N) const;
4744
bool usePrecSqrtF32(const SDNode *N) const;
4845
bool useF32FTZ() const;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
843843
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
844844
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
845845
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
846-
ISD::STORE});
846+
ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
847847

848848
// setcc for f16x2 and bf16x2 needs special handling to prevent
849849
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5219,6 +5219,42 @@ static SDValue PerformREMCombine(SDNode *N,
52195219
return SDValue();
52205220
}
52215221

5222+
// (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y)
5223+
static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
5224+
CodeGenOptLevel OptLevel) {
5225+
if (OptLevel == CodeGenOptLevel::None)
5226+
return SDValue();
5227+
5228+
SDValue Op = N->getOperand(0);
5229+
if (!Op.hasOneUse())
5230+
return SDValue();
5231+
EVT ToVT = N->getValueType(0);
5232+
EVT FromVT = Op.getValueType();
5233+
if (!((ToVT == MVT::i32 && FromVT == MVT::i16) ||
5234+
(ToVT == MVT::i64 && FromVT == MVT::i32)))
5235+
return SDValue();
5236+
if (!(Op.getOpcode() == ISD::MUL ||
5237+
(Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Op.getOperand(1)))))
5238+
return SDValue();
5239+
5240+
SDLoc DL(N);
5241+
unsigned ExtOpcode = N->getOpcode();
5242+
unsigned Opcode = 0;
5243+
if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap())
5244+
Opcode = NVPTXISD::MUL_WIDE_SIGNED;
5245+
else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap())
5246+
Opcode = NVPTXISD::MUL_WIDE_UNSIGNED;
5247+
else
5248+
return SDValue();
5249+
SDValue RHS = Op.getOperand(1);
5250+
if (Op.getOpcode() == ISD::SHL) {
5251+
const auto ShiftAmt = Op.getConstantOperandVal(1);
5252+
const auto MulVal = APInt(ToVT.getSizeInBits(), 1) << ShiftAmt;
5253+
RHS = DCI.DAG.getConstant(MulVal, DL, ToVT);
5254+
}
5255+
return DCI.DAG.getNode(Opcode, DL, ToVT, Op.getOperand(0), RHS);
5256+
}
5257+
52225258
enum OperandSignedness {
52235259
Signed = 0,
52245260
Unsigned,
@@ -5825,6 +5861,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
58255861
return combineADDRSPACECAST(N, DCI);
58265862
case ISD::AND:
58275863
return PerformANDCombine(N, DCI);
5864+
case ISD::SIGN_EXTEND:
5865+
case ISD::ZERO_EXTEND:
5866+
return combineMulWide(N, DCI, OptLevel);
58285867
case ISD::BUILD_VECTOR:
58295868
return PerformBUILD_VECTORCombine(N, DCI);
58305869
case ISD::EXTRACT_VECTOR_ELT:

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 37 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ def doF32FTZ : Predicate<"useF32FTZ()">;
125125
def doNoF32FTZ : Predicate<"!useF32FTZ()">;
126126
def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
127127

128-
def doMulWide : Predicate<"doMulWide">;
129-
130128
def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
131129
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
132130
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
@@ -836,36 +834,28 @@ def MULWIDES64 :
836834
BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.s32">;
837835
def MULWIDES64Imm :
838836
BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.s32">;
839-
def MULWIDES64Imm64 :
840-
BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.s32">;
841837

842838
def MULWIDEU64 :
843839
BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.u32">;
844840
def MULWIDEU64Imm :
845841
BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.u32">;
846-
def MULWIDEU64Imm64 :
847-
BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.u32">;
848842

849843
def MULWIDES32 :
850844
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.s16">;
851845
def MULWIDES32Imm :
852846
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.s16">;
853-
def MULWIDES32Imm32 :
854-
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.s16">;
855847

856848
def MULWIDEU32 :
857849
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.u16">;
858850
def MULWIDEU32Imm :
859851
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.u16">;
860-
def MULWIDEU32Imm32 :
861-
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.u16">;
862852

863-
def SDTMulWide : SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>;
864-
def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
865-
def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
853+
def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>;
854+
def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>;
855+
def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>;
866856

867857
// Matchers for signed, unsigned mul.wide ISD nodes.
868-
let Predicates = [doMulWide] in {
858+
let Predicates = [hasOptEnabled] in {
869859
def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
870860
def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
871861
def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
@@ -877,85 +867,6 @@ let Predicates = [doMulWide] in {
877867
def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
878868
}
879869

880-
// Predicates used for converting some patterns to mul.wide.
881-
def SInt32Const : PatLeaf<(imm), [{
882-
const APInt &v = N->getAPIntValue();
883-
return v.isSignedIntN(32);
884-
}]>;
885-
886-
def UInt32Const : PatLeaf<(imm), [{
887-
const APInt &v = N->getAPIntValue();
888-
return v.isIntN(32);
889-
}]>;
890-
891-
def SInt16Const : PatLeaf<(imm), [{
892-
const APInt &v = N->getAPIntValue();
893-
return v.isSignedIntN(16);
894-
}]>;
895-
896-
def UInt16Const : PatLeaf<(imm), [{
897-
const APInt &v = N->getAPIntValue();
898-
return v.isIntN(16);
899-
}]>;
900-
901-
def IntConst_0_30 : PatLeaf<(imm), [{
902-
// Check if 0 <= v < 31; only then will the result of (x << v) be an int32.
903-
const APInt &v = N->getAPIntValue();
904-
return v.sge(0) && v.slt(31);
905-
}]>;
906-
907-
def IntConst_0_14 : PatLeaf<(imm), [{
908-
// Check if 0 <= v < 15; only then will the result of (x << v) be an int16.
909-
const APInt &v = N->getAPIntValue();
910-
return v.sge(0) && v.slt(15);
911-
}]>;
912-
913-
def SHL2MUL32 : SDNodeXForm<imm, [{
914-
const APInt &v = N->getAPIntValue();
915-
APInt temp(32, 1);
916-
return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i32);
917-
}]>;
918-
919-
def SHL2MUL16 : SDNodeXForm<imm, [{
920-
const APInt &v = N->getAPIntValue();
921-
APInt temp(16, 1);
922-
return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i16);
923-
}]>;
924-
925-
// Convert "sign/zero-extend, then shift left by an immediate" to mul.wide.
926-
let Predicates = [doMulWide] in {
927-
def : Pat<(shl (sext i32:$a), (i32 IntConst_0_30:$b)),
928-
(MULWIDES64Imm $a, (SHL2MUL32 $b))>;
929-
def : Pat<(shl (zext i32:$a), (i32 IntConst_0_30:$b)),
930-
(MULWIDEU64Imm $a, (SHL2MUL32 $b))>;
931-
932-
def : Pat<(shl (sext i16:$a), (i16 IntConst_0_14:$b)),
933-
(MULWIDES32Imm $a, (SHL2MUL16 $b))>;
934-
def : Pat<(shl (zext i16:$a), (i16 IntConst_0_14:$b)),
935-
(MULWIDEU32Imm $a, (SHL2MUL16 $b))>;
936-
937-
// Convert "sign/zero-extend then multiply" to mul.wide.
938-
def : Pat<(mul (sext i32:$a), (sext i32:$b)),
939-
(MULWIDES64 $a, $b)>;
940-
def : Pat<(mul (sext i32:$a), (i64 SInt32Const:$b)),
941-
(MULWIDES64Imm64 $a, (i64 SInt32Const:$b))>;
942-
943-
def : Pat<(mul (zext i32:$a), (zext i32:$b)),
944-
(MULWIDEU64 $a, $b)>;
945-
def : Pat<(mul (zext i32:$a), (i64 UInt32Const:$b)),
946-
(MULWIDEU64Imm64 $a, (i64 UInt32Const:$b))>;
947-
948-
def : Pat<(mul (sext i16:$a), (sext i16:$b)),
949-
(MULWIDES32 $a, $b)>;
950-
def : Pat<(mul (sext i16:$a), (i32 SInt16Const:$b)),
951-
(MULWIDES32Imm32 $a, (i32 SInt16Const:$b))>;
952-
953-
def : Pat<(mul (zext i16:$a), (zext i16:$b)),
954-
(MULWIDEU32 $a, $b)>;
955-
def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)),
956-
(MULWIDEU32Imm32 $a, (i32 UInt16Const:$b))>;
957-
}
958-
959870
//
960871
// Integer multiply-add
961872
//
@@ -991,6 +902,39 @@ defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>;
991902
defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>;
992903
}
993904

905+
multiclass MAD_WIDE<string PtxSuffix, OneUse2 Op, RegTyInfo BigT, RegTyInfo SmallT> {
906+
def rrr:
907+
BasicNVPTXInst<(outs BigT.RC:$dst),
908+
(ins SmallT.RC:$a, SmallT.RC:$b, BigT.RC:$c),
909+
"mad.wide." # PtxSuffix,
910+
[(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), BigT.Ty:$c))]>;
911+
def rri:
912+
BasicNVPTXInst<(outs BigT.RC:$dst),
913+
(ins SmallT.RC:$a, SmallT.RC:$b, BigT.Imm:$c),
914+
"mad.wide." # PtxSuffix,
915+
[(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), imm:$c))]>;
916+
def rir:
917+
BasicNVPTXInst<(outs BigT.RC:$dst),
918+
(ins SmallT.RC:$a, SmallT.Imm:$b, BigT.RC:$c),
919+
"mad.wide." # PtxSuffix,
920+
[(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), BigT.Ty:$c))]>;
921+
def rii:
922+
BasicNVPTXInst<(outs BigT.RC:$dst),
923+
(ins SmallT.RC:$a, SmallT.Imm:$b, BigT.Imm:$c),
924+
"mad.wide." # PtxSuffix,
925+
[(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), imm:$c))]>;
926+
}
927+
928+
def mul_wide_unsigned_oneuse : OneUse2<mul_wide_unsigned>;
929+
def mul_wide_signed_oneuse : OneUse2<mul_wide_signed>;
930+
931+
let Predicates = [hasOptEnabled] in {
932+
defm MAD_WIDE_U16 : MAD_WIDE<"u16", mul_wide_unsigned_oneuse, I32RT, I16RT>;
933+
defm MAD_WIDE_S16 : MAD_WIDE<"s16", mul_wide_signed_oneuse, I32RT, I16RT>;
934+
defm MAD_WIDE_U32 : MAD_WIDE<"u32", mul_wide_unsigned_oneuse, I64RT, I32RT>;
935+
defm MAD_WIDE_S32 : MAD_WIDE<"s32", mul_wide_signed_oneuse, I64RT, I32RT>;
936+
}
937+
994938
foreach t = [I16RT, I32RT, I64RT] in {
995939
def NEG_S # t.Size :
996940
BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),

llvm/test/CodeGen/NVPTX/bug26185-2.ll

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ define ptx_kernel void @spam(ptr addrspace(1) noalias nocapture readonly %arg, p
1616
; CHECK: .maxntid 1, 1, 1
1717
; CHECK-NEXT: {
1818
; CHECK-NEXT: .reg .b32 %r<2>;
19-
; CHECK-NEXT: .reg .b64 %rd<9>;
19+
; CHECK-NEXT: .reg .b64 %rd<8>;
2020
; CHECK-EMPTY:
2121
; CHECK-NEXT: // %bb.0: // %bb
2222
; CHECK-NEXT: ld.param.b64 %rd1, [spam_param_0];
@@ -25,10 +25,9 @@ define ptx_kernel void @spam(ptr addrspace(1) noalias nocapture readonly %arg, p
2525
; CHECK-NEXT: add.s64 %rd4, %rd1, %rd3;
2626
; CHECK-NEXT: ld.param.b64 %rd5, [spam_param_1];
2727
; CHECK-NEXT: ld.global.nc.s16 %r1, [%rd4+16];
28-
; CHECK-NEXT: mul.wide.s32 %rd6, %r1, %r1;
29-
; CHECK-NEXT: ld.global.b64 %rd7, [%rd5];
30-
; CHECK-NEXT: add.s64 %rd8, %rd6, %rd7;
31-
; CHECK-NEXT: st.global.b64 [%rd5], %rd8;
28+
; CHECK-NEXT: ld.global.b64 %rd6, [%rd5];
29+
; CHECK-NEXT: mad.wide.s32 %rd7, %r1, %r1, %rd6;
30+
; CHECK-NEXT: st.global.b64 [%rd5], %rd7;
3231
; CHECK-NEXT: ret;
3332
bb:
3433
%tmp5 = add nsw i64 %arg3, 8

0 commit comments

Comments
 (0)