@@ -3917,3 +3917,40 @@ def atomic_thread_fence_seq_cst_cta :
39173917def atomic_thread_fence_acq_rel_cta :
39183918 NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
39193919 Requires<[hasPTX<60>, hasSM<70>]>;
3920+
3921+ def fpimm0 : FPImmLeaf<fAny, [{
3922+ return Imm.isExactlyValue(+0.0);
3923+ }]>;
3924+
3925+ def FMARELU_F16 :
3926+ NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
3927+ "fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>,
3928+ Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
3929+ def FMARELU_BF16 :
3930+ NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
3931+ "fma.rn.relu.bf16 \t$dst, $a, $b, $c;", []>,
3932+ Requires<[hasBF16Math, hasPTX<70>, hasSM<80>]>;
3933+ def FMARELU_F16_FTZ :
3934+ NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
3935+ "fma.rn.relu.ftz.f16 \t$dst, $a, $b, $c;", []>,
3936+ Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
3937+ def FMARELU_BF16_FTZ :
3938+ NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
3939+ "fma.rn.relu.ftz.bf16 \t$dst, $a, $b, $c;", []>,
3940+ Requires<[hasBF16Math, hasPTX<70>, hasSM<80>]>;
3941+
3942+
3943+ // FTZ variants
3944+ def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
3945+ (FMARELU_F16_FTZ Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3946+ Requires<[allowFMA, doF32FTZ, allowUnsafeFPMath, hasPTX<70>]>;
3947+ def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
3948+ (FMARELU_BF16_FTZ Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3949+ Requires<[allowFMA, doF32FTZ, allowUnsafeFPMath, hasPTX<70>]>;
3950+ // No FTZ
3951+ def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
3952+ (FMARELU_F16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3953+ Requires<[allowFMA, allowUnsafeFPMath, hasPTX<70>]>;
3954+ def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
3955+ (FMARELU_BF16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3956+ Requires<[allowFMA, allowUnsafeFPMath, hasPTX<70>]>;
0 commit comments