@@ -151,6 +151,8 @@ class OneUse2<SDPatternOperator operator>
151
151
class fpimm_pos_inf<ValueType vt>
152
152
: FPImmLeaf<vt, [{ return Imm.isPosInfinity(); }]>;
153
153
154
+ class zeroinitializer<ValueType vt> :
155
+ PatLeaf<(vt (bitconvert (!cast<ValueType>("i" # vt.Size) 0)))>;
154
156
155
157
156
158
// Operands which can hold a Register or an Immediate.
@@ -789,6 +791,23 @@ def UMAX16x2 : I16x2<"max.u", umax>;
789
791
def SMIN16x2 : I16x2<"min.s", smin>;
790
792
def UMIN16x2 : I16x2<"min.u", umin>;
791
793
794
+ let Predicates = [hasPTX<80>, hasSM<90>] in {
795
+
796
+ def MIN_RELU_S32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
797
+ "min.relu.s32",
798
+ [(set i32:$dst, (smax (smin i32:$a, i32:$b), 0))]>;
799
+ def MAX_RELU_S32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
800
+ "max.relu.s32",
801
+ [(set i32:$dst, (smax (smax i32:$a, i32:$b), 0))]>;
802
+ def MIN_RELU_S16x2 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
803
+ "min.relu.s16x2",
804
+ [(set v2i16:$dst, (smax (smin v2i16:$a, v2i16:$b),
805
+ zeroinitializer<v2i16>))]>;
806
+ def MAX_RELU_S16x2 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
807
+ "max.relu.s16x2",
808
+ [(set v2i16:$dst, (smax (smax v2i16:$a, v2i16:$b),
809
+ zeroinitializer<v2i16>))]>;
810
+ }
792
811
793
812
//
794
813
// Wide multiplication
@@ -2379,9 +2398,6 @@ def fpimm_any_zero : FPImmLeaf<fAny, [{
2379
2398
return Imm.isZero();
2380
2399
}]>;
2381
2400
2382
- def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
2383
- def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
2384
-
2385
2401
// Perform substitution if fma only has one use, and also if instruction has
2386
2402
// nnan instruction flag or if the TM has NoNaNsFPMath
2387
2403
def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
@@ -2404,10 +2420,10 @@ class FMARELUInst<RegTyInfo t, bit allow_ftz, PatFrag zero_pat>
2404
2420
2405
2421
let Predicates = [useFP16Math, hasPTX<70>, hasSM<80>] in {
2406
2422
def FMARELU_F16 : FMARELUInst<F16RT, true, fpimm_any_zero>;
2407
- def FMARELU_F16X2 : FMARELUInst<F16X2RT, true, fpimm_positive_zero_v2f16 >;
2423
+ def FMARELU_F16X2 : FMARELUInst<F16X2RT, true, zeroinitializer<v2f16> >;
2408
2424
}
2409
2425
2410
2426
let Predicates = [hasBF16Math, hasPTX<70>, hasSM<80>] in {
2411
2427
def FMARELU_BF16 : FMARELUInst<BF16RT, false, fpimm_any_zero>;
2412
- def FMARELU_BF16X2 : FMARELUInst<BF16X2RT, false, fpimm_positive_zero_v2bf16 >;
2428
+ def FMARELU_BF16X2 : FMARELUInst<BF16X2RT, false, zeroinitializer<v2bf16> >;
2413
2429
}
0 commit comments