Skip to content

[clang] Enable constexpr handling for __builtin_elementwise_fma #152919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions clang/docs/LanguageExtensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,10 @@ elementwise to the input.

Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±infinity

The integer elementwise intrinsics, including ``__builtin_elementwise_popcount``,
The elementwise intrinsics ``__builtin_elementwise_popcount``,
``__builtin_elementwise_bitreverse``, ``__builtin_elementwise_add_sat``,
``__builtin_elementwise_sub_sat`` can be called in a ``constexpr`` context.
``__builtin_elementwise_sub_sat``, and ``__builtin_elementwise_fma``
can be called in a ``constexpr`` context.

No implicit promotion of integer types takes place. The mixing of integer types
of different sizes and signs is forbidden in binary and ternary builtins.
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,7 @@ def ElementwiseCopysign : Builtin {

def ElementwiseFma : Builtin {
let Spellings = ["__builtin_elementwise_fma"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr];
let Prototype = "void(...)";
}

Expand Down
60 changes: 60 additions & 0 deletions clang/lib/AST/ByteCode/InterpBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2320,6 +2320,63 @@ static bool interp__builtin_elementwise_sat(InterpState &S, CodePtr OpPC,
return true;
}

static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
const CallExpr *Call) {
assert(Call->getNumArgs() == 3);

FPOptions FPO = Call->getFPFeaturesInEffect(S.Ctx.getLangOpts());
llvm::RoundingMode RM = getRoundingMode(FPO);
const QualType Arg1Type = Call->getArg(0)->getType();
const QualType Arg2Type = Call->getArg(1)->getType();
const QualType Arg3Type = Call->getArg(2)->getType();

// Non-vector floating point types.
if (!Arg1Type->isVectorType()) {
assert(!Arg2Type->isVectorType());
assert(!Arg3Type->isVectorType());

const Floating &Z = S.Stk.pop<Floating>();
const Floating &Y = S.Stk.pop<Floating>();
const Floating &X = S.Stk.pop<Floating>();
APFloat F = X.getAPFloat();
F.fusedMultiplyAdd(Y.getAPFloat(), Z.getAPFloat(), RM);
Floating Result = S.allocFloat(X.getSemantics());
Result.copy(F);
S.Stk.push<Floating>(Result);
return true;
}

// Vector type.
assert(Arg1Type->isVectorType() &&
Arg2Type->isVectorType() &&
Arg3Type->isVectorType());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-format would rather write this as:

-  assert(Arg1Type->isVectorType() &&
-         Arg2Type->isVectorType() &&
+  assert(Arg1Type->isVectorType() && Arg2Type->isVectorType() &&
          Arg3Type->isVectorType());

I like it less but I can change it if required.


const VectorType *VecT = Arg1Type->castAs<VectorType>();
const QualType ElemT = VecT->getElementType();
unsigned NumElems = VecT->getNumElements();

assert(ElemT == Arg2Type->castAs<VectorType>()->getElementType() &&
ElemT == Arg3Type->castAs<VectorType>()->getElementType());
assert(NumElems == Arg2Type->castAs<VectorType>()->getNumElements() &&
NumElems == Arg3Type->castAs<VectorType>()->getNumElements());
assert(ElemT->isRealFloatingType());

const Pointer &VZ = S.Stk.pop<Pointer>();
const Pointer &VY = S.Stk.pop<Pointer>();
const Pointer &VX = S.Stk.pop<Pointer>();
const Pointer &Dst = S.Stk.peek<Pointer>();
for (unsigned I = 0; I != NumElems; ++I) {
using T = PrimConv<PT_Float>::T;
APFloat X = VX.elem<T>(I).getAPFloat();
APFloat Y = VY.elem<T>(I).getAPFloat();
APFloat Z = VZ.elem<T>(I).getAPFloat();
(void)X.fusedMultiplyAdd(Y, Z, RM);
Dst.elem<Floating>(I) = Floating(X);
}
Dst.initializeAllElements();
return true;
}

bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
uint32_t BuiltinID) {
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
Expand Down Expand Up @@ -2727,6 +2784,9 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
case Builtin::BI__builtin_elementwise_sub_sat:
return interp__builtin_elementwise_sat(S, OpPC, Call, BuiltinID);

case Builtin::BI__builtin_elementwise_fma:
return interp__builtin_elementwise_fma(S, OpPC, Call);

default:
S.FFDiag(S.Current->getLocation(OpPC),
diag::note_invalid_subexpr_in_const_expr)
Expand Down
39 changes: 39 additions & 0 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11658,6 +11658,29 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {

return Success(APValue(ResultElements.data(), ResultElements.size()), E);
}
case Builtin::BI__builtin_elementwise_fma: {
APValue SourceX, SourceY, SourceZ;
if (!EvaluateAsRValue(Info, E->getArg(0), SourceX) ||
!EvaluateAsRValue(Info, E->getArg(1), SourceY) ||
!EvaluateAsRValue(Info, E->getArg(2), SourceZ))
return false;

unsigned SourceLen = SourceX.getVectorLength();
SmallVector<APValue> ResultElements;
ResultElements.reserve(SourceLen);
llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);

for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
APFloat X = SourceX.getVectorElt(EltNum).getFloat();
APFloat Y = SourceY.getVectorElt(EltNum).getFloat();
APFloat Z = SourceZ.getVectorElt(EltNum).getFloat();
APFloat Result(X);
(void)Result.fusedMultiplyAdd(Y, Z, RM);
ResultElements.push_back(APValue(Result));
}

