Skip to content

[clang] Enable constexpr handling for __builtin_elementwise_fma #152919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ckoparkar
Copy link
Contributor

Fixes #152455.

/cc @RKSimon

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:bytecode Issues for the clang bytecode constexpr interpreter labels Aug 10, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 10, 2025

@llvm/pr-subscribers-clang

Author: Chaitanya Koparkar (ckoparkar)

Changes

Fixes #152455.

/cc @RKSimon


Full diff: https://github.com/llvm/llvm-project/pull/152919.diff

5 Files Affected:

  • (modified) clang/docs/LanguageExtensions.rst (+3-2)
  • (modified) clang/include/clang/Basic/Builtins.td (+1-1)
  • (modified) clang/lib/AST/ByteCode/InterpBuiltin.cpp (+72)
  • (modified) clang/lib/AST/ExprConstant.cpp (+39)
  • (modified) clang/test/Sema/constant-builtins-vector.cpp (+22)
diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst
index b5bb198ca637a..e2aa2ad58a41e 100644
--- a/clang/docs/LanguageExtensions.rst
+++ b/clang/docs/LanguageExtensions.rst
@@ -757,9 +757,10 @@ elementwise to the input.
 
 Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±infinity
 
-The integer elementwise intrinsics, including ``__builtin_elementwise_popcount``,
+The elementwise intrinsics ``__builtin_elementwise_popcount``,
 ``__builtin_elementwise_bitreverse``, ``__builtin_elementwise_add_sat``,
-``__builtin_elementwise_sub_sat`` can be called in a ``constexpr`` context.
+``__builtin_elementwise_sub_sat``, and ``__builtin_elementwise_fma``
+can be called in a ``constexpr`` context.
 
 No implicit promotion of integer types takes place. The mixing of integer types
 of different sizes and signs is forbidden in binary and ternary builtins.
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index c81714e9b009d..0e6a0af34b5da 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -1498,7 +1498,7 @@ def ElementwiseCopysign : Builtin {
 
 def ElementwiseFma : Builtin {
   let Spellings = ["__builtin_elementwise_fma"];
-  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr];
   let Prototype = "void(...)";
 }
 
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index c835bd4fb6088..b530980dd34f8 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -141,6 +141,16 @@ static void diagnoseNonConstexprBuiltin(InterpState &S, CodePtr OpPC,
     S.CCEDiag(Loc, diag::note_invalid_subexpr_in_const_expr);
 }
 
