Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions clang/docs/LanguageExtensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -757,12 +757,12 @@ elementwise to the input.

Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±infinity

The integer elementwise intrinsics, including ``__builtin_elementwise_popcount``,
The elementwise intrinsics ``__builtin_elementwise_popcount``,
``__builtin_elementwise_bitreverse``, ``__builtin_elementwise_add_sat``,
``__builtin_elementwise_sub_sat``, ``__builtin_elementwise_max``,
``__builtin_elementwise_min``, ``__builtin_elementwise_abs``,
``__builtin_elementwise_ctlz``, and ``__builtin_elementwise_cttz`` can be
called in a ``constexpr`` context.
``__builtin_elementwise_ctlz``, ``__builtin_elementwise_cttz``, and
``__builtin_elementwise_fma`` can be called in a ``constexpr`` context.

No implicit promotion of integer types takes place. The mixing of integer types
of different sizes and signs is forbidden in binary and ternary builtins.
Expand Down Expand Up @@ -4379,7 +4379,7 @@ fall into one of the specified floating-point classes.

if (__builtin_isfpclass(x, 448)) {
// `x` is positive finite value
...
...
}

**Description**:
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,7 @@ def ElementwiseCopysign : Builtin {

def ElementwiseFma : Builtin {
let Spellings = ["__builtin_elementwise_fma"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Attributes = [NoThrow, Const, CustomTypeChecking, Constexpr];
let Prototype = "void(...)";
}

Expand Down
58 changes: 58 additions & 0 deletions clang/lib/AST/ByteCode/InterpBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2714,6 +2714,62 @@ static bool interp__builtin_ia32_pmul(InterpState &S, CodePtr OpPC,
return true;
}

static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
const CallExpr *Call) {
assert(Call->getNumArgs() == 3);

FPOptions FPO = Call->getFPFeaturesInEffect(S.Ctx.getLangOpts());
llvm::RoundingMode RM = getRoundingMode(FPO);
const QualType Arg1Type = Call->getArg(0)->getType();
const QualType Arg2Type = Call->getArg(1)->getType();
const QualType Arg3Type = Call->getArg(2)->getType();

// Non-vector floating point types.
if (!Arg1Type->isVectorType()) {
assert(!Arg2Type->isVectorType());
assert(!Arg3Type->isVectorType());

const Floating &Z = S.Stk.pop<Floating>();
const Floating &Y = S.Stk.pop<Floating>();
const Floating &X = S.Stk.pop<Floating>();
APFloat F = X.getAPFloat();
F.fusedMultiplyAdd(Y.getAPFloat(), Z.getAPFloat(), RM);
Floating Result = S.allocFloat(X.getSemantics());
Result.copy(F);
S.Stk.push<Floating>(Result);
return true;
}

// Vector type.
assert(Arg1Type->isVectorType() && Arg2Type->isVectorType() &&
Arg3Type->isVectorType());

const VectorType *VecT = Arg1Type->castAs<VectorType>();
const QualType ElemT = VecT->getElementType();
unsigned NumElems = VecT->getNumElements();

assert(ElemT == Arg2Type->castAs<VectorType>()->getElementType() &&
ElemT == Arg3Type->castAs<VectorType>()->getElementType());
assert(NumElems == Arg2Type->castAs<VectorType>()->getNumElements() &&
NumElems == Arg3Type->castAs<VectorType>()->getNumElements());
assert(ElemT->isRealFloatingType());

const Pointer &VZ = S.Stk.pop<Pointer>();
const Pointer &VY = S.Stk.pop<Pointer>();
const Pointer &VX = S.Stk.pop<Pointer>();
const Pointer &Dst = S.Stk.peek<Pointer>();
for (unsigned I = 0; I != NumElems; ++I) {
using T = PrimConv<PT_Float>::T;
APFloat X = VX.elem<T>(I).getAPFloat();
APFloat Y = VY.elem<T>(I).getAPFloat();
APFloat Z = VZ.elem<T>(I).getAPFloat();
(void)X.fusedMultiplyAdd(Y, Z, RM);
Dst.elem<Floating>(I) = Floating(X);
}
Dst.initializeAllElements();
return true;
}

bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
uint32_t BuiltinID) {
if (!S.getASTContext().BuiltinInfo.isConstantEvaluated(BuiltinID))
Expand Down Expand Up @@ -3145,6 +3201,8 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
case clang::X86::BI__builtin_ia32_pmuludq128:
case clang::X86::BI__builtin_ia32_pmuludq256:
return interp__builtin_ia32_pmul(S, OpPC, Call, BuiltinID);
case Builtin::BI__builtin_elementwise_fma:
return interp__builtin_elementwise_fma(S, OpPC, Call);

default:
S.FFDiag(S.Current->getLocation(OpPC),
Expand Down
37 changes: 37 additions & 0 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11874,6 +11874,28 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {

return Success(APValue(ResultElements.data(), ResultElements.size()), E);
}

case Builtin::BI__builtin_elementwise_fma: {
APValue SourceX, SourceY, SourceZ;
if (!EvaluateAsRValue(Info, E->getArg(0), SourceX) ||
!EvaluateAsRValue(Info, E->getArg(1), SourceY) ||
!EvaluateAsRValue(Info, E->getArg(2), SourceZ))
return false;

unsigned SourceLen = SourceX.getVectorLength();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that all the vectors have the same length or is this verified before this point and do we have a test that checks this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already verified, clang/test/Sema/builtins-elementwise-math.c has tests which check various bad inputs.

SmallVector<APValue> ResultElements;
ResultElements.reserve(SourceLen);
llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
const APFloat &X = SourceX.getVectorElt(EltNum).getFloat();
const APFloat &Y = SourceY.getVectorElt(EltNum).getFloat();
const APFloat &Z = SourceZ.getVectorElt(EltNum).getFloat();
APFloat Result(X);
(void)Result.fusedMultiplyAdd(Y, Z, RM);
ResultElements.push_back(APValue(Result));
}
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
}
}
}

