Skip to content

Conversation

@AlexMaclean
Copy link
Member

Update the lowering rules for sin, cos, and frem to respect the instruction-level flags in addition to the global and function-level options. For sin and cos, the TableGen lowering has been updated to check the afn flag on the node. The lowering for frem has been pulled to custom instruction legalization in order to allow for DAG Combiner optimizations to operate over the expanded instructions.

@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Update the lowering rules for sin, cos, and frem to respect the instruction-level flags in addition to the global and function-level options. For sin and cos, the TableGen lowering has been updated to check the afn flag on the node. The lowering for frem has been pulled to custom instruction legalization in order to allow for DAG Combiner optimizations to operate over the expanded instructions.


Patch is 23.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133121.diff

6 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+32)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+21-105)
  • (modified) llvm/test/CodeGen/NVPTX/f16-instructions.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/f16x2-instructions.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/fast-math.ll (+14)
  • (added) llvm/test/CodeGen/NVPTX/frem.ll (+286)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 06e221777b7ea..8a4b83365ae84 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -18,6 +18,7 @@
 #include "NVPTXTargetMachine.h"
 #include "NVPTXTargetObjectFile.h"
 #include "NVPTXUtilities.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
@@ -932,6 +933,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
     setOperationAction(Op, MVT::bf16, Promote);
     AddPromotedToType(Op, MVT::bf16, MVT::f32);
   }
+  setOperationAction(ISD::FREM, {MVT::f32, MVT::f64}, Custom);
 
   setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal);
   if (STI.getPTXVersion() >= 65) {
@@ -2819,6 +2821,34 @@ static SDValue lowerROT(SDValue Op, SelectionDAG &DAG) {
                      SDLoc(Op), Opcode, DAG);
 }
 
+static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG,
+                         bool AllowUnsafeFPMath) {
+  // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
+  // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
+  // the semantics of LLVM's frem.
+  SDLoc DL(Op);
+  SDValue X = Op->getOperand(0);
+  SDValue Y = Op->getOperand(1);
+  EVT Ty = Op.getValueType();
+
+  SDValue Div = DAG.getNode(ISD::FDIV, DL, Ty, X, Y);
+  SDValue Trunc = DAG.getNode(ISD::FTRUNC, DL, Ty, Div);
+  SDValue Mul =
+      DAG.getNode(ISD::FMUL, DL, Ty, Trunc, Y, SDNodeFlags::AllowContract);
+  SDValue Sub =
+      DAG.getNode(ISD::FSUB, DL, Ty, X, Mul, SDNodeFlags::AllowContract);
+
+  if (AllowUnsafeFPMath || Op->getFlags().hasNoInfs())
+    return Sub;
+
+  // If Y is infinite, return X
+  SDValue AbsY = DAG.getNode(ISD::FABS, DL, Ty, Y);
+  SDValue Inf =
+      DAG.getConstantFP(APFloat::getInf(Ty.getFltSemantics()), DL, Ty);
+  SDValue IsInf = DAG.getSetCC(DL, MVT::i1, AbsY, Inf, ISD::SETEQ);
+  return DAG.getSelect(DL, Ty, IsInf, X, Sub);
+}
+
 SDValue
 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -2913,6 +2943,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::CTPOP:
   case ISD::CTLZ:
     return lowerCTLZCTPOP(Op, DAG);
+  case ISD::FREM:
+    return lowerFREM(Op, DAG, allowUnsafeFPMath(DAG.getMachineFunction()));
 
   default:
     llvm_unreachable("Custom lowering not defined for operation");
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 1786503a6dd4e..fe9bb621b481c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -150,9 +150,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
 
 def doMulWide      : Predicate<"doMulWide">;
 
-def allowUnsafeFPMath : Predicate<"allowUnsafeFPMath()">;
-def noUnsafeFPMath : Predicate<"!allowUnsafeFPMath()">;
-
 def do_DIVF32_APPROX : Predicate<"getDivF32Level()==0">;
 def do_DIVF32_FULL : Predicate<"getDivF32Level()==1">;
 