return Success(APValue(ResultElements.data(), ResultElements.size()), E);
}
}
}

Expand Down Expand Up @@ -15878,6 +15901,22 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
Result = minimumnum(Result, RHS);
return true;
}

case Builtin::BI__builtin_elementwise_fma: {
if(!E->getArg(0)->isPRValue() ||
!E->getArg(1)->isPRValue() ||
!E->getArg(2)->isPRValue()) {
return false;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about this check.

Without it, these statements in clang/test/CodeGen/builtins-elementwise-math.c lead to an assertion failure at E->isPRValue() in EvaluateFloat:

float f2 = __builtin_elementwise_fma(f32, f32, f32);
double d2 = __builtin_elementwise_fma(f64, f64, f64);
v2f64 = __builtin_elementwise_fma(f64, f64, f64); 
half tmp_f16 = __builtin_elementwise_fma(f16, f16, f16);

When I dump the arguments, I see this; the first argument is not an r-value, leading to the assertion failure:

Arg 0: DeclRefExpr 0x5a72563dffe8 'float' lvalue ParmVar 0x5a72563de468 'f32' 'float'

Arg 1: ImplicitCastExpr 0x5a72563e00f0 'float' <LValueToRValue>
`-DeclRefExpr 0x5a72563e0008 'float' lvalue ParmVar 0x5a72563de468 'f32' 'float'

Arg 2: ImplicitCastExpr 0x5a72563e0108 'float' <LValueToRValue>
`-DeclRefExpr 0x5a72563e0028 'float' lvalue ParmVar 0x5a72563de468 'f32' 'float'

The arguments look different for some other builtins I checked, e.g. for fmax:

/*
Arg 0 : ImplicitCastExpr 0x5ffaa8eed648 'double' <FloatingCast>
`-ImplicitCastExpr 0x5ffaa8eed630 'half':'_Float16' <LValueToRValue>
  `-DeclRefExpr 0x5ffaa8eed568 'half':'_Float16' lvalue ParmVar 0x5ffaa8eebde0 'f16' 'half':'_Float16'

Arg 1: ImplicitCastExpr 0x5ffaa8eed678 'double' <FloatingCast>
`-ImplicitCastExpr 0x5ffaa8eed660 'half':'_Float16' <LValueToRValue>
  `-DeclRefExpr 0x5ffaa8eed588 'half':'_Float16' lvalue ParmVar 0x5ffaa8eebde0 'f16' 'half':'_Float16'
*/

half tmp2_f16 = __builtin_fmax(f16, f16);

I'm not sure what code is responsible for the LValueToRValue here, and is it a bug that it isn't happening for the first argument of fma? Is it related to way that APFloat::fusedMultiplyAdd stores its result? I'd be happy to debug further if someone can point me the right direction.

}
APFloat SourceY(0.), SourceZ(0.);
if (!EvaluateFloat(E->getArg(0), Result, Info) ||
!EvaluateFloat(E->getArg(1), SourceY, Info) ||
!EvaluateFloat(E->getArg(2), SourceZ, Info))
return false;
llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
(void)Result.fusedMultiplyAdd(SourceY, SourceZ, RM);
return true;
}
}
}

Expand Down
52 changes: 52 additions & 0 deletions clang/test/CodeGen/rounding-math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,55 @@ float V3 = func_01(1.0F, 2.0F);
// CHECK: @V1 = {{.*}}global float 1.000000e+00, align 4
// CHECK: @V2 = {{.*}}global float 1.000000e+00, align 4
// CHECK: @V3 = {{.*}}global float 3.000000e+00, align 4

void test_builtin_elementwise_fma_round_upward() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arsenm Tests for different rounding modes. I chose this file rather than constant-builtins-vector.cpp because, it's already testing rounding modes and already has the appropriate RUN command set up.

