Skip to content

Commit f649605

Browse files
authored
[clang] Enable constexpr handling for __builtin_elementwise_fma (#152919)
Fixes #152455.
1 parent 318b0dd commit f649605

File tree

6 files changed

+173
-5
lines changed

6 files changed

+173
-5
lines changed

clang/docs/LanguageExtensions.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -767,12 +767,12 @@ elementwise to the input.
767767

768768
Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±infinity
769769

770-
The integer elementwise intrinsics, including ``__builtin_elementwise_popcount``,
770+
The elementwise intrinsics ``__builtin_elementwise_popcount``,
771771
``__builtin_elementwise_bitreverse``, ``__builtin_elementwise_add_sat``,
772772
``__builtin_elementwise_sub_sat``, ``__builtin_elementwise_max``,
773773
``__builtin_elementwise_min``, ``__builtin_elementwise_abs``,
774-
``__builtin_elementwise_ctlz``, and ``__builtin_elementwise_cttz`` can be
775-
called in a ``constexpr`` context.
774+
``__builtin_elementwise_ctlz``, ``__builtin_elementwise_cttz``, and
775+
``__builtin_elementwise_fma`` can be called in a ``constexpr`` context.
776776

777777
No implicit promotion of integer types takes place. The mixing of integer types
778778
of different sizes and signs is forbidden in binary and ternary builtins.
@@ -4389,7 +4389,7 @@ fall into one of the specified floating-point classes.
43894389
43904390
if (__builtin_isfpclass(x, 448)) {
43914391
// `x` is positive finite value
4392-
...
4392+
...
43934393
}
43944394
43954395
**Description**:

clang/include/clang/Basic/Builtins.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,7 @@ def ElementwiseCopysign : Builtin {
14981498

14991499
def ElementwiseFma : Builtin {
15001500
let Spellings = ["__builtin_elementwise_fma"];
1501-
let Attributes = [NoThrow, Const, CustomTypeChecking];
1501+
let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr];
15021502
let Prototype = "void(...)";
15031503
}
15041504

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2714,6 +2714,62 @@ static bool interp__builtin_ia32_pmul(InterpState &S, CodePtr OpPC,
27142714
return true;
27152715
}
27162716

2717+
static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
2718+
const CallExpr *Call) {
2719+
assert(Call->getNumArgs() == 3);
2720+
2721+
FPOptions FPO = Call->getFPFeaturesInEffect(S.Ctx.getLangOpts());
2722+
llvm::RoundingMode RM = getRoundingMode(FPO);
2723+
const QualType Arg1Type = Call->getArg(0)->getType();
2724+
const QualType Arg2Type = Call->getArg(1)->getType();
2725+
const QualType Arg3Type = Call->getArg(2)->getType();
2726+
2727+
// Non-vector floating point types.
2728+
if (!Arg1Type->isVectorType()) {
2729+
assert(!Arg2Type->isVectorType());
2730+
assert(!Arg3Type->isVectorType());
2731+
2732+
const Floating &Z = S.Stk.pop<Floating>();
2733+
const Floating &Y = S.Stk.pop<Floating>();
2734+
const Floating &X = S.Stk.pop<Floating>();
2735+
APFloat F = X.getAPFloat();
2736+
F.fusedMultiplyAdd(Y.getAPFloat(), Z.getAPFloat(), RM);
2737+
Floating Result = S.allocFloat(X.getSemantics());
2738+
Result.copy(F);
2739+
S.Stk.push<Floating>(Result);
2740+
return true;
2741+
}
2742+
2743+
// Vector type.
2744+
assert(Arg1Type->isVectorType() && Arg2Type->isVectorType() &&
2745+
Arg3Type->isVectorType());
2746+
2747+
const VectorType *VecT = Arg1Type->castAs<VectorType>();
2748+
const QualType ElemT = VecT->getElementType();
2749+
unsigned NumElems = VecT->getNumElements();
2750+
2751+
assert(ElemT == Arg2Type->castAs<VectorType>()->getElementType() &&
2752+
ElemT == Arg3Type->castAs<VectorType>()->getElementType());
2753+
assert(NumElems == Arg2Type->castAs<VectorType>()->getNumElements() &&
2754+
NumElems == Arg3Type->castAs<VectorType>()->getNumElements());
2755+
assert(ElemT->isRealFloatingType());
2756+
2757+
const Pointer &VZ = S.Stk.pop<Pointer>();
2758+
const Pointer &VY = S.Stk.pop<Pointer>();
2759+
const Pointer &VX = S.Stk.pop<Pointer>();
2760+
const Pointer &Dst = S.Stk.peek<Pointer>();
2761+
for (unsigned I = 0; I != NumElems; ++I) {
2762+
using T = PrimConv<PT_Float>::T;
2763+
APFloat X = VX.elem<T>(I).getAPFloat();
2764+
APFloat Y = VY.elem<T>(I).getAPFloat();
2765+
APFloat Z = VZ.elem<T>(I).getAPFloat();
2766+
(void)X.fusedMultiplyAdd(Y, Z, RM);
2767+
Dst.elem<Floating>(I) = Floating(X);
2768+
}
2769+
Dst.initializeAllElements();
2770+
return true;
2771+
}
2772+
27172773
bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
27182774
uint32_t BuiltinID) {
27192775
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -3145,6 +3201,8 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
31453201
case clang::X86::BI__builtin_ia32_pmuludq128:
31463202
case clang::X86::BI__builtin_ia32_pmuludq256:
31473203
return interp__builtin_ia32_pmul(S, OpPC, Call, BuiltinID);
3204+
case Builtin::BI__builtin_elementwise_fma:
3205+
return interp__builtin_elementwise_fma(S, OpPC, Call);
31483206

31493207
default:
31503208
S.FFDiag(S.Current->getLocation(OpPC),

clang/lib/AST/ExprConstant.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11874,6 +11874,28 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1187411874

1187511875
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
1187611876
}
11877+
11878+
case Builtin::BI__builtin_elementwise_fma: {
11879+
APValue SourceX, SourceY, SourceZ;
11880+
if (!EvaluateAsRValue(Info, E->getArg(0), SourceX) ||
11881+
!EvaluateAsRValue(Info, E->getArg(1), SourceY) ||
11882+
!EvaluateAsRValue(Info, E->getArg(2), SourceZ))
11883+
return false;
11884+
11885+
unsigned SourceLen = SourceX.getVectorLength();
11886+
SmallVector<APValue> ResultElements;
11887+
ResultElements.reserve(SourceLen);
11888+
llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
11889+
for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
11890+
const APFloat &X = SourceX.getVectorElt(EltNum).getFloat();
11891+
const APFloat &Y = SourceY.getVectorElt(EltNum).getFloat();
11892+
const APFloat &Z = SourceZ.getVectorElt(EltNum).getFloat();
11893+
APFloat Result(X);
11894+
(void)Result.fusedMultiplyAdd(Y, Z, RM);
11895+
ResultElements.push_back(APValue(Result));
11896+
}
11897+
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
11898+
}
1187711899
}
1187811900
}
1187911901

