Skip to content

Commit 5678aef

Browse files
authored
[NVPTX] Add support for integer min/max ReLU idiom (#151727)
1 parent 203b35d commit 5678aef

File tree

2 files changed

+730
-109
lines changed

2 files changed

+730
-109
lines changed

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ class OneUse2<SDPatternOperator operator>
151151
class fpimm_pos_inf<ValueType vt>
152152
: FPImmLeaf<vt, [{ return Imm.isPosInfinity(); }]>;
153153

154+
class zeroinitializer<ValueType vt> :
155+
PatLeaf<(vt (bitconvert (!cast<ValueType>("i" # vt.Size) 0)))>;
154156

155157

156158
// Operands which can hold a Register or an Immediate.
@@ -789,6 +791,23 @@ def UMAX16x2 : I16x2<"max.u", umax>;
789791
def SMIN16x2 : I16x2<"min.s", smin>;
790792
def UMIN16x2 : I16x2<"min.u", umin>;
791793

794+
let Predicates = [hasPTX<80>, hasSM<90>] in {
795+
796+
def MIN_RELU_S32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
797+
"min.relu.s32",
798+
[(set i32:$dst, (smax (smin i32:$a, i32:$b), 0))]>;
799+
def MAX_RELU_S32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
800+
"max.relu.s32",
801+
[(set i32:$dst, (smax (smax i32:$a, i32:$b), 0))]>;
802+
def MIN_RELU_S16x2 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
803+
"min.relu.s16x2",
804+
[(set v2i16:$dst, (smax (smin v2i16:$a, v2i16:$b),
805+
zeroinitializer<v2i16>))]>;
806+
def MAX_RELU_S16x2 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
807+
"max.relu.s16x2",
808+
[(set v2i16:$dst, (smax (smax v2i16:$a, v2i16:$b),
809+
zeroinitializer<v2i16>))]>;
810+
}
792811

793812
//
794813
// Wide multiplication
@@ -2379,9 +2398,6 @@ def fpimm_any_zero : FPImmLeaf<fAny, [{
23792398
return Imm.isZero();
23802399
}]>;
23812400

2382-
def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
2383-
def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
2384-
23852401
// Perform substitution if fma only has one use, and also if instruction has
23862402
// nnan instruction flag or if the TM has NoNaNsFPMath
23872403
def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
@@ -2404,10 +2420,10 @@ class FMARELUInst<RegTyInfo t, bit allow_ftz, PatFrag zero_pat>
24042420

24052421
let Predicates = [useFP16Math, hasPTX<70>, hasSM<80>] in {
24062422
def FMARELU_F16 : FMARELUInst<F16RT, true, fpimm_any_zero>;
2407-
def FMARELU_F16X2 : FMARELUInst<F16X2RT, true, fpimm_positive_zero_v2f16>;
2423+
def FMARELU_F16X2 : FMARELUInst<F16X2RT, true, zeroinitializer<v2f16>>;
24082424
}
24092425

24102426
let Predicates = [hasBF16Math, hasPTX<70>, hasSM<80>] in {
24112427
def FMARELU_BF16 : FMARELUInst<BF16RT, false, fpimm_any_zero>;
2412-
def FMARELU_BF16X2 : FMARELUInst<BF16X2RT, false, fpimm_positive_zero_v2bf16>;
2428+
def FMARELU_BF16X2 : FMARELUInst<BF16X2RT, false, zeroinitializer<v2bf16>>;
24132429
}

0 commit comments

Comments
 (0)