Skip to content

Commit 4db1a8c

Browse files
committed
more cleanup
1 parent 0a9e248 commit 4db1a8c

File tree

3 files changed

+864
-2141
lines changed

3 files changed

+864
-2141
lines changed

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 68 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
160160
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
161161

162162
def True : Predicate<"true">;
163-
def False : Predicate<"false">;
164163

165164
class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
166165
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
@@ -965,31 +964,17 @@ def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
965964
def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
966965

967966
// Matchers for signed, unsigned mul.wide ISD nodes.
968-
def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)),
969-
(MULWIDES32 $a, $b)>,
970-
Requires<[doMulWide]>;
971-
def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)),
972-
(MULWIDES32Imm $a, imm:$b)>,
973-
Requires<[doMulWide]>;
974-
def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)),
975-
(MULWIDEU32 $a, $b)>,
976-
Requires<[doMulWide]>;
977-
def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)),
978-
(MULWIDEU32Imm $a, imm:$b)>,
979-
Requires<[doMulWide]>;
967+
let Predicates = [doMulWide] in {
968+
def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
969+
def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
970+
def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
971+
def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>;
980972

981-
def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)),
982-
(MULWIDES64 $a, $b)>,
983-
Requires<[doMulWide]>;
984-
def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)),
985-
(MULWIDES64Imm $a, imm:$b)>,
986-
Requires<[doMulWide]>;
987-
def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)),
988-
(MULWIDEU64 $a, $b)>,
989-
Requires<[doMulWide]>;
990-
def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)),
991-
(MULWIDEU64Imm $a, imm:$b)>,
992-
Requires<[doMulWide]>;
973+
def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>;
974+
def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>;
975+
def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>;
976+
def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
977+
}
993978

994979
// Predicates used for converting some patterns to mul.wide.
995980
def SInt32Const : PatLeaf<(imm), [{
@@ -1115,18 +1100,12 @@ defm MAD32 : MAD<"mad.lo.s32", i32, Int32Regs, i32imm>;
11151100
defm MAD64 : MAD<"mad.lo.s64", i64, Int64Regs, i64imm>;
11161101
}
11171102

1118-
def INEG16 :
1119-
BasicNVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
1120-
"neg.s16",
1121-
[(set i16:$dst, (ineg i16:$src))]>;
1122-
def INEG32 :
1123-
BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src),
1124-
"neg.s32",
1125-
[(set i32:$dst, (ineg i32:$src))]>;
1126-
def INEG64 :
1127-
BasicNVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
1128-
"neg.s64",
1129-
[(set i64:$dst, (ineg i64:$src))]>;
1103+
foreach t = [I16RT, I32RT, I64RT] in {
1104+
def NEG_S # t.Size :
1105+
BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
1106+
"neg.s" # t.Size,
1107+
[(set t.Ty:$dst, (ineg t.Ty:$src))]>;
1108+
}
11301109

11311110
//-----------------------------------
11321111
// Floating Point Arithmetic
@@ -2506,24 +2485,20 @@ def : Pat<(f16 (uint_to_fp i32:$a)), (CVT_f16_u32 $a, CvtRN)>;
25062485
def : Pat<(f16 (uint_to_fp i64:$a)), (CVT_f16_u64 $a, CvtRN)>;
25072486

25082487
// sint -> bf16
2509-
def : Pat<(bf16 (sint_to_fp i1:$a)), (CVT_bf16_s32 (SELP_b32ii 1, 0, $a), CvtRN)>,
2510-
Requires<[hasPTX<78>, hasSM<90>]>;
2511-
def : Pat<(bf16 (sint_to_fp i16:$a)), (CVT_bf16_s16 $a, CvtRN)>,
2512-
Requires<[hasPTX<78>, hasSM<90>]>;
2513-
def : Pat<(bf16 (sint_to_fp i32:$a)), (CVT_bf16_s32 $a, CvtRN)>,
2514-
Requires<[hasPTX<78>, hasSM<90>]>;
2515-
def : Pat<(bf16 (sint_to_fp i64:$a)), (CVT_bf16_s64 $a, CvtRN)>,
2516-
Requires<[hasPTX<78>, hasSM<90>]>;
2488+
let Predicates = [hasPTX<78>, hasSM<90>] in {
2489+
def : Pat<(bf16 (sint_to_fp i1:$a)), (CVT_bf16_s32 (SELP_b32ii 1, 0, $a), CvtRN)>;
2490+
def : Pat<(bf16 (sint_to_fp i16:$a)), (CVT_bf16_s16 $a, CvtRN)>;
2491+
def : Pat<(bf16 (sint_to_fp i32:$a)), (CVT_bf16_s32 $a, CvtRN)>;
2492+
def : Pat<(bf16 (sint_to_fp i64:$a)), (CVT_bf16_s64 $a, CvtRN)>;
2493+
}
25172494