Expand Down Expand Up @@ -16139,6 +16161,21 @@ bool FloatExprEvaluator::VisitCallExpr(const CallExpr *E) {
Result = minimumnum(Result, RHS);
return true;
}

case Builtin::BI__builtin_elementwise_fma: {
if (!E->getArg(0)->isPRValue() || !E->getArg(1)->isPRValue() ||
!E->getArg(2)->isPRValue()) {
return false;
}
APFloat SourceY(0.), SourceZ(0.);
if (!EvaluateFloat(E->getArg(0), Result, Info) ||
!EvaluateFloat(E->getArg(1), SourceY, Info) ||
!EvaluateFloat(E->getArg(2), SourceZ, Info))
return false;
llvm::RoundingMode RM = getActiveRoundingMode(getEvalInfo(), E);
(void)Result.fusedMultiplyAdd(SourceY, SourceZ, RM);
return true;
}
}
}

Expand Down
52 changes: 52 additions & 0 deletions clang/test/CodeGen/rounding-math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,55 @@ float V3 = func_01(1.0F, 2.0F);
// CHECK: @V1 = {{.*}}global float 1.000000e+00, align 4
// CHECK: @V2 = {{.*}}global float 1.000000e+00, align 4
// CHECK: @V3 = {{.*}}global float 3.000000e+00, align 4

void test_builtin_elementwise_fma_round_upward() {
#pragma STDC FENV_ACCESS ON
#pragma STDC FENV_ROUND FE_UPWARD

// CHECK: store float 0x4018000100000000, ptr %f1
// CHECK: store float 0x4018000100000000, ptr %f2
constexpr float f1 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
constexpr float f2 = 2.0F * 3.000001F + 0.000001F;
static_assert(f1 == f2);
static_assert(f1 == 6.00000381F);
// CHECK: store double 0x40180000C9539B89, ptr %d1
// CHECK: store double 0x40180000C9539B89, ptr %d2
constexpr double d1 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
constexpr double d2 = 2.0 * 3.000001 + 0.000001;
static_assert(d1 == d2);
static_assert(d1 == 6.0000030000000004);
}

