@@ -418,25 +418,13 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
418418 !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
419419 [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
420420 Requires<[useFP16Math, allowFMA]>;
421- def bf16rr_ftz :
422- NVPTXInst<(outs Int16Regs:$dst),
423- (ins Int16Regs:$a, Int16Regs:$b),
424- !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
425- [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
426- Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
427421 def bf16rr :
428422 NVPTXInst<(outs Int16Regs:$dst),
429423 (ins Int16Regs:$a, Int16Regs:$b),
430424 !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
431425 [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
432426 Requires<[hasBF16Math, allowFMA]>;
433427
434- def bf16x2rr_ftz :
435- NVPTXInst<(outs Int32Regs:$dst),
436- (ins Int32Regs:$a, Int32Regs:$b),
437- !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
438- [(set (v2bf16 Int32Regs:$dst), (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
439- Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
440428 def bf16x2rr :
441429 NVPTXInst<(outs Int32Regs:$dst),
442430 (ins Int32Regs:$a, Int32Regs:$b),
@@ -1423,9 +1411,7 @@ defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
14231411defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
14241412defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
14251413defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
1426- defm BFMA16_ftz : FMA_BF16<"fma.rn.ftz.bf16", bf16, Int16Regs, doF32FTZ>;
14271414defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
1428- defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>;
14291415defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
14301416defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
14311417defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
@@ -3959,3 +3945,54 @@ def atomic_thread_fence_seq_cst_cta :
39593945def atomic_thread_fence_acq_rel_cta :
39603946 NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
39613947 Requires<[hasPTX<60>, hasSM<70>]>;
3948+
3949+ def fpimm_any_zero : FPImmLeaf<fAny, [{
3950+ return Imm.isZero();
3951+ }]>;
3952+
3953+ def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
3954+ def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
3955+
3956+ // Perform substitution if fma only has one use, and also if instruction has
3957+ // nnan instruction flag or if the TM has NoNaNsFPMath
3958+ def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
3959+ (fma node:$a, node:$b, node:$c), [{
3960+ return N->hasOneUse() &&
3961+ (N->getFlags().hasNoNaNs() || TM.Options.NoNaNsFPMath);
3962+ }]>;
3963+ // fmaxnum will differentiate between signed and unsigned zeros soon, so this
3964+ // PatFrag is for a fmaxnum node with nsz
3965+ def NVPTX_fmaxnum_nsz : PatFrag<(ops node:$a, node:$b),
3966+ (fmaxnum node:$a, node:$b), [{
3967+ return N->getFlags().hasNoSignedZeros() || TM.Options.NoSignedZerosFPMath;
3968+ }]>;
3969+
3970+ class NVPTXInst_rrr<RegisterClass RC, string Instruction, list<Predicate> Preds>
3971+ : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
3972+ !strconcat(Instruction, "\t$dst, $a, $b, $c;"), []>,
3973+ Requires<Preds>;
3974+
3975+ def FMARELU_F16 : NVPTXInst_rrr<Int16Regs, "fma.rn.relu.f16", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3976+ def FMARELU_F16_FTZ : NVPTXInst_rrr<Int16Regs, "fma.rn.ftz.relu.f16", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3977+ def FMARELU_BF16 : NVPTXInst_rrr<Int16Regs, "fma.rn.relu.bf16", [hasBF16Math, hasPTX<70>, hasSM<80>]>;
3978+ def FMARELU_F16X2 : NVPTXInst_rrr<Int32Regs, "fma.rn.relu.f16x2", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3979+ def FMARELU_F16X2_FTZ : NVPTXInst_rrr<Int32Regs, "fma.rn.ftz.relu.f16x2", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3980+ def FMARELU_BF16X2 : NVPTXInst_rrr<Int32Regs, "fma.rn.relu.bf16x2", [hasBF16Math, hasPTX<70>, hasSM<80>]>;
3981+
3982+ // FTZ
3983+ def : Pat<(f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
3984+ (FMARELU_F16_FTZ Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3985+ Requires<[doF32FTZ]>;
3986+ def : Pat<(v2f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2f16)),
3987+ (FMARELU_F16X2_FTZ Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>,
3988+ Requires<[doF32FTZ]>;
3989+
3990+ // NO FTZ
3991+ def : Pat<(f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
3992+ (FMARELU_F16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>;
3993+ def : Pat<(bf16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
3994+ (FMARELU_BF16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>;
3995+ def : Pat<(v2f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2f16)),
3996+ (FMARELU_F16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>;
3997+ def : Pat<(v2bf16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2bf16)),
3998+ (FMARELU_BF16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>;
0 commit comments