25182495
// uint -> bf16
2519-
def : Pat<(bf16 (uint_to_fp i1:$a)), (CVT_bf16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>,
2520-
Requires<[hasPTX<78>, hasSM<90>]>;
2521-
def : Pat<(bf16 (uint_to_fp i16:$a)), (CVT_bf16_u16 $a, CvtRN)>,
2522-
Requires<[hasPTX<78>, hasSM<90>]>;
2523-
def : Pat<(bf16 (uint_to_fp i32:$a)), (CVT_bf16_u32 $a, CvtRN)>,
2524-
Requires<[hasPTX<78>, hasSM<90>]>;
2525-
def : Pat<(bf16 (uint_to_fp i64:$a)), (CVT_bf16_u64 $a, CvtRN)>,
2526-
Requires<[hasPTX<78>, hasSM<90>]>;
2496+
let Predicates = [hasPTX<78>, hasSM<90>] in {
2497+
def : Pat<(bf16 (uint_to_fp i1:$a)), (CVT_bf16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>;
2498+
def : Pat<(bf16 (uint_to_fp i16:$a)), (CVT_bf16_u16 $a, CvtRN)>;
2499+
def : Pat<(bf16 (uint_to_fp i32:$a)), (CVT_bf16_u32 $a, CvtRN)>;
2500+
def : Pat<(bf16 (uint_to_fp i64:$a)), (CVT_bf16_u64 $a, CvtRN)>;
2501+
}
25272502

25282503
// sint -> f32
25292504
def : Pat<(f32 (sint_to_fp i1:$a)), (CVT_f32_s32 (SELP_b32ii -1, 0, $a), CvtRN)>;
@@ -2574,27 +2549,25 @@ def : Pat<(i16 (fp_to_uint bf16:$a)), (CVT_u16_bf16 $a, CvtRZI)>;
25742549
def : Pat<(i32 (fp_to_uint bf16:$a)), (CVT_u32_bf16 $a, CvtRZI)>;
25752550
def : Pat<(i64 (fp_to_uint bf16:$a)), (CVT_u64_bf16 $a, CvtRZI)>;
25762551
// f32 -> sint
2577-
def : Pat<(i1 (fp_to_sint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
2578-
def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI_FTZ)>,
2579-
Requires<[doF32FTZ]>;
2552+
let Predicates = [doF32FTZ] in {
2553+
def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI_FTZ)>;
2554+
def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI_FTZ)>;
2555+
def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI_FTZ)>;
2556+
}
2557+
def : Pat<(i1 (fp_to_sint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
25802558
def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI)>;
2581-
def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI_FTZ)>,
2582-
Requires<[doF32FTZ]>;
25832559
def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI)>;
2584-
def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI_FTZ)>,
2585-
Requires<[doF32FTZ]>;
25862560
def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI)>;
25872561

25882562
// f32 -> uint
2563+
let Predicates = [doF32FTZ] in {
2564+
def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI_FTZ)>;
2565+
def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI_FTZ)>;
2566+
def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI_FTZ)>;
2567+
}
25892568
def : Pat<(i1 (fp_to_uint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
2590-
def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI_FTZ)>,
2591-
Requires<[doF32FTZ]>;
25922569
def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI)>;
2593-
def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI_FTZ)>,
2594-
Requires<[doF32FTZ]>;
25952570
def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI)>;
2596-
def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI_FTZ)>,
2597-
Requires<[doF32FTZ]>;
25982571
def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI)>;
25992572