@@ -211,6 +208,12 @@ class ValueToRegClass<ValueType T> {
 // Some Common Instruction Class Templates
 //===----------------------------------------------------------------------===//
 
+class OneUse1<SDPatternOperator operator>
+    : PatFrag<(ops node:$A), (operator node:$A), [{ return N->hasOneUse(); }]>;
+
+class fpimm_pos_inf<ValueType vt>
+    : FPImmLeaf<vt, [{ return Imm.isPosInfinity(); }]>;
+
 // Utility class to wrap up information about a register and DAG type for more
 // convenient iteration and parameterization
 class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm> {
@@ -442,7 +445,7 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
 class BinOpAllowsFMA<SDPatternOperator operator>
     : PatFrag<(ops node:$A, node:$B),
               (operator node:$A, node:$B), [{
-  return allowFMA() || N->getFlags().hasAllowContract();;
+  return allowFMA() || N->getFlags().hasAllowContract();
 }]>;
 
 multiclass F3_fma_component<string op_str, SDNode op_node> {
@@ -693,10 +696,7 @@ let hasSideEffects = false in {
   defm CVT_to_tf32_rz_relu_satf  : CVT_TO_TF32<"rz.relu.satfinite", [hasPTX<86>, hasSM<100>]>;
 }
 
-def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{
-  return N->hasOneUse();
-}]>;
-
+def fpround_oneuse : OneUse1<fpround>;
 def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse f32:$lo)),
                                 (bf16 (fpround_oneuse f32:$hi)))),
           (CVT_bf16x2_f32 $hi, $lo, CvtRN)>,
@@ -786,18 +786,14 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
 // Test Instructions
 //-----------------------------------
 
+def fabs_oneuse : OneUse1<fabs>;
+
 def TESTINF_f32r : NVPTXInst<(outs Int1Regs:$p), (ins Float32Regs:$a),
                              "testp.infinite.f32 \t$p, $a;",
-                             []>;
-def TESTINF_f32i : NVPTXInst<(outs Int1Regs:$p), (ins f32imm:$a),
-                             "testp.infinite.f32 \t$p, $a;",
-                             []>;
+                             [(set i1:$p, (seteq (fabs_oneuse f32:$a), fpimm_pos_inf<f32>))]>;
 def TESTINF_f64r : NVPTXInst<(outs Int1Regs:$p), (ins Float64Regs:$a),
                              "testp.infinite.f64 \t$p, $a;",
-                             []>;
-def TESTINF_f64i : NVPTXInst<(outs Int1Regs:$p), (ins f64imm:$a),
-                             "testp.infinite.f64 \t$p, $a;",
-                             []>;
+                             [(set i1:$p, (seteq (fabs_oneuse f64:$a), fpimm_pos_inf<f64>))]>;
 
 //-----------------------------------
 // Integer Arithmetic
@@ -1362,99 +1358,19 @@ defm FMA32        : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
 defm FMA64        : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
 
 // sin/cos
+
+class UnaryOpAllowsApproxFn<SDPatternOperator operator>
+    : PatFrag<(ops node:$A),
+              (operator node:$A), [{
+  return allowUnsafeFPMath() || N->getFlags().hasApproximateFuncs();
+}]>;
+
 def SINF:  NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
                       "sin.approx.f32 \t$dst, $src;",
-                      [(set f32:$dst, (fsin f32:$src))]>,
-                      Requires<[allowUnsafeFPMath]>;
+                      [(set f32:$dst, (UnaryOpAllowsApproxFn<fsin> f32:$src))]>;
 def COSF:  NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
                       "cos.approx.f32 \t$dst, $src;",