+// Same implementation as Compiler::getRoundingMode.
+static llvm::RoundingMode getRoundingMode(const InterpState &S, const Expr *E) {
+  FPOptions FPO = E->getFPFeaturesInEffect(S.Ctx.getLangOpts());
+
+  if (FPO.getRoundingMode() == llvm::RoundingMode::Dynamic)
+    return llvm::RoundingMode::NearestTiesToEven;
+
+  return FPO.getRoundingMode();
+}
+
 static bool interp__builtin_is_constant_evaluated(InterpState &S, CodePtr OpPC,
                                                   const InterpFrame *Frame,
                                                   const CallExpr *Call) {
@@ -2320,6 +2330,65 @@ static bool interp__builtin_elementwise_sat(InterpState &S, CodePtr OpPC,
   return true;
 }
 
+static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
+                                            const CallExpr *Call) {
+  assert(Call->getNumArgs() == 3);
+
+  llvm::RoundingMode RM = getRoundingMode(S, Call);
+
+  const QualType Arg1Type = Call->getArg(0)->getType();
+  const QualType Arg2Type = Call->getArg(1)->getType();
+  const QualType Arg3Type = Call->getArg(2)->getType();
+
+  // Non-vector floating point types.
+  if (!Arg1Type->isVectorType()) {
+    assert(!Arg2Type->isVectorType());
+    assert(!Arg3Type->isVectorType());
+
+    const Floating &Z = S.Stk.pop<Floating>();
+    const Floating &Y = S.Stk.pop<Floating>();
+    const Floating &X = S.Stk.pop<Floating>();
+
+    APFloat F = X.getAPFloat();
+    F.fusedMultiplyAdd(Y.getAPFloat(), Z.getAPFloat(), RM);
+    Floating Result = S.allocFloat(X.getSemantics());
+    Result.copy(F);
+    S.Stk.push<Floating>(Result);
+    return true;
+  }
+
+  // Vector type.
+  assert(Arg1Type->isVectorType() &&
+         Arg2Type->isVectorType() &&
+         Arg3Type->isVectorType());
+
+  const VectorType *VecT = Arg1Type->castAs<VectorType>();
+  const QualType ElemT = VecT->getElementType();
+  unsigned NumElems = VecT->getNumElements();
+
+  assert(ElemT == Arg2Type->castAs<VectorType>()->getElementType() &&
+         ElemT == Arg3Type->castAs<VectorType>()->getElementType());
+  assert(NumElems == Arg2Type->castAs<VectorType>()->getNumElements() &&
+         NumElems == Arg3Type->castAs<VectorType>()->getNumElements());
+  assert(ElemT->isRealFloatingType());
+
+  const Pointer &VZ = S.Stk.pop<Pointer>();
+  const Pointer &VY = S.Stk.pop<Pointer>();
+  const Pointer &VX = S.Stk.pop<Pointer>();
+  const Pointer &Dst = S.Stk.peek<Pointer>();
+
+  for (unsigned I = 0; I != NumElems; ++I) {
+    using T = PrimConv<PT_Float>::T;
+    APFloat X = VX.elem<T>(I).getAPFloat();
+    APFloat Y = VY.elem<T>(I).getAPFloat();
+    APFloat Z = VZ.elem<T>(I).getAPFloat();
+    (void)X.fusedMultiplyAdd(Y, Z, RM);
+    Dst.elem<T>(I) = static_cast<PrimConv<PT_Float>::T>(X);
+  }
+  Dst.initializeAllElements();
+  return true;
+}
+
 bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
                       uint32_t BuiltinID) {
   if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -2727,6 +2796,9 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
   case Builtin::BI__builtin_elementwise_sub_sat:
     return interp__builtin_elementwise_sat(S, OpPC, Call, BuiltinID);
 
+  case Builtin::BI__builtin_elementwise_fma:
+    return interp__builtin_elementwise_fma(S, OpPC, Call);
+
   default:
     S.FFDiag(S.Current->getLocation(OpPC),
              diag::note_invalid_subexpr_in_const_expr)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 3679327da7b0c..a7293415af0ce 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11658,6 +11658,29 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
 
     return Success(APValue(ResultElements.data(), ResultElements.size()), E);
   }
+  case Builtin::BI__builtin_elementwise_fma: {
+    APValue SourceX, SourceY, SourceZ;
+    if (!EvaluateAsRValue(Info, E->getArg(0), SourceX) ||
+        !EvaluateAsRValue(Info, E->getArg(1), SourceY) ||
+        !EvaluateAsRValue(Info, E->getArg(2), SourceZ))
+      return false;
+
+    unsigned SourceLen = SourceX.getVectorLength();
+    SmallVector<APValue> ResultElements;
+    ResultElements.reserve(SourceLen);
+    llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
+
+    for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
+      APFloat X = SourceX.getVectorElt(EltNum).getFloat();
+      APFloat Y = SourceY.getVectorElt(EltNum).getFloat();
+      APFloat Z = SourceZ.getVectorElt(EltNum).getFloat();
+      APFloat Result(X);
+      (void)Result.fusedMultiplyAdd(Y, Z, RM);
+      ResultElements.push_back(APValue(Result));
+    }
+
+    return Success(APValue(ResultElements.data(), ResultElements.size()), E);
+  }
   }
 }
 
@@ -15878,6 +15901,22 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
     Result = minimumnum(Result, RHS);
     return true;
   }
+
+  case Builtin::BI__builtin_elementwise_fma: {
+    if(!E->getArg(0)->isPRValue() ||
+       !E->getArg(1)->isPRValue() ||
+       !E->getArg(2)->isPRValue()) {
+      return false;
+    }
+    APFloat SourceY(0.), SourceZ(0.);
+    if (!EvaluateFloat(E->getArg(0), Result, Info) ||
+        !EvaluateFloat(E->getArg(1), SourceY, Info) ||
+        !EvaluateFloat(E->getArg(2), SourceZ, Info))
+      return false;
+    llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
+    (void)Result.fusedMultiplyAdd(SourceY, SourceZ, RM);
+    return true;
+  }
   }
 }
 