26002573
// f64 -> sint
@@ -2716,28 +2689,24 @@ let hasSideEffects = false in {
27162689

27172690
// PTX 7.1 lets you avoid a temp register and just use _ as a "sink" for the
27182691
// unused high/low part.
2719-
def I32toI16H_Sink : NVPTXInst<(outs Int16Regs:$high),
2720-
(ins Int32Regs:$s),
2721-
"mov.b32 \t{{_, $high}}, $s;",
2722-
[]>, Requires<[hasPTX<71>]>;
2723-
def I32toI16L_Sink : NVPTXInst<(outs Int16Regs:$low),
2724-
(ins Int32Regs:$s),
2725-
"mov.b32 \t{{$low, _}}, $s;",
2726-
[]>, Requires<[hasPTX<71>]>;
2727-
def I64toI32H_Sink : NVPTXInst<(outs Int32Regs:$high),
2728-
(ins Int64Regs:$s),
2729-
"mov.b64 \t{{_, $high}}, $s;",
2730-
[]>, Requires<[hasPTX<71>]>;
2731-
def I64toI32L_Sink : NVPTXInst<(outs Int32Regs:$low),
2732-
(ins Int64Regs:$s),
2733-
"mov.b64 \t{{$low, _}}, $s;",
2734-
[]>, Requires<[hasPTX<71>]>;
2692+
let Predicates = [hasPTX<71>] in {
2693+
def I32toI16H_Sink : NVPTXInst<(outs Int16Regs:$high), (ins Int32Regs:$s),
2694+
"mov.b32 \t{{_, $high}}, $s;", []>;
2695+
def I32toI16L_Sink : NVPTXInst<(outs Int16Regs:$low), (ins Int32Regs:$s),
2696+
"mov.b32 \t{{$low, _}}, $s;", []>;
2697+
def I64toI32H_Sink : NVPTXInst<(outs Int32Regs:$high), (ins Int64Regs:$s),
2698+
"mov.b64 \t{{_, $high}}, $s;", []>;
2699+
def I64toI32L_Sink : NVPTXInst<(outs Int32Regs:$low), (ins Int64Regs:$s),
2700+
"mov.b64 \t{{$low, _}}, $s;", []>;
2701+
}
27352702
}
27362703

2737-
def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>, Requires<[hasPTX<71>]>;
2738-
def : Pat<(i16 (trunc (sra i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>, Requires<[hasPTX<71>]>;
2739-
def : Pat<(i32 (trunc (srl i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>, Requires<[hasPTX<71>]>;
2740-
def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>, Requires<[hasPTX<71>]>;
2704+
let Predicates = [hasPTX<71>] in {
2705+
def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>;
2706+
def : Pat<(i16 (trunc (sra i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>;
2707+
def : Pat<(i32 (trunc (srl i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>;
2708+
def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>;
2709+
}
27412710

27422711
// Fall back to the old way if we don't have PTX 7.1.
27432712
def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H $s)>;
@@ -3070,29 +3039,19 @@ def stacksave :
30703039
SDNode<"NVPTXISD::STACKSAVE", SDTIntLeaf,
30713040
[SDNPHasChain, SDNPSideEffect]>;
30723041

3073-
def STACKRESTORE_32 :
3074-
BasicNVPTXInst<(outs), (ins Int32Regs:$ptr),
3075-
"stackrestore.u32",
3076-
[(stackrestore i32:$ptr)]>,
3077-
Requires<[hasPTX<73>, hasSM<52>]>;
3078-
3079-
def STACKSAVE_32 :
3080-
BasicNVPTXInst<(outs Int32Regs:$dst), (ins),
3081-
"stacksave.u32",
3082-
[(set i32:$dst, (i32 stacksave))]>,
3083-
Requires<[hasPTX<73>, hasSM<52>]>;
3084-
3085-
def STACKRESTORE_64 :
3086-
BasicNVPTXInst<(outs), (ins Int64Regs:$ptr),
3087-
"stackrestore.u64",
3088-
[(stackrestore i64:$ptr)]>,
3089-
Requires<[hasPTX<73>, hasSM<52>]>;
3090-
3091-
def STACKSAVE_64 :
3092-
BasicNVPTXInst<(outs Int64Regs:$dst), (ins),
3093-
"stacksave.u64",
3094-
[(set i64:$dst, (i64 stacksave))]>,
3095-
Requires<[hasPTX<73>, hasSM<52>]>;
3042+
let Predicates = [hasPTX<73>, hasSM<52>] in {
3043+
foreach t = [I32RT, I64RT] in {
3044+
def STACKRESTORE_ # t.Size :
3045+
BasicNVPTXInst<(outs), (ins t.RC:$ptr),
3046+
"stackrestore.u" # t.Size,
3047+
[(stackrestore t.Ty:$ptr)]>;
3048+
3049+
def STACKSAVE_ # t.Size :
3050+
BasicNVPTXInst<(outs t.RC:$dst), (ins),
3051+
"stacksave.u" # t.Size,
3052+
[(set t.Ty:$dst, (t.Ty stacksave))]>;
3053+
}
3054+
}
30963055

30973056
include "NVPTXIntrinsics.td"
30983057

0 commit comments

Comments
 (0)