@@ -16139,6 +16161,21 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
1613916161
Result = minimumnum(Result, RHS);
1614016162
return true;
1614116163
}
16164+
16165+
case Builtin::BI__builtin_elementwise_fma: {
16166+
if (!E->getArg(0)->isPRValue() || !E->getArg(1)->isPRValue() ||
16167+
!E->getArg(2)->isPRValue()) {
16168+
return false;
16169+
}
16170+
APFloat SourceY(0.), SourceZ(0.);
16171+
if (!EvaluateFloat(E->getArg(0), Result, Info) ||
16172+
!EvaluateFloat(E->getArg(1), SourceY, Info) ||
16173+
!EvaluateFloat(E->getArg(2), SourceZ, Info))
16174+
return false;
16175+
llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
16176+
(void)Result.fusedMultiplyAdd(SourceY, SourceZ, RM);
16177+
return true;
16178+
}
1614216179
}
1614316180
}
1614416181

clang/test/CodeGen/rounding-math.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,55 @@ float V3 = func_01(1.0F, 2.0F);
1111
// CHECK: @V1 = {{.*}}global float 1.000000e+00, align 4
1212
// CHECK: @V2 = {{.*}}global float 1.000000e+00, align 4
1313
// CHECK: @V3 = {{.*}}global float 3.000000e+00, align 4
14+
15+
void test_builtin_elementwise_fma_round_upward() {
16+
#pragma STDC FENV_ACCESS ON
17+
#pragma STDC FENV_ROUND FE_UPWARD
18+
19+
// CHECK: store float 0x4018000100000000, ptr %f1
20+
// CHECK: store float 0x4018000100000000, ptr %f2
21+
constexpr float f1 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
22+
constexpr float f2 = 2.0F * 3.000001F + 0.000001F;
23+
static_assert(f1 == f2);
24+
static_assert(f1 == 6.00000381F);
25+
// CHECK: store double 0x40180000C9539B89, ptr %d1
26+
// CHECK: store double 0x40180000C9539B89, ptr %d2
27+
constexpr double d1 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
28+
constexpr double d2 = 2.0 * 3.000001 + 0.000001;
29+
static_assert(d1 == d2);
30+
static_assert(d1 == 6.0000030000000004);
31+
}
32+
33+
void test_builtin_elementwise_fma_round_downward() {
34+
#pragma STDC FENV_ACCESS ON
35+
#pragma STDC FENV_ROUND FE_DOWNWARD
36+
37+
// CHECK: store float 0x40180000C0000000, ptr %f3
38+
// CHECK: store float 0x40180000C0000000, ptr %f4
39+
constexpr float f3 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
40+
constexpr float f4 = 2.0F * 3.000001F + 0.000001F;
41+
static_assert(f3 == f4);
42+
// CHECK: store double 0x40180000C9539B87, ptr %d3
43+
// CHECK: store double 0x40180000C9539B87, ptr %d4
44+
constexpr double d3 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
45+
constexpr double d4 = 2.0 * 3.000001 + 0.000001;
46+
static_assert(d3 == d4);
47+
}
48+
49+
void test_builtin_elementwise_fma_round_nearest() {
50+
#pragma STDC FENV_ACCESS ON
51+
#pragma STDC FENV_ROUND FE_TONEAREST
52+
53+
// CHECK: store float 0x40180000C0000000, ptr %f5
54+
// CHECK: store float 0x40180000C0000000, ptr %f6
55+
constexpr float f5 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
56+
constexpr float f6 = 2.0F * 3.000001F + 0.000001F;
57+
static_assert(f5 == f6);
58+
static_assert(f5 == 6.00000286F);
59+
// CHECK: store double 0x40180000C9539B89, ptr %d5
60+
// CHECK: store double 0x40180000C9539B89, ptr %d6
61+
constexpr double d5 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
62+
constexpr double d6 = 2.0 * 3.000001 + 0.000001;
63+
static_assert(d5 == d6);
64+
static_assert(d5 == 6.0000030000000004);
65+
}

