@@ -125,8 +125,6 @@ def doF32FTZ : Predicate<"useF32FTZ()">;
125
125
def doNoF32FTZ : Predicate<"!useF32FTZ()">;
126
126
def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
127
127
128
- def doMulWide : Predicate<"doMulWide">;
129
-
130
128
def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
131
129
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
132
130
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
@@ -836,36 +834,28 @@ def MULWIDES64 :
836
834
BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.s32">;
837
835
def MULWIDES64Imm :
838
836
BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.s32">;
839
- def MULWIDES64Imm64 :
840
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.s32">;
841
837
842
838
def MULWIDEU64 :
843
839
BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.u32">;
844
840
def MULWIDEU64Imm :
845
841
BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.u32">;
846
- def MULWIDEU64Imm64 :
847
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.u32">;
848
842
849
843
def MULWIDES32 :
850
844
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.s16">;
851
845
def MULWIDES32Imm :
852
846
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.s16">;
853
- def MULWIDES32Imm32 :
854
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.s16">;
855
847
856
848
def MULWIDEU32 :
857
849
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.u16">;
858
850
def MULWIDEU32Imm :
859
851
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.u16">;
860
- def MULWIDEU32Imm32 :
861
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.u16">;
862
852
863
- def SDTMulWide : SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>;
864
- def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
865
- def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
853
+ def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>;
854
+ def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative] >;
855
+ def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative] >;
866
856
867
857
// Matchers for signed, unsigned mul.wide ISD nodes.
868
- let Predicates = [doMulWide ] in {
858
+ let Predicates = [hasOptEnabled ] in {
869
859
def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
870
860
def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
871
861
def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
@@ -877,85 +867,6 @@ let Predicates = [doMulWide] in {
877
867
def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
878
868
}
879
869
880
- // Predicates used for converting some patterns to mul.wide.
881
- def SInt32Const : PatLeaf<(imm), [{
882
- const APInt &v = N->getAPIntValue();
883
- return v.isSignedIntN(32);
884
- }]>;
885
-
886
- def UInt32Const : PatLeaf<(imm), [{
887
- const APInt &v = N->getAPIntValue();
888
- return v.isIntN(32);
889
- }]>;
890
-
891
- def SInt16Const : PatLeaf<(imm), [{
892
- const APInt &v = N->getAPIntValue();
893
- return v.isSignedIntN(16);
894
- }]>;
895
-
896
- def UInt16Const : PatLeaf<(imm), [{
897
- const APInt &v = N->getAPIntValue();
898
- return v.isIntN(16);
899
- }]>;
900
-
901
- def IntConst_0_30 : PatLeaf<(imm), [{
902
- // Check if 0 <= v < 31; only then will the result of (x << v) be an int32.
903
- const APInt &v = N->getAPIntValue();
904
- return v.sge(0) && v.slt(31);
905
- }]>;
906
-
907
- def IntConst_0_14 : PatLeaf<(imm), [{
908
- // Check if 0 <= v < 15; only then will the result of (x << v) be an int16.
909
- const APInt &v = N->getAPIntValue();
910
- return v.sge(0) && v.slt(15);
911
- }]>;
912
-
913
- def SHL2MUL32 : SDNodeXForm<imm, [{
914
- const APInt &v = N->getAPIntValue();
915
- APInt temp(32, 1);
916
- return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i32);
917
- }]>;
918
-
919
- def SHL2MUL16 : SDNodeXForm<imm, [{
920
- const APInt &v = N->getAPIntValue();
921
- APInt temp(16, 1);
922
- return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i16);
923
- }]>;
924
-
925
- // Convert "sign/zero-extend, then shift left by an immediate" to mul.wide.
926
- let Predicates = [doMulWide] in {
927
- def : Pat<(shl (sext i32:$a), (i32 IntConst_0_30:$b)),
928
- (MULWIDES64Imm $a, (SHL2MUL32 $b))>;
929
- def : Pat<(shl (zext i32:$a), (i32 IntConst_0_30:$b)),
930
- (MULWIDEU64Imm $a, (SHL2MUL32 $b))>;
931
-
932
- def : Pat<(shl (sext i16:$a), (i16 IntConst_0_14:$b)),
933
- (MULWIDES32Imm $a, (SHL2MUL16 $b))>;
934
- def : Pat<(shl (zext i16:$a), (i16 IntConst_0_14:$b)),
935
- (MULWIDEU32Imm $a, (SHL2MUL16 $b))>;
936
-
937
- // Convert "sign/zero-extend then multiply" to mul.wide.
938
- def : Pat<(mul (sext i32:$a), (sext i32:$b)),
939
- (MULWIDES64 $a, $b)>;
940
- def : Pat<(mul (sext i32:$a), (i64 SInt32Const:$b)),
941
- (MULWIDES64Imm64 $a, (i64 SInt32Const:$b))>;
942
-
943
- def : Pat<(mul (zext i32:$a), (zext i32:$b)),
944
- (MULWIDEU64 $a, $b)>;
945
- def : Pat<(mul (zext i32:$a), (i64 UInt32Const:$b)),
946
- (MULWIDEU64Imm64 $a, (i64 UInt32Const:$b))>;
947
-
948
- def : Pat<(mul (sext i16:$a), (sext i16:$b)),
949
- (MULWIDES32 $a, $b)>;
950
- def : Pat<(mul (sext i16:$a), (i32 SInt16Const:$b)),
951
- (MULWIDES32Imm32 $a, (i32 SInt16Const:$b))>;
952
-
953
- def : Pat<(mul (zext i16:$a), (zext i16:$b)),
954
- (MULWIDEU32 $a, $b)>;
955
- def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)),
956
- (MULWIDEU32Imm32 $a, (i32 UInt16Const:$b))>;
957
- }
958
-
959
870
//
960
871
// Integer multiply-add
961
872
//
@@ -991,6 +902,39 @@ defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>;
991
902
defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>;
992
903
}
993
904
905
+ multiclass MAD_WIDE<string PtxSuffix, OneUse2 Op, RegTyInfo BigT, RegTyInfo SmallT> {
906
+ def rrr:
907
+ BasicNVPTXInst<(outs BigT.RC:$dst),
908
+ (ins SmallT.RC:$a, SmallT.RC:$b, BigT.RC:$c),
909
+ "mad.wide." # PtxSuffix,
910
+ [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), BigT.Ty:$c))]>;
911
+ def rri:
912
+ BasicNVPTXInst<(outs BigT.RC:$dst),
913
+ (ins SmallT.RC:$a, SmallT.RC:$b, BigT.Imm:$c),
914
+ "mad.wide." # PtxSuffix,
915
+ [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, SmallT.Ty:$b), imm:$c))]>;
916
+ def rir:
917
+ BasicNVPTXInst<(outs BigT.RC:$dst),
918
+ (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.RC:$c),
919
+ "mad.wide." # PtxSuffix,
920
+ [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), BigT.Ty:$c))]>;
921
+ def rii:
922
+ BasicNVPTXInst<(outs BigT.RC:$dst),
923
+ (ins SmallT.RC:$a, SmallT.Imm:$b, BigT.Imm:$c),
924
+ "mad.wide." # PtxSuffix,
925
+ [(set BigT.Ty:$dst, (add (Op SmallT.Ty:$a, imm:$b), imm:$c))]>;
926
+ }
927
+
928
+ def mul_wide_unsigned_oneuse : OneUse2<mul_wide_unsigned>;
929
+ def mul_wide_signed_oneuse : OneUse2<mul_wide_signed>;
930
+
931
+ let Predicates = [hasOptEnabled] in {
932
+ defm MAD_WIDE_U16 : MAD_WIDE<"u16", mul_wide_unsigned_oneuse, I32RT, I16RT>;
933
+ defm MAD_WIDE_S16 : MAD_WIDE<"s16", mul_wide_signed_oneuse, I32RT, I16RT>;
934
+ defm MAD_WIDE_U32 : MAD_WIDE<"u32", mul_wide_unsigned_oneuse, I64RT, I32RT>;
935
+ defm MAD_WIDE_S32 : MAD_WIDE<"s32", mul_wide_signed_oneuse, I64RT, I32RT>;
936
+ }
937
+
994
938
foreach t = [I16RT, I32RT, I64RT] in {
995
939
def NEG_S # t.Size :
996
940
BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
0 commit comments