diff --git a/clang/test/Sema/constant-builtins-vector.cpp b/clang/test/Sema/constant-builtins-vector.cpp
index bde5c478b2b6f..5fa0a7d447ebe 100644
--- a/clang/test/Sema/constant-builtins-vector.cpp
+++ b/clang/test/Sema/constant-builtins-vector.cpp
@@ -860,3 +860,25 @@ static_assert(__builtin_elementwise_sub_sat(0U, 1U) == 0U);
 static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_sub_sat((vector4char){5, 4, 3, 2}, (vector4char){1, 1, 1, 1})) == (LITTLE_END ? 0x01020304 : 0x04030201));
 static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_sub_sat((vector4uchar){5, 4, 3, 2}, (vector4uchar){1, 1, 1, 1})) == (LITTLE_END ? 0x01020304U : 0x04030201U));
 static_assert(__builtin_bit_cast(unsigned long long, __builtin_elementwise_sub_sat((vector4short){(short)0x8000, (short)0x8001, (short)0x8002, (short)0x8003}, (vector4short){7, 8, 9, 10}) == (LITTLE_END ? 0x8000800080008000 : 0x8000800080008000)));
+
+
+// Non-vector floating point types.
+static_assert(__builtin_elementwise_fma(2.0, 3.0, 4.0) == 10.0);
+static_assert(__builtin_elementwise_fma(200.0, 300.0, 400.0) == 60400.0);
+// Vector type.
+constexpr vector4float fmaFloat1 =
+  __builtin_elementwise_fma((vector4float){1.0, 2.0, 3.0, 4.0},
+                            (vector4float){2.0, 3.0, 4.0, 5.0},
+                            (vector4float){3.0, 4.0, 5.0, 6.0});
+static_assert(fmaFloat1[0] == 5.0);
+static_assert(fmaFloat1[1] == 10.0);
+static_assert(fmaFloat1[2] == 17.0);
+static_assert(fmaFloat1[3] == 26.0);
+constexpr vector4double fmaDouble1 =
+  __builtin_elementwise_fma((vector4double){1.0, 2.0, 3.0, 4.0},
+                            (vector4double){2.0, 3.0, 4.0, 5.0},
+                            (vector4double){3.0, 4.0, 5.0, 6.0});
+static_assert(fmaDouble1[0] == 5.0);
+static_assert(fmaDouble1[1] == 10.0);
+static_assert(fmaDouble1[2] == 17.0);
+static_assert(fmaDouble1[3] == 26.0);

Copy link

github-actions bot commented Aug 10, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions cpp -- clang/lib/AST/ByteCode/InterpBuiltin.cpp clang/lib/AST/ExprConstant.cpp clang/test/CodeGen/rounding-math.cpp clang/test/Sema/constant-builtins-vector.cpp
View the diff from clang-format here.
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index 4f0263f7e..a019b40c8 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -2353,8 +2353,7 @@ static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
   }
 
   // Vector type.
