-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[clang] Enable constexpr handling for __builtin_elementwise_fma #152919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-clang Author: Chaitanya Koparkar (ckoparkar) ChangesFixes #152455. /cc @RKSimon Full diff: https://github.com/llvm/llvm-project/pull/152919.diff 5 Files Affected:
diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst
index b5bb198ca637a..e2aa2ad58a41e 100644
--- a/clang/docs/LanguageExtensions.rst
+++ b/clang/docs/LanguageExtensions.rst
@@ -757,9 +757,10 @@ elementwise to the input.
Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±infinity
-The integer elementwise intrinsics, including ``__builtin_elementwise_popcount``,
+The elementwise intrinsics ``__builtin_elementwise_popcount``,
``__builtin_elementwise_bitreverse``, ``__builtin_elementwise_add_sat``,
-``__builtin_elementwise_sub_sat`` can be called in a ``constexpr`` context.
+``__builtin_elementwise_sub_sat``, and ``__builtin_elementwise_fma``
+can be called in a ``constexpr`` context.
No implicit promotion of integer types takes place. The mixing of integer types
of different sizes and signs is forbidden in binary and ternary builtins.
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index c81714e9b009d..0e6a0af34b5da 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -1498,7 +1498,7 @@ def ElementwiseCopysign : Builtin {
def ElementwiseFma : Builtin {
let Spellings = ["__builtin_elementwise_fma"];
- let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr];
let Prototype = "void(...)";
}
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index c835bd4fb6088..b530980dd34f8 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -141,6 +141,16 @@ static void diagnoseNonConstexprBuiltin(InterpState &S, CodePtr OpPC,
S.CCEDiag(Loc, diag::note_invalid_subexpr_in_const_expr);
}
+// Same implementation as Compiler::getRoundingMode.
+static llvm::RoundingMode getRoundingMode(const InterpState &S, const Expr *E) {
+ FPOptions FPO = E->getFPFeaturesInEffect(S.Ctx.getLangOpts());
+
+ if (FPO.getRoundingMode() == llvm::RoundingMode::Dynamic)
+ return llvm::RoundingMode::NearestTiesToEven;
+
+ return FPO.getRoundingMode();
+}
+
static bool interp__builtin_is_constant_evaluated(InterpState &S, CodePtr OpPC,
const InterpFrame *Frame,
const CallExpr *Call) {
@@ -2320,6 +2330,65 @@ static bool interp__builtin_elementwise_sat(InterpState &S, CodePtr OpPC,
return true;
}
+static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
+ const CallExpr *Call) {
+ assert(Call->getNumArgs() == 3);
+
+ llvm::RoundingMode RM = getRoundingMode(S, Call);
+
+ const QualType Arg1Type = Call->getArg(0)->getType();
+ const QualType Arg2Type = Call->getArg(1)->getType();
+ const QualType Arg3Type = Call->getArg(2)->getType();
+
+ // Non-vector floating point types.
+ if (!Arg1Type->isVectorType()) {
+ assert(!Arg2Type->isVectorType());
+ assert(!Arg3Type->isVectorType());
+
+ const Floating &Z = S.Stk.pop<Floating>();
+ const Floating &Y = S.Stk.pop<Floating>();
+ const Floating &X = S.Stk.pop<Floating>();
+
+ APFloat F = X.getAPFloat();
+ F.fusedMultiplyAdd(Y.getAPFloat(), Z.getAPFloat(), RM);
+ Floating Result = S.allocFloat(X.getSemantics());
+ Result.copy(F);
+ S.Stk.push<Floating>(Result);
+ return true;
+ }
+
+ // Vector type.
+ assert(Arg1Type->isVectorType() &&
+ Arg2Type->isVectorType() &&
+ Arg3Type->isVectorType());
+
+ const VectorType *VecT = Arg1Type->castAs<VectorType>();
+ const QualType ElemT = VecT->getElementType();
+ unsigned NumElems = VecT->getNumElements();
+
+ assert(ElemT == Arg2Type->castAs<VectorType>()->getElementType() &&
+ ElemT == Arg3Type->castAs<VectorType>()->getElementType());
+ assert(NumElems == Arg2Type->castAs<VectorType>()->getNumElements() &&
+ NumElems == Arg3Type->castAs<VectorType>()->getNumElements());
+ assert(ElemT->isRealFloatingType());
+
+ const Pointer &VZ = S.Stk.pop<Pointer>();
+ const Pointer &VY = S.Stk.pop<Pointer>();
+ const Pointer &VX = S.Stk.pop<Pointer>();
+ const Pointer &Dst = S.Stk.peek<Pointer>();
+
+ for (unsigned I = 0; I != NumElems; ++I) {
+ using T = PrimConv<PT_Float>::T;
+ APFloat X = VX.elem<T>(I).getAPFloat();
+ APFloat Y = VY.elem<T>(I).getAPFloat();
+ APFloat Z = VZ.elem<T>(I).getAPFloat();
+ (void)X.fusedMultiplyAdd(Y, Z, RM);
+ Dst.elem<T>(I) = static_cast<PrimConv<PT_Float>::T>(X);
+ }
+ Dst.initializeAllElements();
+ return true;
+}
+
bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
uint32_t BuiltinID) {
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -2727,6 +2796,9 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
case Builtin::BI__builtin_elementwise_sub_sat:
return interp__builtin_elementwise_sat(S, OpPC, Call, BuiltinID);
+ case Builtin::BI__builtin_elementwise_fma:
+ return interp__builtin_elementwise_fma(S, OpPC, Call);
+
default:
S.FFDiag(S.Current->getLocation(OpPC),
diag::note_invalid_subexpr_in_const_expr)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 3679327da7b0c..a7293415af0ce 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11658,6 +11658,29 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
}
+ case Builtin::BI__builtin_elementwise_fma: {
+ APValue SourceX, SourceY, SourceZ;
+ if (!EvaluateAsRValue(Info, E->getArg(0), SourceX) ||
+ !EvaluateAsRValue(Info, E->getArg(1), SourceY) ||
+ !EvaluateAsRValue(Info, E->getArg(2), SourceZ))
+ return false;
+
+ unsigned SourceLen = SourceX.getVectorLength();
+ SmallVector<APValue> ResultElements;
+ ResultElements.reserve(SourceLen);
+ llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
+
+ for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
+ APFloat X = SourceX.getVectorElt(EltNum).getFloat();
+ APFloat Y = SourceY.getVectorElt(EltNum).getFloat();
+ APFloat Z = SourceZ.getVectorElt(EltNum).getFloat();
+ APFloat Result(X);
+ (void)Result.fusedMultiplyAdd(Y, Z, RM);
+ ResultElements.push_back(APValue(Result));
+ }
+
+ return Success(APValue(ResultElements.data(), ResultElements.size()), E);
+ }
}
}
@@ -15878,6 +15901,22 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
Result = minimumnum(Result, RHS);
return true;
}
+
+ case Builtin::BI__builtin_elementwise_fma: {
+ if(!E->getArg(0)->isPRValue() ||
+ !E->getArg(1)->isPRValue() ||
+ !E->getArg(2)->isPRValue()) {
+ return false;
+ }
+ APFloat SourceY(0.), SourceZ(0.);
+ if (!EvaluateFloat(E->getArg(0), Result, Info) ||
+ !EvaluateFloat(E->getArg(1), SourceY, Info) ||
+ !EvaluateFloat(E->getArg(2), SourceZ, Info))
+ return false;
+ llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
+ (void)Result.fusedMultiplyAdd(SourceY, SourceZ, RM);
+ return true;
+ }
}
}
diff --git a/clang/test/Sema/constant-builtins-vector.cpp b/clang/test/Sema/constant-builtins-vector.cpp
index bde5c478b2b6f..5fa0a7d447ebe 100644
--- a/clang/test/Sema/constant-builtins-vector.cpp
+++ b/clang/test/Sema/constant-builtins-vector.cpp
@@ -860,3 +860,25 @@ static_assert(__builtin_elementwise_sub_sat(0U, 1U) == 0U);
static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_sub_sat((vector4char){5, 4, 3, 2}, (vector4char){1, 1, 1, 1})) == (LITTLE_END ? 0x01020304 : 0x04030201));
static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_sub_sat((vector4uchar){5, 4, 3, 2}, (vector4uchar){1, 1, 1, 1})) == (LITTLE_END ? 0x01020304U : 0x04030201U));
static_assert(__builtin_bit_cast(unsigned long long, __builtin_elementwise_sub_sat((vector4short){(short)0x8000, (short)0x8001, (short)0x8002, (short)0x8003}, (vector4short){7, 8, 9, 10}) == (LITTLE_END ? 0x8000800080008000 : 0x8000800080008000)));
+
+
+// Non-vector floating point types.
+static_assert(__builtin_elementwise_fma(2.0, 3.0, 4.0) == 10.0);
+static_assert(__builtin_elementwise_fma(200.0, 300.0, 400.0) == 60400.0);
+// Vector type.
+constexpr vector4float fmaFloat1 =
+ __builtin_elementwise_fma((vector4float){1.0, 2.0, 3.0, 4.0},
+ (vector4float){2.0, 3.0, 4.0, 5.0},
+ (vector4float){3.0, 4.0, 5.0, 6.0});
+static_assert(fmaFloat1[0] == 5.0);
+static_assert(fmaFloat1[1] == 10.0);
+static_assert(fmaFloat1[2] == 17.0);
+static_assert(fmaFloat1[3] == 26.0);
+constexpr vector4double fmaDouble1 =
+ __builtin_elementwise_fma((vector4double){1.0, 2.0, 3.0, 4.0},
+ (vector4double){2.0, 3.0, 4.0, 5.0},
+ (vector4double){3.0, 4.0, 5.0, 6.0});
+static_assert(fmaDouble1[0] == 5.0);
+static_assert(fmaDouble1[1] == 10.0);
+static_assert(fmaDouble1[2] == 17.0);
+static_assert(fmaDouble1[3] == 26.0);
|
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp -- clang/lib/AST/ByteCode/InterpBuiltin.cpp clang/lib/AST/ExprConstant.cpp clang/test/CodeGen/rounding-math.cpp clang/test/Sema/constant-builtins-vector.cpp View the diff from clang-format here.diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index 4f0263f7e..a019b40c8 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -2353,8 +2353,7 @@ static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
}
// Vector type.
- assert(Arg1Type->isVectorType() &&
- Arg2Type->isVectorType() &&
+ assert(Arg1Type->isVectorType() && Arg2Type->isVectorType() &&
Arg3Type->isVectorType());
const VectorType *VecT = Arg1Type->castAs<VectorType>();
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 79bb9daf7..a33a66f48 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15937,9 +15937,8 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
}
case Builtin::BI__builtin_elementwise_fma: {
- if(!E->getArg(0)->isPRValue() ||
- !E->getArg(1)->isPRValue() ||
- !E->getArg(2)->isPRValue()) {
+ if (!E->getArg(0)->isPRValue() || !E->getArg(1)->isPRValue() ||
+ !E->getArg(2)->isPRValue()) {
return false;
}
APFloat SourceY(0.), SourceZ(0.);
|
if(!E->getArg(0)->isPRValue() || | ||
!E->getArg(1)->isPRValue() || | ||
!E->getArg(2)->isPRValue()) { | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused about this check.
Without it, these statements in clang/test/CodeGen/builtins-elementwise-math.c
lead to an assertion failure at E->isPRValue()
in EvaluateFloat
:
float f2 = __builtin_elementwise_fma(f32, f32, f32);
double d2 = __builtin_elementwise_fma(f64, f64, f64);
v2f64 = __builtin_elementwise_fma(f64, f64, f64);
half tmp_f16 = __builtin_elementwise_fma(f16, f16, f16);
When I dump the arguments, I see this; the first argument is not an r-value, leading to the assertion failure:
Arg 0: DeclRefExpr 0x5a72563dffe8 'float' lvalue ParmVar 0x5a72563de468 'f32' 'float'
Arg 1: ImplicitCastExpr 0x5a72563e00f0 'float' <LValueToRValue>
`-DeclRefExpr 0x5a72563e0008 'float' lvalue ParmVar 0x5a72563de468 'f32' 'float'
Arg 2: ImplicitCastExpr 0x5a72563e0108 'float' <LValueToRValue>
`-DeclRefExpr 0x5a72563e0028 'float' lvalue ParmVar 0x5a72563de468 'f32' 'float'
The arguments look different for some other builtins I checked, e.g. for fmax:
/*
Arg 0 : ImplicitCastExpr 0x5ffaa8eed648 'double' <FloatingCast>
`-ImplicitCastExpr 0x5ffaa8eed630 'half':'_Float16' <LValueToRValue>
`-DeclRefExpr 0x5ffaa8eed568 'half':'_Float16' lvalue ParmVar 0x5ffaa8eebde0 'f16' 'half':'_Float16'
Arg 1: ImplicitCastExpr 0x5ffaa8eed678 'double' <FloatingCast>
`-ImplicitCastExpr 0x5ffaa8eed660 'half':'_Float16' <LValueToRValue>
`-DeclRefExpr 0x5ffaa8eed588 'half':'_Float16' lvalue ParmVar 0x5ffaa8eebde0 'f16' 'half':'_Float16'
*/
half tmp2_f16 = __builtin_fmax(f16, f16);
I'm not sure what code is responsible for the LValueToRValue
here, and is it a bug that it isn't happening for the first argument of fma
? Is it related to way that APFloat::fusedMultiplyAdd
stores its result? I'd be happy to debug further if someone can point me the right direction.
// Vector type. | ||
assert(Arg1Type->isVectorType() && | ||
Arg2Type->isVectorType() && | ||
Arg3Type->isVectorType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-format
would rather write this as:
- assert(Arg1Type->isVectorType() &&
- Arg2Type->isVectorType() &&
+ assert(Arg1Type->isVectorType() && Arg2Type->isVectorType() &&
Arg3Type->isVectorType());
I like it less but I can change it if required.
What about integers and integer vectors? |
@tbaederr I believe |
48e3822
to
2413905
Compare
@@ -11,3 +11,55 @@ float V3 = func_01(1.0F, 2.0F); | |||
// CHECK: @V1 = {{.*}}global float 1.000000e+00, align 4 | |||
// CHECK: @V2 = {{.*}}global float 1.000000e+00, align 4 | |||
// CHECK: @V3 = {{.*}}global float 3.000000e+00, align 4 | |||
|
|||
void test_builtin_elementwise_fma_round_upward() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arsenm Tests for different rounding modes. I chose this file rather than constant-builtins-vector.cpp
because, it's already testing rounding modes and already has the appropriate RUN
command set up.
#pragma STDC FENV_ROUND FE_UPWARD | ||
|
||
// CHECK: store float 0x4018000100000000, ptr %f1 | ||
// CHECK: store float 0x4018000100000000, ptr %f2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if checking for a specific constant is a good idea. But I was using these to sanity check the result, so I put them here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should definitely be testing the specific constant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 all constexpr testing must test exact fp constant results (this includes +/-0.0)
Fixes #152455.
/cc @RKSimon