void test_builtin_elementwise_fma_round_downward() {
#pragma STDC FENV_ACCESS ON
#pragma STDC FENV_ROUND FE_DOWNWARD

// CHECK: store float 0x40180000C0000000, ptr %f3
// CHECK: store float 0x40180000C0000000, ptr %f4
constexpr float f3 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
constexpr float f4 = 2.0F * 3.000001F + 0.000001F;
static_assert(f3 == f4);
// CHECK: store double 0x40180000C9539B87, ptr %d3
// CHECK: store double 0x40180000C9539B87, ptr %d4
constexpr double d3 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
constexpr double d4 = 2.0 * 3.000001 + 0.000001;
static_assert(d3 == d4);
}

void test_builtin_elementwise_fma_round_nearest() {
#pragma STDC FENV_ACCESS ON
#pragma STDC FENV_ROUND FE_TONEAREST

// CHECK: store float 0x40180000C0000000, ptr %f5
// CHECK: store float 0x40180000C0000000, ptr %f6
constexpr float f5 = __builtin_elementwise_fma(2.0F, 3.000001F, 0.000001F);
constexpr float f6 = 2.0F * 3.000001F + 0.000001F;
static_assert(f5 == f6);
static_assert(f5 == 6.00000286F);
// CHECK: store double 0x40180000C9539B89, ptr %d5
// CHECK: store double 0x40180000C9539B89, ptr %d6
constexpr double d5 = __builtin_elementwise_fma(2.0, 3.000001, 0.000001);
constexpr double d6 = 2.0 * 3.000001 + 0.000001;
static_assert(d5 == d6);
static_assert(d5 == 6.0000030000000004);
}
21 changes: 21 additions & 0 deletions clang/test/Sema/constant-builtins-vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -936,3 +936,24 @@ constexpr vector4char ctz1 = __builtin_elementwise_cttz((vector4char){1, 0, 3, 4
// expected-note@-1 {{evaluation of __builtin_elementwise_cttz with a zero value is undefined}}
static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_cttz((vector4char){8, 0, 127, 0}, (vector4char){9, -1, 9, -2})) == (LITTLE_END ? 0xFE00FF03 : 0x03FF00FE));
static_assert(__builtin_bit_cast(unsigned, __builtin_elementwise_cttz((vector4char){0, 0, 0, 0}, (vector4char){0, 0, 0, 0})) == 0);

// Non-vector floating point types.
static_assert(__builtin_elementwise_fma(2.0, 3.0, 4.0) == 10.0);
static_assert(__builtin_elementwise_fma(200.0, 300.0, 400.0) == 60400.0);
// Vector type.
constexpr vector4float fmaFloat1 =
__builtin_elementwise_fma((vector4float){1.0, 2.0, 3.0, 4.0},
(vector4float){2.0, 3.0, 4.0, 5.0},
(vector4float){3.0, 4.0, 5.0, 6.0});
static_assert(fmaFloat1[0] == 5.0);
static_assert(fmaFloat1[1] == 10.0);
static_assert(fmaFloat1[2] == 17.0);
static_assert(fmaFloat1[3] == 26.0);
constexpr vector4double fmaDouble1 =
__builtin_elementwise_fma((vector4double){1.0, 2.0, 3.0, 4.0},
(vector4double){2.0, 3.0, 4.0, 5.0},
(vector4double){3.0, 4.0, 5.0, 6.0});
static_assert(fmaDouble1[0] == 5.0);
static_assert(fmaDouble1[1] == 10.0);
static_assert(fmaDouble1[2] == 17.0);
static_assert(fmaDouble1[3] == 26.0);