-  assert(Arg1Type->isVectorType() &&
-         Arg2Type->isVectorType() &&
+  assert(Arg1Type->isVectorType() && Arg2Type->isVectorType() &&
          Arg3Type->isVectorType());
 
   const VectorType *VecT = Arg1Type->castAs<VectorType>();
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 79bb9daf7..a33a66f48 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15937,9 +15937,8 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
   }
 
   case Builtin::BI__builtin_elementwise_fma: {
-    if(!E->getArg(0)->isPRValue() ||
-       !E->getArg(1)->isPRValue() ||
-       !E->getArg(2)->isPRValue()) {
+    if (!E->getArg(0)->isPRValue() || !E->getArg(1)->isPRValue() ||
+        !E->getArg(2)->isPRValue()) {
       return false;
     }
     APFloat SourceY(0.), SourceZ(0.);

if(!E->getArg(0)->isPRValue() ||
!E->getArg(1)->isPRValue() ||
!E->getArg(2)->isPRValue()) {
return false;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about this check.

Without it, these statements in clang/test/CodeGen/builtins-elementwise-math.c lead to an assertion failure at E->isPRValue() in EvaluateFloat:

float f2 = __builtin_elementwise_fma(f32, f32, f32);
double d2 = __builtin_elementwise_fma(f64, f64, f64);
v2f64 = __builtin_elementwise_fma(f64, f64, f64); 
half tmp_f16 = __builtin_elementwise_fma(f16, f16, f16);

When I dump the arguments, I see this; the first argument is not an r-value, leading to the assertion failure:

Arg 0: DeclRefExpr 0x5a72563dffe8 'float' lvalue ParmVar 0x5a72563de468 'f32' 'float'

Arg 1: ImplicitCastExpr 0x5a72563e00f0 'float' <LValueToRValue>
`-DeclRefExpr 0x5a72563e0008 'float' lvalue ParmVar 0x5a72563de468 'f32' 'float'

Arg 2: ImplicitCastExpr 0x5a72563e0108 'float' <LValueToRValue>
`-DeclRefExpr 0x5a72563e0028 'float' lvalue ParmVar 0x5a72563de468 'f32' 'float'

The arguments look different for some other builtins I checked, e.g. for fmax:

/*
Arg 0 : ImplicitCastExpr 0x5ffaa8eed648 'double' <FloatingCast>
`-ImplicitCastExpr 0x5ffaa8eed630 'half':'_Float16' <LValueToRValue>
  `-DeclRefExpr 0x5ffaa8eed568 'half':'_Float16' lvalue ParmVar 0x5ffaa8eebde0 'f16' 'half':'_Float16'

Arg 1: ImplicitCastExpr 0x5ffaa8eed678 'double' <FloatingCast>
`-ImplicitCastExpr 0x5ffaa8eed660 'half':'_Float16' <LValueToRValue>
  `-DeclRefExpr 0x5ffaa8eed588 'half':'_Float16' lvalue ParmVar 0x5ffaa8eebde0 'f16' 'half':'_Float16'
*/

half tmp2_f16 = __builtin_fmax(f16, f16);

I'm not sure what code is responsible for the LValueToRValue here, and is it a bug that it isn't happening for the first argument of fma? Is it related to way that APFloat::fusedMultiplyAdd stores its result? I'd be happy to debug further if someone can point me the right direction.

// Vector type.
assert(Arg1Type->isVectorType() &&
Arg2Type->isVectorType() &&
Arg3Type->isVectorType());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-format would rather write this as:

-  assert(Arg1Type->isVectorType() &&
-         Arg2Type->isVectorType() &&
+  assert(Arg1Type->isVectorType() && Arg2Type->isVectorType() &&
          Arg3Type->isVectorType());

I like it less but I can change it if required.

@tbaederr
Copy link
Contributor

What about integers and integer vectors?

@RKSimon RKSimon requested review from arsenm and RKSimon August 10, 2025 18:17
@ckoparkar
Copy link
Contributor Author

ckoparkar commented Aug 10, 2025

What about integers and integer vectors?

@tbaederr I believe __builtin_elementwise_fma only accepts direct floating point numbers and vectors of floating point numbers as arguments, so we don't need to handle integers.

@@ -11,3 +11,55 @@ float V3 = func_01(1.0F, 2.0F);
// CHECK: @V1 = {{.*}}global float 1.000000e+00, align 4
// CHECK: @V2 = {{.*}}global float 1.000000e+00, align 4
// CHECK: @V3 = {{.*}}global float 3.000000e+00, align 4

void test_builtin_elementwise_fma_round_upward() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arsenm Tests for different rounding modes. I chose this file rather than constant-builtins-vector.cpp because, it's already testing rounding modes and already has the appropriate RUN command set up.

#pragma STDC FENV_ROUND FE_UPWARD

// CHECK: store float 0x4018000100000000, ptr %f1
// CHECK: store float 0x4018000100000000, ptr %f2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if checking for a specific constant is a good idea. But I was using these to sanity check the result, so I put them here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should definitely be testing the specific constant

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 all constexpr testing must test exact fp constant results (this includes +/-0.0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:bytecode Issues for the clang bytecode constexpr interpreter clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Clang] VectorExprEvaluator::VisitCallExpr - add __builtin_elementwise_fma constexpr handling
5 participants