clang/test/Sema/constant-builtins-vector.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,3 +936,24 @@ constexpr vector4char ctz1 = __builtin_elementwise_cttz((vector4char){1, 0, 3, 4
936936
// expected-note@-1 {{evaluation of __builtin_elementwise_cttz with a zero value is undefined}}
937937
static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_cttz((vector4char){8, 0, 127, 0}, (vector4char){9, -1, 9, -2})) == (LITTLE_END ? 0xFE00FF03 : 0x03FF00FE));
938938
static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_cttz((vector4char){0, 0, 0, 0}, (vector4char){0, 0, 0, 0})) == 0);
939+
940+
// Non-vector floating point types.
941+
static_assert(__builtin_elementwise_fma(2.0, 3.0, 4.0) == 10.0);
942+
static_assert(__builtin_elementwise_fma(200.0, 300.0, 400.0) == 60400.0);
943+
// Vector type.
944+
constexpr vector4float fmaFloat1 =
945+
__builtin_elementwise_fma((vector4float){1.0, 2.0, 3.0, 4.0},
946+
(vector4float){2.0, 3.0, 4.0, 5.0},
947+
(vector4float){3.0, 4.0, 5.0, 6.0});
948+
static_assert(fmaFloat1[0] == 5.0);
949+
static_assert(fmaFloat1[1] == 10.0);
950+
static_assert(fmaFloat1[2] == 17.0);
951+
static_assert(fmaFloat1[3] == 26.0);
952+
constexpr vector4double fmaDouble1 =
953+
__builtin_elementwise_fma((vector4double){1.0, 2.0, 3.0, 4.0},
954+
(vector4double){2.0, 3.0, 4.0, 5.0},
955+
(vector4double){3.0, 4.0, 5.0, 6.0});
956+
static_assert(fmaDouble1[0] == 5.0);
957+
static_assert(fmaDouble1[1] == 10.0);
958+
static_assert(fmaDouble1[2] == 17.0);
959+
static_assert(fmaDouble1[3] == 26.0);

0 commit comments

Comments
 (0)