#pragma STDC FENV_ACCESS ON
#pragma STDC FENV_ROUND FE_UPWARD

// CHECK: store float 0x4018000100000000, ptr %f1
// CHECK: store float 0x4018000100000000, ptr %f2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if checking for a specific constant is a good idea. But I was using these to sanity check the result, so I put them here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should definitely be testing the specific constant

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 all constexpr testing must test exact fp constant results (this includes +/-0.0)

constexpr float f1 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
constexpr float f2 = 2.0F * 3.000001F + 0.000001F;
static_assert(f1 == f2);
static_assert(f1 == 6.00000381F);
// CHECK: store double 0x40180000C9539B89, ptr %d1
// CHECK: store double 0x40180000C9539B89, ptr %d2
constexpr double d1 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
constexpr double d2 = 2.0 * 3.000001 + 0.000001;
static_assert(d1 == d2);
static_assert(d1 == 6.0000030000000004);
}

void test_builtin_elementwise_fma_round_downward() {
#pragma STDC FENV_ACCESS ON
#pragma STDC FENV_ROUND FE_DOWNWARD

// CHECK: store float 0x40180000C0000000, ptr %f3
// CHECK: store float 0x40180000C0000000, ptr %f4
constexpr float f3 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
constexpr float f4 = 2.0F * 3.000001F + 0.000001F;
static_assert(f3 == f4);
// CHECK: store double 0x40180000C9539B87, ptr %d3
// CHECK: store double 0x40180000C9539B87, ptr %d4
constexpr double d3 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
constexpr double d4 = 2.0 * 3.000001 + 0.000001;
static_assert(d3 == d4);
}

void test_builtin_elementwise_fma_round_nearest() {
#pragma STDC FENV_ACCESS ON
#pragma STDC FENV_ROUND FE_TONEAREST

// CHECK: store float 0x40180000C0000000, ptr %f5
// CHECK: store float 0x40180000C0000000, ptr %f6
constexpr float f5 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
constexpr float f6 = 2.0F * 3.000001F + 0.000001F;
static_assert(f5 == f6);
static_assert(f5 == 6.00000286F);
// CHECK: store double 0x40180000C9539B89, ptr %d5
// CHECK: store double 0x40180000C9539B89, ptr %d6
constexpr double d5 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
constexpr double d6 = 2.0 * 3.000001 + 0.000001;
static_assert(d5 == d6);
static_assert(d5 == 6.0000030000000004);
}
22 changes: 22 additions & 0 deletions clang/test/Sema/constant-builtins-vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,3 +860,25 @@ static_assert(__builtin_elementwise_sub_sat(0U, 1U) == 0U);
static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_sub_sat((vector4char){5, 4, 3, 2}, (vector4char){1, 1, 1, 1})) == (LITTLE_END ? 0x01020304 : 0x04030201));
static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_sub_sat((vector4uchar){5, 4, 3, 2}, (vector4uchar){1, 1, 1, 1})) == (LITTLE_END ? 0x01020304U : 0x04030201U));
static_assert(__builtin_bit_cast(unsigned long long, __builtin_elementwise_sub_sat((vector4short){(short)0x8000, (short)0x8001, (short)0x8002, (short)0x8003}, (vector4short){7, 8, 9, 10}) == (LITTLE_END ? 0x8000800080008000 : 0x8000800080008000)));


// Non-vector floating point types.
static_assert(__builtin_elementwise_fma(2.0, 3.0, 4.0) == 10.0);
static_assert(__builtin_elementwise_fma(200.0, 300.0, 400.0) == 60400.0);
// Vector type.
constexpr vector4float fmaFloat1 =
__builtin_elementwise_fma((vector4float){1.0, 2.0, 3.0, 4.0},
(vector4float){2.0, 3.0, 4.0, 5.0},
(vector4float){3.0, 4.0, 5.0, 6.0});
static_assert(fmaFloat1[0] == 5.0);
static_assert(fmaFloat1[1] == 10.0);
static_assert(fmaFloat1[2] == 17.0);
static_assert(fmaFloat1[3] == 26.0);
constexpr vector4double fmaDouble1 =
__builtin_elementwise_fma((vector4double){1.0, 2.0, 3.0, 4.0},
(vector4double){2.0, 3.0, 4.0, 5.0},
(vector4double){3.0, 4.0, 5.0, 6.0});
static_assert(fmaDouble1[0] == 5.0);
static_assert(fmaDouble1[1] == 10.0);
static_assert(fmaDouble1[2] == 17.0);
static_assert(fmaDouble1[3] == 26.0);
Loading