-                      [(set f32:$dst, (fcos f32:$src))]>,
-                      Requires<[allowUnsafeFPMath]>;
-
-// Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
-// i.e. "poor man's fmod()". When y is infinite, x is returned. This matches the
-// semantics of LLVM's frem.
-
-// frem - f32 FTZ
-def : Pat<(frem f32:$x, f32:$y),
-          (FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32
-            (FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ),
-             $y))>,
-          Requires<[doF32FTZ, allowUnsafeFPMath]>;
-def : Pat<(frem f32:$x, fpimm:$y),
-          (FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32
-            (FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ),
-             fpimm:$y))>,
-          Requires<[doF32FTZ, allowUnsafeFPMath]>;
-
-def : Pat<(frem f32:$x, f32:$y),
-          (SELP_f32rr $x,
-            (FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32
-              (FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ),
-              $y)),
-            (TESTINF_f32r $y))>,
-          Requires<[doF32FTZ, noUnsafeFPMath]>;
-def : Pat<(frem f32:$x, fpimm:$y),
-          (SELP_f32rr $x,
-            (FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32
-              (FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ),
-              fpimm:$y)),
-            (TESTINF_f32i fpimm:$y))>,
-          Requires<[doF32FTZ, noUnsafeFPMath]>;
-
-// frem - f32
-def : Pat<(frem f32:$x, f32:$y),
-          (FSUBf32rr $x, (FMULf32rr (CVT_f32_f32
-            (FDIV32rr_prec $x, $y), CvtRZI),
-             $y))>,
-          Requires<[allowUnsafeFPMath]>;
-def : Pat<(frem f32:$x, fpimm:$y),
-          (FSUBf32rr $x, (FMULf32ri (CVT_f32_f32
-            (FDIV32ri_prec $x, fpimm:$y), CvtRZI),
-             fpimm:$y))>,
-          Requires<[allowUnsafeFPMath]>;
-
-def : Pat<(frem f32:$x, f32:$y),
-          (SELP_f32rr $x,
-            (FSUBf32rr $x, (FMULf32rr (CVT_f32_f32
-              (FDIV32rr_prec $x, $y), CvtRZI),
-              $y)),
-            (TESTINF_f32r Float32Regs:$y))>,
-          Requires<[noUnsafeFPMath]>;
-def : Pat<(frem f32:$x, fpimm:$y),
-          (SELP_f32rr $x,
-            (FSUBf32rr $x, (FMULf32ri (CVT_f32_f32
-              (FDIV32ri_prec $x, fpimm:$y), CvtRZI),
-              fpimm:$y)),
-            (TESTINF_f32i fpimm:$y))>,
-          Requires<[noUnsafeFPMath]>;
-
-// frem - f64
-def : Pat<(frem f64:$x, f64:$y),
-          (FSUBf64rr $x, (FMULf64rr (CVT_f64_f64
-            (FDIV64rr $x, $y), CvtRZI),
-             $y))>,
-          Requires<[allowUnsafeFPMath]>;
-def : Pat<(frem f64:$x, fpimm:$y),
-          (FSUBf64rr $x, (FMULf64ri (CVT_f64_f64
-            (FDIV64ri $x, fpimm:$y), CvtRZI),
-             fpimm:$y))>,
-          Requires<[allowUnsafeFPMath]>;
-
-def : Pat<(frem f64:$x, f64:$y),
-          (SELP_f64rr $x,
-            (FSUBf64rr $x, (FMULf64rr (CVT_f64_f64
-              (FDIV64rr $x, $y), CvtRZI),
-               $y)),
-            (TESTINF_f64r Float64Regs:$y))>,
-          Requires<[noUnsafeFPMath]>;
-def : Pat<(frem f64:$x, fpimm:$y),
-          (SELP_f64rr $x,
-            (FSUBf64rr $x, (FMULf64ri (CVT_f64_f64
-              (FDIV64ri $x, fpimm:$y), CvtRZI),
-              fpimm:$y)),
-            (TESTINF_f64r $y))>,
-          Requires<[noUnsafeFPMath]>;
+                      [(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>;
 
 //-----------------------------------
 // Bitwise operations
diff --git a/llvm/test/CodeGen/NVPTX/f16-instructions.ll b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
index 70d1167bbb6e2..b34dfc4e19766 100644
--- a/llvm/test/CodeGen/NVPTX/f16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
@@ -200,14 +200,14 @@ define half @test_fdiv(half %a, half %b) #0 {
 ; CHECK-NOFTZ-DAG:  cvt.f32.f16     [[FB:%f[0-9]+]], [[B]];
 ; CHECK-NOFTZ-NEXT: div.rn.f32      [[D:%f[0-9]+]], [[FA]], [[FB]];
 ; CHECK-NOFTZ-NEXT: cvt.rzi.f32.f32 [[DI:%f[0-9]+]], [[D]];
-; CHECK-NOFTZ-NEXT: mul.f32         [[RI:%f[0-9]+]], [[DI]], [[FB]];
-; CHECK-NOFTZ-NEXT: sub.f32         [[RF:%f[0-9]+]], [[FA]], [[RI]];
+; CHECK-NOFTZ-NEXT: neg.f32         [[DNEG:%f[0-9]+]], [[DI]];
+; CHECK-NOFTZ-NEXT: fma.rn.f32      [[RF:%f[0-9]+]], [[DNEG]], [[FB]], [[FA]];
 ; CHECK-F16-FTZ-DAG:  cvt.ftz.f32.f16     [[FA:%f[0-9]+]], [[A]];
 ; CHECK-F16-FTZ-DAG:  cvt.ftz.f32.f16     [[FB:%f[0-9]+]], [[B]];
 ; CHECK-F16-FTZ-NEXT: div.rn.ftz.f32      [[D:%f[0-9]+]], [[FA]], [[FB]];
 ; CHECK-F16-FTZ-NEXT: cvt.rzi.ftz.f32.f32 [[DI:%f[0-9]+]], [[D]];
-; CHECK-F16-FTZ-NEXT: mul.ftz.f32         [[RI:%f[0-9]+]], [[DI]], [[FB]];
-; CHECK-F16-FTZ-NEXT: sub.ftz.f32         [[RF:%f[0-9]+]], [[FA]], [[RI]];
+; CHECK-F16-FTZ-NEXT: neg.ftz.f32         [[DNEG:%f[0-9]+]], [[DI]];
+; CHECK-F16-FTZ-NEXT: fma.rn.ftz.f32      [[RF:%f[0-9]+]], [[DNEG]], [[FB]], [[FA]];
 ; CHECK-NEXT: testp.infinite.f32 [[ISBINF:%p[0-9]+]], [[FB]];
 ; CHECK-NEXT: selp.f32           [[RESULT:%f[0-9]+]], [[FA]], [[RF]], [[ISBINF]];
 ; CHECK-NEXT: cvt.rn.f16.f32     [[R:%rs[0-9]+]], [[RESULT]];
diff --git a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
index 539e810c83cbd..d78b68dc501da 100644
--- a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
@@ -362,8 +362,8 @@ define <2 x half> @test_frem(<2 x half> %a, <2 x half> %b) #0 {
 ; CHECK-NEXT:    cvt.f32.f16 %f2, %rs4;
 ; CHECK-NEXT:    div.rn.f32 %f3, %f2, %f1;
 ; CHECK-NEXT:    cvt.rzi.f32.f32 %f4, %f3;
-; CHECK-NEXT:    mul.f32 %f5, %f4, %f1;
-; CHECK-NEXT:    sub.f32 %f6, %f2, %f5;
+; CHECK-NEXT:    neg.f32 %f5, %f4;
+; CHECK-NEXT:    fma.rn.f32 %f6, %f5, %f1, %f2;
 ; CHECK-NEXT:    testp.infinite.f32 %p1, %f1;
 ; CHECK-NEXT:    selp.f32 %f7, %f2, %f6, %p1;
 ; CHECK-NEXT:    cvt.rn.f16.f32 %rs5, %f7;
@@ -371,8 +371,8 @@ define <2 x half> @test_frem(<2 x half> %a, <2 x half> %b) #0 {
 ; CHECK-NEXT:    cvt.f32.f16 %f9, %rs3;
 ; CHECK-NEXT:    div.rn.f32 %f10, %f9, %f8;
 ; CHECK-NEXT:    cvt.rzi.f32.f32 %f11, %f10;
-; CHECK-NEXT:    mul.f32 %f12, %f11, %f8;
-; CHECK-NEXT:    sub.f32 %f13, %f9, %f12;
+; CHECK-NEXT:    neg.f32 %f12, %f11;
+; CHECK-NEXT:    fma.rn.f32 %f13, %f12, %f8, %f9;
 ; CHECK-NEXT:    testp.infinite.f32 %p2, %f8;
 ; CHECK-NEXT:    selp.f32 %f14, %f9, %f13, %p2;
 ; CHECK-NEXT:    cvt.rn.f16.f32 %rs6, %f14;
diff --git a/llvm/test/CodeGen/NVPTX/fast-math.ll b/llvm/test/CodeGen/NVPTX/fast-math.ll
index d45ce15298f9d..4cb6a35e796fb 100644
--- a/llvm/test/CodeGen/NVPTX/fast-math.ll
+++ b/llvm/test/CodeGen/NVPTX/fast-math.ll
@@ -131,6 +131,20 @@ define float @fadd_ftz(float %a, float %b) #1 {
 declare float @llvm.sin.f32(float)
 declare float @llvm.cos.f32(float)
 
+; CHECK-LABEL: fsin_approx_afn
+; CHECK:       sin.approx.f32
+define float @fsin_approx_afn(float %a) {
+  %r = tail call afn float @llvm.sin.f32(float %a)
+  ret float %r
+}
+
+; CHECK-LABEL: fcos_approx_afn
+; CHECK:       cos.approx.f32
+define float @fcos_approx_afn(float %a) {
+  %r = tail call afn float @llvm.cos.f32(float %a)
+  ret float %r
+}
+
 ; CHECK-LABEL: fsin_approx
 ; CHECK:       sin.approx.f32
 define float @fsin_approx(float %a) #0 {
diff --git a/llvm/test/CodeGen/NVPTX/frem.ll b/llvm/test/CodeGen/NVPTX/frem.ll
new file mode 100644
index 0000000000000..89e1f2e4c0055
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/frem.ll
@@ -0,0 +1,286 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s --enable-unsafe-fp-math | FileCheck %s --check-prefixes=FAST
+; RUN: llc < %s | FileCheck %s --check-prefixes=NORMAL
+
+
+target triple = "nvptx64-unknown-cuda"
+
+define half @frem_f16(half %a, half %b) {
+; FAST-LABEL: frem_f16(
+; FAST:       {
+; FAST-NEXT:    .reg .b16 %rs<4>;
+; FAST-NEXT:    .reg .f32 %f<7>;
+; FAST-EMPTY:
+; FAST-NEXT:  // %bb.0:
+; FAST-NEXT:    ld.param.b16 %rs1, [frem_f16_param_0];
+; FAST-NEXT:    ld.param.b16 %rs2, [frem_f16_param_1];
+; FAST-NEXT:    cvt.f32.f16 %f1, %rs2;
+; FAST-NEXT:    cvt.f32.f16 %f2, %rs1;
+; FAST-NEXT:    div.approx.f32 %f3, %f2, %f1;
+; FAST-NEXT:    cvt.rzi.f32.f32 %f4, %f3;
+; FAST-NEXT:    neg.f32 %f5, %f4;
+; FAST-NEXT:    fma.rn.f32 %f6, %f5, %f1, %f2;
+; FAST-NEXT:    cvt.rn.f16.f32 %rs3, %f6;
+; FAST-NEXT:    st.param.b16 [func_retval0], %rs3;
+; FAST-NEXT:    ret;
+;
+; NORMAL-LABEL: frem_f16(
+; NORMAL:       {
+; NORMAL-NEXT:    .reg .pred %p<2>;
+; NORMAL-NEXT:    .reg .b16 %rs<4>;
+; NORMAL-NEXT:    .reg .f32 %f<8>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT:  // %bb.0:
+; NORMAL-NEXT:    ld.param.b16 %rs1, [frem_f16_param_0];
+; NORMAL-NEXT:    ld.param.b16 %rs2, [frem_f16_param_1];
+; NORMAL-NEXT:    cvt.f32.f16 %f1, %rs2;
+; NORMAL-NEXT:    cvt.f32.f16 %f2, %rs1;
+; NORMAL-NEXT:    div.rn.f32 %f3, %f2, %f1;
+; NORMAL-NEXT:    cvt.rzi.f32.f32 %f4, %f3;
+; NORMAL-NEXT:    neg.f32 %f5, %f4;
+; NORMAL-NEXT:    fma.rn.f32 %f6, %f5, %f1, %f2;
+; NORMAL-NEXT:    testp.infinite.f32 %p1, %f1;
+; NORMAL-NEXT:    selp.f32 %f7, %f2, %f6, %p1;
+; NORMAL-NEXT:    cvt.rn.f16.f32 %rs3, %f7;
+; NORMAL-NEXT:    st.param.b16 [func_retval0], %rs3;
+; NORMAL-NEXT:    ret;
+  %r = frem half %a, %b
+  ret half %r
+}
+
+define float @frem_f32(float %a, float %b) {
+; FAST-LABEL: frem_f32(
+; FAST:       {
+; FAST-NEXT:    .reg .f32 %f<7>;
+; FAST-EMPTY:
+; FAST-NEXT:  // %bb.0:
+; FAST-NEXT:    ld.param.f32 %f1, [frem_f32_param_0];
+; FAST-NEXT:    ld.param.f32 %f2, [frem_f32_param_1];
+; FAST-NEXT:    div.approx.f32 %f3, %f1, %f2;
+; FAST-NEXT:    cvt.rzi.f32.f32 %f4, %f3;
+; FAST-NEXT:    neg.f32 %f5, %f4;
+; FAST-NEXT:    fma.rn.f32 %f6, %f5, %f2, %f1;
+; FAST-NEXT:    st.param.f32 [func_retval0], %f6;
+; FAST-NEXT:    ret;
+;
+; NORMAL-LABEL: frem_f32(
+; NORMAL:       {
+; NORMAL-NEXT:    .reg .pred %p<2>;
+; NORMAL-NEXT:    .reg .f32 %f<8>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT:  // %bb.0:
+; NORMAL-NEXT:    ld.param.f32 %f1, [frem_f32_param_0];
+; NORMAL-NEXT:    ld.param.f32 %f2, [frem_f32_param_1];
+; NORMAL-NEXT:    div.rn.f32 %f3, %f1, %f2;
+; NORMAL-NEXT:    cvt.rzi.f32.f32 %f4, %f3;
+; NORMAL-NEXT:    neg.f32 %f5, %f4;
+; NORMAL-NEXT:    fma.rn.f32 %f6, %f5, %f2, %f1;
+; NORMAL-NEXT:    testp.infinite.f32 %p1, %f2;
+; NORMAL-NEXT:    selp.f32 %f7, %f1, %f6, %p1;
+; NORMAL-NEXT:    st.param.f32 [func_retval0], %f7;
+; NORMAL-NEXT:    ret;
+  %r = frem float %a, %b
+  ret float %r
+}
+
+define double @frem_f64(double %a, double %b) {
+; FAST-LABEL: frem_f64(
+; FAST:       {
+; FAST-NEXT:    .reg .f64 %fd<7>;
+; FAST-EMPTY:
+; FAST-NEXT:  // %bb.0:
+; FAST-NEXT:    ld.param.f64 %fd1, [frem_f64_param_0];
+; FAST-NEXT:    ld.param.f64 %fd2, [frem_f64_param_1];
+; FAST-NEXT:    div.rn.f64 %fd3, %fd1, %fd2;
+; FAST-NEXT:    cvt.rzi.f64.f64 %fd4, %fd3;
+; FAST-NEXT:    neg.f64 %fd5, %fd4;
+; FAST-NEXT:    fma.rn.f64 %fd6, %fd5, %fd2, %fd1;
+; FAST-NEXT:    st.param.f64 [func_retval0], %fd6;
+; FAST-NEXT:    ret;
+;
+; NORMAL-LABEL: frem_f64(
+; NORMAL:       {
+; NORMAL-NEXT:    .reg .pred %p<2>;
+; NORMAL-NEXT:    .reg .f64 %fd<8>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT:  // %bb.0:
+; NORMAL-NEXT:    ld.param.f64 %fd1, [frem_f64_param_0];
+; NORMAL-NEXT:    ld.param.f64 %fd2, [frem_f64_param_1];
+; NORMAL-NEXT:    div.rn.f64 %fd3, %fd1, %fd2;
+; NORMAL-NEXT:    cvt.rzi.f64.f64 %fd4, %fd3;
+; NORMAL-NEXT:    neg.f64 %fd5, %fd4;
+; NORMAL-NEXT:    fma.rn.f64 %fd6, %fd5, %fd2, %fd1;
+; NORMAL-NEXT:    testp.infinite.f64 %p1, %fd2;
+; NORMAL-NEXT:    selp.f64 %fd7, %fd1, %fd6, %p1;
+; NORMAL-NEXT:    st.param.f64 [func_retval0], %fd7;
+; NORMAL-NEXT:    ret;
+  %r = frem double %a, %b
+  ret double %r
+}
+
+define half @frem_f16_ninf(half %a, half %b) {
+; FAST-LABEL: frem_f16_ninf(
+; FAST:       {
+; FAST-NEXT:    .reg .b16 %rs<4>;
+; FAST-NEXT:    .reg .f32 %f<7>;
+; FAST-EMPTY:
+; FAST-NEXT:  // %bb.0:
+; FAST-NEXT:    ld.param.b16 %rs1, [frem_f16_ninf_param_0];
+; FAST-NEXT:    ld.param.b16 %rs2, [frem_f16_ninf_param_1];
+; FAST-NEXT:    cvt.f32.f16 %f1, %rs2;
+; FAST-NEXT:    cvt.f32.f16 %f2, %rs1;
+; FAST-NEXT:    div.approx.f32 %f3, %f2, %f1;
+; FAST-NEXT:    cvt.rzi.f32.f32 %f4, %f3;
+; FAST-NEXT:    neg.f32 %f5, %f4;
+; FAST-NEXT:    fma.rn.f32 %f6, %f5, %f1, %f2;
+; FAST-NEXT:    cvt.rn.f16.f32 %rs3, %f6;
+; FAST-NEXT:    st.param.b16 [func_retval0], %rs3;
+; FAST-NEXT:    ret;
+;
+; NORMAL-LABEL: frem_f16_ninf(
+; NORMAL:       {
+; NORMAL-NEXT:    .reg .b16 %rs<4>;
+; NORMAL-NEXT:    .reg .f32 %f<7>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT:  // %bb.0:
+; NORMAL-NEXT:    ld.param.b16 %rs1, [frem_f16_ninf_param_0];
+; NORMAL-NEXT:    ld.param.b16 %rs2, [frem_f16_ninf_param_1];
+; NORMAL-NEXT:    cvt.f32.f16 %f1, %rs2;
+; NORMAL-NEXT:    cvt.f32.f16 %f2, %rs1;
+; NORMAL-NEXT:    div.rn.f32 %f3, %f2, %f1;
+; NORMAL-NEXT:    cvt.rzi.f32.f32 %f4, %f3;
+; NORMAL-NEXT:    neg.f32 %f5, %f4;
+; NORMAL-NEXT:    fma.rn.f32 %f6, %f5, %f1, %f2;
+; NORMAL-NEXT:    cvt.rn.f16.f32 %rs3, %f6;
+; NORMAL-NEXT:    st.param.b16 [func_retval0], %rs3;
+; NORMAL-NEXT:    ret;
+  %r = frem ninf half %a, %b
+  ret half %r
+}
+
+define float @frem_f32_ninf(float %a, float %b) {
+; FAST-LABEL: frem_f32_ninf(
+; FAST:       {
+; FAST-NEXT:    .reg .f32 %f<7>;
+; FAST-EMPTY:
+; FAST-NEXT:  // %bb.0:
+; FAST-NEXT:    ld.param.f32 %f1, [frem_f32_ninf_param_0];
+; FAST-NEXT:    ld.param.f32 %f2, [frem_f32_ninf_param_1];
+; FAST-NEXT:    div.approx.f32 %f3, %f1, %f2;
+; FAST-NEXT:    cvt.rzi.f32.f32 %f4, %f3;
+; FAST-NEXT:    neg.f32 %f5, %f4;
+; FAST-NEXT:    fma.rn.f32 %f6, %f5, %f2, %f1;
+; FAST-NEXT:    st.param.f32 [func_retval0], %f6;
+; FAST-NEXT:    ret;
+;
+; NORMAL-LABEL: frem_f32_ninf(
+; NORMAL:       {
+; NORMAL-NEXT:    .reg .f32 %f<7>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT:  // %bb.0:
+; NORMAL-NEXT:    ld.param.f32 %f1, [frem_f32_ninf_param_0];
+; NORMAL-NEXT:    ld.param.f32 %f2, [frem_f32_ninf_param_1];
+; NORMAL-NEXT:    div.rn.f32 %f3, %f1, %f2;
+; NORMAL-NEXT:    cvt.rzi.f32.f32 %f4, %f3;
+; NORMAL-NEXT:    neg.f32 %f5, %f4;
+; NORMAL-NEXT:    fma.rn.f32 %f6, %f5, %f2, %f1;
+; NORMAL-NEXT:    st.param.f32 [func_retval0], %f6;
+; NORMAL-NEX...
[truncated]

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a test nit.

@AlexMaclean AlexMaclean merged commit 812e02a into llvm:main Mar 28, 2025
9 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants