Skip to content

Commit 48e3822

Browse files
committed
[clang] Enable constexpr handling for __builtin_elementwise_fma
1 parent dbfc3ed commit 48e3822

File tree

5 files changed

+137
-3
lines changed

5 files changed

+137
-3
lines changed

clang/docs/LanguageExtensions.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -757,9 +757,10 @@ elementwise to the input.
757757

758758
Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±infinity
759759

760-
The integer elementwise intrinsics, including ``__builtin_elementwise_popcount``,
760+
The elementwise intrinsics ``__builtin_elementwise_popcount``,
761761
``__builtin_elementwise_bitreverse``, ``__builtin_elementwise_add_sat``,
762-
``__builtin_elementwise_sub_sat`` can be called in a ``constexpr`` context.
762+
``__builtin_elementwise_sub_sat``, and ``__builtin_elementwise_fma``
763+
can be called in a ``constexpr`` context.
763764

764765
No implicit promotion of integer types takes place. The mixing of integer types
765766
of different sizes and signs is forbidden in binary and ternary builtins.

clang/include/clang/Basic/Builtins.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,7 @@ def ElementwiseCopysign : Builtin {
14981498

14991499
def ElementwiseFma : Builtin {
15001500
let Spellings = ["__builtin_elementwise_fma"];
1501-
let Attributes = [NoThrow, Const, CustomTypeChecking];
1501+
let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr];
15021502
let Prototype = "void(...)";
15031503
}
15041504

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ static void diagnoseNonConstexprBuiltin(InterpState &S, CodePtr OpPC,
141141
S.CCEDiag(Loc, diag::note_invalid_subexpr_in_const_expr);
142142
}
143143

144+
// Same implementation as Compiler::getRoundingMode.
145+
static llvm::RoundingMode getRoundingMode(const InterpState &S, const Expr *E) {
146+
FPOptions FPO = E->getFPFeaturesInEffect(S.Ctx.getLangOpts());
147+
148+
if (FPO.getRoundingMode() == llvm::RoundingMode::Dynamic)
149+
return llvm::RoundingMode::NearestTiesToEven;
150+
151+
return FPO.getRoundingMode();
152+
}
153+
144154
static bool interp__builtin_is_constant_evaluated(InterpState &S, CodePtr OpPC,
145155
const InterpFrame *Frame,
146156
const CallExpr *Call) {
@@ -2320,6 +2330,65 @@ static bool interp__builtin_elementwise_sat(InterpState &S, CodePtr OpPC,
23202330
return true;
23212331
}
23222332

2333+
static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
2334+
const CallExpr *Call) {
2335+
assert(Call->getNumArgs() == 3);
2336+
2337+
llvm::RoundingMode RM = getRoundingMode(S, Call);
2338+
2339+
const QualType Arg1Type = Call->getArg(0)->getType();
2340+
const QualType Arg2Type = Call->getArg(1)->getType();
2341+
const QualType Arg3Type = Call->getArg(2)->getType();
2342+
2343+
// Non-vector floating point types.
2344+
if (!Arg1Type->isVectorType()) {
2345+
assert(!Arg2Type->isVectorType());
2346+
assert(!Arg3Type->isVectorType());
2347+
2348+
const Floating &Z = S.Stk.pop<Floating>();
2349+
const Floating &Y = S.Stk.pop<Floating>();
2350+
const Floating &X = S.Stk.pop<Floating>();
2351+
2352+
APFloat F = X.getAPFloat();
2353+
F.fusedMultiplyAdd(Y.getAPFloat(), Z.getAPFloat(), RM);
2354+
Floating Result = S.allocFloat(X.getSemantics());
2355+
Result.copy(F);
2356+
S.Stk.push<Floating>(Result);
2357+
return true;
2358+
}
2359+
2360+
// Vector type.
2361+
assert(Arg1Type->isVectorType() &&
2362+
Arg2Type->isVectorType() &&
2363+
Arg3Type->isVectorType());
2364+
2365+
const VectorType *VecT = Arg1Type->castAs<VectorType>();
2366+
const QualType ElemT = VecT->getElementType();
2367+
unsigned NumElems = VecT->getNumElements();
2368+
2369+
assert(ElemT == Arg2Type->castAs<VectorType>()->getElementType() &&
2370+
ElemT == Arg3Type->castAs<VectorType>()->getElementType());
2371+
assert(NumElems == Arg2Type->castAs<VectorType>()->getNumElements() &&
2372+
NumElems == Arg3Type->castAs<VectorType>()->getNumElements());
2373+
assert(ElemT->isRealFloatingType());
2374+
2375+
const Pointer &VZ = S.Stk.pop<Pointer>();
2376+
const Pointer &VY = S.Stk.pop<Pointer>();
2377+
const Pointer &VX = S.Stk.pop<Pointer>();
2378+
const Pointer &Dst = S.Stk.peek<Pointer>();
2379+
2380+
for (unsigned I = 0; I != NumElems; ++I) {
2381+
using T = PrimConv<PT_Float>::T;
2382+
APFloat X = VX.elem<T>(I).getAPFloat();
2383+
APFloat Y = VY.elem<T>(I).getAPFloat();
2384+
APFloat Z = VZ.elem<T>(I).getAPFloat();
2385+
(void)X.fusedMultiplyAdd(Y, Z, RM);
2386+
Dst.elem<T>(I) = static_cast<PrimConv<PT_Float>::T>(X);
2387+
}
2388+
Dst.initializeAllElements();
2389+
return true;
2390+
}
2391+
23232392
bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
23242393
uint32_t BuiltinID) {
23252394
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
@@ -2727,6 +2796,9 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
27272796
case Builtin::BI__builtin_elementwise_sub_sat:
27282797
return interp__builtin_elementwise_sat(S, OpPC, Call, BuiltinID);
27292798

2799+
case Builtin::BI__builtin_elementwise_fma:
2800+
return interp__builtin_elementwise_fma(S, OpPC, Call);
2801+
27302802
default:
27312803
S.FFDiag(S.Current->getLocation(OpPC),
27322804
diag::note_invalid_subexpr_in_const_expr)

clang/lib/AST/ExprConstant.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11658,6 +11658,29 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1165811658

1165911659
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
1166011660
}
11661+
case Builtin::BI__builtin_elementwise_fma: {
11662+
APValue SourceX, SourceY, SourceZ;
11663+
if (!EvaluateAsRValue(Info, E->getArg(0), SourceX) ||
11664+
!EvaluateAsRValue(Info, E->getArg(1), SourceY) ||
11665+
!EvaluateAsRValue(Info, E->getArg(2), SourceZ))
11666+
return false;
11667+
11668+
unsigned SourceLen = SourceX.getVectorLength();
11669+
SmallVector<APValue> ResultElements;
11670+
ResultElements.reserve(SourceLen);
11671+
llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
11672+
11673+
for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
11674+
APFloat X = SourceX.getVectorElt(EltNum).getFloat();
11675+
APFloat Y = SourceY.getVectorElt(EltNum).getFloat();
11676+
APFloat Z = SourceZ.getVectorElt(EltNum).getFloat();
11677+
APFloat Result(X);
11678+
(void)Result.fusedMultiplyAdd(Y, Z, RM);
11679+
ResultElements.push_back(APValue(Result));
11680+
}
11681+
11682+
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
11683+
}
1166111684
}
1166211685
}
1166311686

@@ -15878,6 +15901,22 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
1587815901
Result = minimumnum(Result, RHS);
1587915902
return true;
1588015903
}
15904+
15905+
case Builtin::BI__builtin_elementwise_fma: {
15906+
if(!E->getArg(0)->isPRValue() ||
15907+
!E->getArg(1)->isPRValue() ||
15908+
!E->getArg(2)->isPRValue()) {
15909+
return false;
15910+
}
15911+
APFloat SourceY(0.), SourceZ(0.);
15912+
if (!EvaluateFloat(E->getArg(0), Result, Info) ||
15913+
!EvaluateFloat(E->getArg(1), SourceY, Info) ||
15914+
!EvaluateFloat(E->getArg(2), SourceZ, Info))
15915+
return false;
15916+
llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
15917+
(void)Result.fusedMultiplyAdd(SourceY, SourceZ, RM);
15918+
return true;
15919+
}
1588115920
}
1588215921
}
1588315922

clang/test/Sema/constant-builtins-vector.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,3 +860,25 @@ static_assert(__builtin_elementwise_sub_sat(0U, 1U) == 0U);
860860
static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_sub_sat((vector4char){5, 4, 3, 2}, (vector4char){1, 1, 1, 1})) == (LITTLE_END ? 0x01020304 : 0x04030201));
861861
static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_sub_sat((vector4uchar){5, 4, 3, 2}, (vector4uchar){1, 1, 1, 1})) == (LITTLE_END ? 0x01020304U : 0x04030201U));
862862
static_assert(__builtin_bit_cast(unsigned long long, __builtin_elementwise_sub_sat((vector4short){(short)0x8000, (short)0x8001, (short)0x8002, (short)0x8003}, (vector4short){7, 8, 9, 10}) == (LITTLE_END ? 0x8000800080008000 : 0x8000800080008000)));
863+
864+
865+
// Non-vector floating point types.
866+
static_assert(__builtin_elementwise_fma(2.0, 3.0, 4.0) == 10.0);
867+
static_assert(__builtin_elementwise_fma(200.0, 300.0, 400.0) == 60400.0);
868+
// Vector type.
869+
constexpr vector4float fmaFloat1 =
870+
__builtin_elementwise_fma((vector4float){1.0, 2.0, 3.0, 4.0},
871+
(vector4float){2.0, 3.0, 4.0, 5.0},
872+
(vector4float){3.0, 4.0, 5.0, 6.0});
873+
static_assert(fmaFloat1[0] == 5.0);
874+
static_assert(fmaFloat1[1] == 10.0);
875+
static_assert(fmaFloat1[2] == 17.0);
876+
static_assert(fmaFloat1[3] == 26.0);
877+
constexpr vector4double fmaDouble1 =
878+
__builtin_elementwise_fma((vector4double){1.0, 2.0, 3.0, 4.0},
879+
(vector4double){2.0, 3.0, 4.0, 5.0},
880+
(vector4double){3.0, 4.0, 5.0, 6.0});
881+
static_assert(fmaDouble1[0] == 5.0);
882+
static_assert(fmaDouble1[1] == 10.0);
883+
static_assert(fmaDouble1[2] == 17.0);
884+
static_assert(fmaDouble1[3] == 26.0);

0 commit comments

Comments
 (0)