@@ -160,7 +160,6 @@ def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
160160def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
161161
162162def True : Predicate<"true">;
163- def False : Predicate<"false">;
164163
165164class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
166165class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
@@ -257,6 +256,11 @@ def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
257256// "prmt.b32${mode}">;
258257// ---> "prmt.b32${mode} \t$d, $a, $b, $c;"
259258//
259+ // * BasicFlagsNVPTXInst<(outs Int64Regs:$state),
260+ // (ins ADDR:$addr),
261+ // "mbarrier.arrive.b64">;
262+ // ---> "mbarrier.arrive.b64 \t$state, [$addr];"
263+ //
260264class BasicFlagsNVPTXInst<dag outs_dag, dag ins_dag, dag flags_dag, string asmstr,
261265 list<dag> pattern = []>
262266 : NVPTXInst<
@@ -274,7 +278,11 @@ class BasicFlagsNVPTXInst<dag outs_dag, dag ins_dag, dag flags_dag, string asmst
274278 !if(!or(!empty(ins_dag), !empty(outs_dag)), "", ", "),
275279 !interleave(
276280 !foreach(i, !range(!size(ins_dag)),
277- "$" # !getdagname(ins_dag, i)),
281+ !if(!eq(!cast<string>(!getdagarg<DAGOperand>(ins_dag, i)), "ADDR"),
282+ "[$" # !getdagname(ins_dag, i) # "]",
283+ "$" # !getdagname(ins_dag, i)
284+ )
285+ ),
278286 ", "))),
279287 ";"),
280288 pattern>;
@@ -956,31 +964,17 @@ def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
956964def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
957965
958966// Matchers for signed, unsigned mul.wide ISD nodes.
959- def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)),
960- (MULWIDES32 $a, $b)>,
961- Requires<[doMulWide]>;
962- def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)),
963- (MULWIDES32Imm $a, imm:$b)>,
964- Requires<[doMulWide]>;
965- def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)),
966- (MULWIDEU32 $a, $b)>,
967- Requires<[doMulWide]>;
968- def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)),
969- (MULWIDEU32Imm $a, imm:$b)>,
970- Requires<[doMulWide]>;
967+ let Predicates = [doMulWide] in {
968+ def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
969+ def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
970+ def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
971+ def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>;
971972
972- def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)),
973- (MULWIDES64 $a, $b)>,
974- Requires<[doMulWide]>;
975- def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)),
976- (MULWIDES64Imm $a, imm:$b)>,
977- Requires<[doMulWide]>;
978- def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)),
979- (MULWIDEU64 $a, $b)>,
980- Requires<[doMulWide]>;
981- def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)),
982- (MULWIDEU64Imm $a, imm:$b)>,
983- Requires<[doMulWide]>;
973+ def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>;
974+ def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>;
975+ def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>;
976+ def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
977+ }
984978
985979// Predicates used for converting some patterns to mul.wide.
986980def SInt32Const : PatLeaf<(imm), [{
@@ -1106,18 +1100,12 @@ defm MAD32 : MAD<"mad.lo.s32", i32, Int32Regs, i32imm>;
11061100defm MAD64 : MAD<"mad.lo.s64", i64, Int64Regs, i64imm>;
11071101}
11081102
1109- def INEG16 :
1110- BasicNVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
1111- "neg.s16",
1112- [(set i16:$dst, (ineg i16:$src))]>;
1113- def INEG32 :
1114- BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src),
1115- "neg.s32",
1116- [(set i32:$dst, (ineg i32:$src))]>;
1117- def INEG64 :
1118- BasicNVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
1119- "neg.s64",
1120- [(set i64:$dst, (ineg i64:$src))]>;
1103+ foreach t = [I16RT, I32RT, I64RT] in {
1104+ def NEG_S # t.Size :
1105+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
1106+ "neg.s" # t.Size,
1107+ [(set t.Ty:$dst, (ineg t.Ty:$src))]>;
1108+ }
11211109
11221110//-----------------------------------
11231111// Floating Point Arithmetic
@@ -1538,7 +1526,7 @@ def bfi : SDNode<"NVPTXISD::BFI", SDTBFI>;
15381526
15391527def SDTPRMT :
15401528 SDTypeProfile<1, 4, [SDTCisVT<0, i32>, SDTCisVT<1, i32>,
1541- SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>, ]>;
1529+ SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>]>;
15421530def prmt : SDNode<"NVPTXISD::PRMT", SDTPRMT>;
15431531
15441532multiclass BFE<string Instr, ValueType T, RegisterClass RC> {
@@ -1961,15 +1949,15 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
19611949 // f16 -> pred
19621950 def : Pat<(i1 (OpNode f16:$a, f16:$b)),
19631951 (SETP_f16rr $a, $b, ModeFTZ)>,
1964- Requires<[useFP16Math,doF32FTZ]>;
1952+ Requires<[useFP16Math, doF32FTZ]>;
19651953 def : Pat<(i1 (OpNode f16:$a, f16:$b)),
19661954 (SETP_f16rr $a, $b, Mode)>,
19671955 Requires<[useFP16Math]>;
19681956
19691957 // bf16 -> pred
19701958 def : Pat<(i1 (OpNode bf16:$a, bf16:$b)),
19711959 (SETP_bf16rr $a, $b, ModeFTZ)>,
1972- Requires<[hasBF16Math,doF32FTZ]>;
1960+ Requires<[hasBF16Math, doF32FTZ]>;
19731961 def : Pat<(i1 (OpNode bf16:$a, bf16:$b)),
19741962 (SETP_bf16rr $a, $b, Mode)>,
19751963 Requires<[hasBF16Math]>;
@@ -2497,24 +2485,20 @@ def : Pat<(f16 (uint_to_fp i32:$a)), (CVT_f16_u32 $a, CvtRN)>;
24972485def : Pat<(f16 (uint_to_fp i64:$a)), (CVT_f16_u64 $a, CvtRN)>;
24982486
24992487// sint -> bf16
2500- def : Pat<(bf16 (sint_to_fp i1:$a)), (CVT_bf16_s32 (SELP_b32ii 1, 0, $a), CvtRN)>,
2501- Requires<[hasPTX<78>, hasSM<90>]>;
2502- def : Pat<(bf16 (sint_to_fp i16:$a)), (CVT_bf16_s16 $a, CvtRN)>,
2503- Requires<[hasPTX<78>, hasSM<90>]>;
2504- def : Pat<(bf16 (sint_to_fp i32:$a)), (CVT_bf16_s32 $a, CvtRN)>,
2505- Requires<[hasPTX<78>, hasSM<90>]>;
2506- def : Pat<(bf16 (sint_to_fp i64:$a)), (CVT_bf16_s64 $a, CvtRN)>,
2507- Requires<[hasPTX<78>, hasSM<90>]>;
2488+ let Predicates = [hasPTX<78>, hasSM<90>] in {
2489+ def : Pat<(bf16 (sint_to_fp i1:$a)), (CVT_bf16_s32 (SELP_b32ii 1, 0, $a), CvtRN)>;
2490+ def : Pat<(bf16 (sint_to_fp i16:$a)), (CVT_bf16_s16 $a, CvtRN)>;
2491+ def : Pat<(bf16 (sint_to_fp i32:$a)), (CVT_bf16_s32 $a, CvtRN)>;
2492+ def : Pat<(bf16 (sint_to_fp i64:$a)), (CVT_bf16_s64 $a, CvtRN)>;
2493+ }
25082494
25092495// uint -> bf16
2510- def : Pat<(bf16 (uint_to_fp i1:$a)), (CVT_bf16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>,
2511- Requires<[hasPTX<78>, hasSM<90>]>;
2512- def : Pat<(bf16 (uint_to_fp i16:$a)), (CVT_bf16_u16 $a, CvtRN)>,
2513- Requires<[hasPTX<78>, hasSM<90>]>;
2514- def : Pat<(bf16 (uint_to_fp i32:$a)), (CVT_bf16_u32 $a, CvtRN)>,
2515- Requires<[hasPTX<78>, hasSM<90>]>;
2516- def : Pat<(bf16 (uint_to_fp i64:$a)), (CVT_bf16_u64 $a, CvtRN)>,
2517- Requires<[hasPTX<78>, hasSM<90>]>;
2496+ let Predicates = [hasPTX<78>, hasSM<90>] in {
2497+ def : Pat<(bf16 (uint_to_fp i1:$a)), (CVT_bf16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>;
2498+ def : Pat<(bf16 (uint_to_fp i16:$a)), (CVT_bf16_u16 $a, CvtRN)>;
2499+ def : Pat<(bf16 (uint_to_fp i32:$a)), (CVT_bf16_u32 $a, CvtRN)>;
2500+ def : Pat<(bf16 (uint_to_fp i64:$a)), (CVT_bf16_u64 $a, CvtRN)>;
2501+ }
25182502
25192503// sint -> f32
25202504def : Pat<(f32 (sint_to_fp i1:$a)), (CVT_f32_s32 (SELP_b32ii -1, 0, $a), CvtRN)>;
@@ -2565,27 +2549,25 @@ def : Pat<(i16 (fp_to_uint bf16:$a)), (CVT_u16_bf16 $a, CvtRZI)>;
25652549def : Pat<(i32 (fp_to_uint bf16:$a)), (CVT_u32_bf16 $a, CvtRZI)>;
25662550def : Pat<(i64 (fp_to_uint bf16:$a)), (CVT_u64_bf16 $a, CvtRZI)>;
25672551// f32 -> sint
2568- def : Pat<(i1 (fp_to_sint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
2569- def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI_FTZ)>,
2570- Requires<[doF32FTZ]>;
2552+ let Predicates = [doF32FTZ] in {
2553+ def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI_FTZ)>;
2554+ def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI_FTZ)>;
2555+ def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI_FTZ)>;
2556+ }
2557+ def : Pat<(i1 (fp_to_sint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
25712558def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI)>;
2572- def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI_FTZ)>,
2573- Requires<[doF32FTZ]>;
25742559def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI)>;
2575- def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI_FTZ)>,
2576- Requires<[doF32FTZ]>;
25772560def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI)>;
25782561
25792562// f32 -> uint
2563+ let Predicates = [doF32FTZ] in {
2564+ def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI_FTZ)>;
2565+ def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI_FTZ)>;
2566+ def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI_FTZ)>;
2567+ }
25802568def : Pat<(i1 (fp_to_uint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
2581- def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI_FTZ)>,
2582- Requires<[doF32FTZ]>;
25832569def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI)>;
2584- def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI_FTZ)>,
2585- Requires<[doF32FTZ]>;
25862570def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI)>;
2587- def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI_FTZ)>,
2588- Requires<[doF32FTZ]>;
25892571def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI)>;
25902572
25912573// f64 -> sint
@@ -2707,28 +2689,24 @@ let hasSideEffects = false in {
27072689
27082690 // PTX 7.1 lets you avoid a temp register and just use _ as a "sink" for the
27092691 // unused high/low part.
2710- def I32toI16H_Sink : NVPTXInst<(outs Int16Regs:$high),
2711- (ins Int32Regs:$s),
2712- "mov.b32 \t{{_, $high}}, $s;",
2713- []>, Requires<[hasPTX<71>]>;
2714- def I32toI16L_Sink : NVPTXInst<(outs Int16Regs:$low),
2715- (ins Int32Regs:$s),
2716- "mov.b32 \t{{$low, _}}, $s;",
2717- []>, Requires<[hasPTX<71>]>;
2718- def I64toI32H_Sink : NVPTXInst<(outs Int32Regs:$high),
2719- (ins Int64Regs:$s),
2720- "mov.b64 \t{{_, $high}}, $s;",
2721- []>, Requires<[hasPTX<71>]>;
2722- def I64toI32L_Sink : NVPTXInst<(outs Int32Regs:$low),
2723- (ins Int64Regs:$s),
2724- "mov.b64 \t{{$low, _}}, $s;",
2725- []>, Requires<[hasPTX<71>]>;
2692+ let Predicates = [hasPTX<71>] in {
2693+ def I32toI16H_Sink : NVPTXInst<(outs Int16Regs:$high), (ins Int32Regs:$s),
2694+ "mov.b32 \t{{_, $high}}, $s;", []>;
2695+ def I32toI16L_Sink : NVPTXInst<(outs Int16Regs:$low), (ins Int32Regs:$s),
2696+ "mov.b32 \t{{$low, _}}, $s;", []>;
2697+ def I64toI32H_Sink : NVPTXInst<(outs Int32Regs:$high), (ins Int64Regs:$s),
2698+ "mov.b64 \t{{_, $high}}, $s;", []>;
2699+ def I64toI32L_Sink : NVPTXInst<(outs Int32Regs:$low), (ins Int64Regs:$s),
2700+ "mov.b64 \t{{$low, _}}, $s;", []>;
2701+ }
27262702}
27272703
2728- def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>, Requires<[hasPTX<71>]>;
2729- def : Pat<(i16 (trunc (sra i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>, Requires<[hasPTX<71>]>;
2730- def : Pat<(i32 (trunc (srl i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>, Requires<[hasPTX<71>]>;
2731- def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>, Requires<[hasPTX<71>]>;
2704+ let Predicates = [hasPTX<71>] in {
2705+ def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>;
2706+ def : Pat<(i16 (trunc (sra i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>;
2707+ def : Pat<(i32 (trunc (srl i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>;
2708+ def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>;
2709+ }
27322710
27332711// Fall back to the old way if we don't have PTX 7.1.
27342712def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H $s)>;
@@ -3061,29 +3039,19 @@ def stacksave :
30613039 SDNode<"NVPTXISD::STACKSAVE", SDTIntLeaf,
30623040 [SDNPHasChain, SDNPSideEffect]>;
30633041
3064- def STACKRESTORE_32 :
3065- BasicNVPTXInst<(outs), (ins Int32Regs:$ptr),
3066- "stackrestore.u32",
3067- [(stackrestore i32:$ptr)]>,
3068- Requires<[hasPTX<73>, hasSM<52>]>;
3069-
3070- def STACKSAVE_32 :
3071- BasicNVPTXInst<(outs Int32Regs:$dst), (ins),
3072- "stacksave.u32",
3073- [(set i32:$dst, (i32 stacksave))]>,
3074- Requires<[hasPTX<73>, hasSM<52>]>;
3075-
3076- def STACKRESTORE_64 :
3077- BasicNVPTXInst<(outs), (ins Int64Regs:$ptr),
3078- "stackrestore.u64",
3079- [(stackrestore i64:$ptr)]>,
3080- Requires<[hasPTX<73>, hasSM<52>]>;
3081-
3082- def STACKSAVE_64 :
3083- BasicNVPTXInst<(outs Int64Regs:$dst), (ins),
3084- "stacksave.u64",
3085- [(set i64:$dst, (i64 stacksave))]>,
3086- Requires<[hasPTX<73>, hasSM<52>]>;
3042+ let Predicates = [hasPTX<73>, hasSM<52>] in {
3043+ foreach t = [I32RT, I64RT] in {
3044+ def STACKRESTORE_ # t.Size :
3045+ BasicNVPTXInst<(outs), (ins t.RC:$ptr),
3046+ "stackrestore.u" # t.Size,
3047+ [(stackrestore t.Ty:$ptr)]>;
3048+
3049+ def STACKSAVE_ # t.Size :
3050+ BasicNVPTXInst<(outs t.RC:$dst), (ins),
3051+ "stacksave.u" # t.Size,
3052+ [(set t.Ty:$dst, (t.Ty stacksave))]>;
3053+ }
3054+ }
30873055
30883056include "NVPTXIntrinsics.td"
30893057
@@ -3124,7 +3092,7 @@ def : Pat <
31243092////////////////////////////////////////////////////////////////////////////////
31253093
31263094class NVPTXFenceInst<string scope, string sem, Predicate ptx>:
3127- NVPTXInst <(outs), (ins), "fence."#sem#"."#scope#";", [] >,
3095+ BasicNVPTXInst <(outs), (ins), "fence."#sem#"."#scope>,
31283096 Requires<[ptx, hasSM<70>]>;
31293097
31303098foreach scope = ["sys", "gpu", "cluster", "cta"] in {
0 commit comments