Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions clang/include/clang/Basic/arm_neon.td
Original file line number Diff line number Diff line change
Expand Up @@ -2161,6 +2161,30 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in {
def VDOTQ_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
}

let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
def VMLALB_F16_F8 : VInst<"vmlalb_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
def VMLALT_F16_F8 : VInst<"vmlalt_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;

def VMLALLBB_F32_F8 : VInst<"vmlallbb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
def VMLALLBT_F32_F8 : VInst<"vmlallbt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
def VMLALLTB_F32_F8 : VInst<"vmlalltb_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
def VMLALLTT_F32_F8 : VInst<"vmlalltt_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;

def VMLALB_F16_F8_LANE : VInst<"vmlalb_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
def VMLALB_F16_F8_LANEQ : VInst<"vmlalb_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
def VMLALT_F16_F8_LANE : VInst<"vmlalt_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
def VMLALT_F16_F8_LANEQ : VInst<"vmlalt_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;

def VMLALLBB_F32_F8_LANE : VInst<"vmlallbb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
def VMLALLBB_F32_F8_LANEQ : VInst<"vmlallbb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
def VMLALLBT_F32_F8_LANE : VInst<"vmlallbt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
def VMLALLBT_F32_F8_LANEQ : VInst<"vmlallbt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
def VMLALLTB_F32_F8_LANE : VInst<"vmlalltb_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
def VMLALLTB_F32_F8_LANEQ : VInst<"vmlalltb_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
def VMLALLTT_F32_F8_LANE : VInst<"vmlalltt_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
def VMLALLTT_F32_F8_LANEQ : VInst<"vmlalltt_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
}

let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;
Expand Down
107 changes: 91 additions & 16 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6759,29 +6759,45 @@ Value *CodeGenFunction::EmitNeonCall(Function *F, SmallVectorImpl<Value*> &Ops,
return Builder.CreateCall(F, Ops, name);
}

Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID,
ArrayRef<llvm::Type *> Tys,
SmallVectorImpl<Value *> &Ops,
Value *FPM, const char *name) {
const CallExpr *E, const char *name) {
llvm::Value *FPM =
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr), FPM);
return EmitNeonCall(F, Ops, name);
return EmitNeonCall(CGM.getIntrinsic(IID, Tys), Ops, name);
}

llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
unsigned IID, bool ExtendLane, llvm::Type *RetTy,
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {

const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
RetTy->getPrimitiveSizeInBits();
llvm::Type *Tys[] = {llvm::FixedVectorType::get(RetTy, ElemCount),
Ops[1]->getType()};
if (ExtendLane) {
if (ExtendLaneArg) {
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
Builder.getInt64(0));
}
llvm::Value *FPM =
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
}

llvm::Value *CodeGenFunction::EmitFP8NeonFMLACall(
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {

if (ExtendLaneArg) {
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
Builder.getInt64(0));
}
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
RetTy->getPrimitiveSizeInBits();
return EmitFP8NeonCall(IID, {llvm::FixedVectorType::get(RetTy, ElemCount)},
Ops, E, name);
}

Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
Expand All @@ -6802,9 +6818,7 @@ Value *CodeGenFunction::EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
Tys[1] = llvm::FixedVectorType::get(Int8Ty, 8);
Ops[0] = Builder.CreateExtractVector(Tys[1], Ops[0], Builder.getInt64(0));
}
llvm::Value *FPM =
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
}

// Right-shift a vector by a constant.
Expand Down Expand Up @@ -12779,7 +12793,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,

unsigned Int;
bool ExtractLow = false;
bool ExtendLane = false;
bool ExtendLaneArg = false;
switch (BuiltinID) {
default: return nullptr;
case NEON::BI__builtin_neon_vbsl_v:
Expand Down Expand Up @@ -14054,24 +14068,85 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
Ops, E, "fdot2");
case NEON::BI__builtin_neon_vdot_lane_f16_mf8_fpm:
case NEON::BI__builtin_neon_vdotq_lane_f16_mf8_fpm:
ExtendLane = true;
ExtendLaneArg = true;
LLVM_FALLTHROUGH;
case NEON::BI__builtin_neon_vdot_laneq_f16_mf8_fpm:
case NEON::BI__builtin_neon_vdotq_laneq_f16_mf8_fpm:
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2_lane,
ExtendLane, HalfTy, Ops, E, "fdot2_lane");
ExtendLaneArg, HalfTy, Ops, E, "fdot2_lane");
case NEON::BI__builtin_neon_vdot_f32_mf8_fpm:
case NEON::BI__builtin_neon_vdotq_f32_mf8_fpm:
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4, false,
FloatTy, Ops, E, "fdot4");
case NEON::BI__builtin_neon_vdot_lane_f32_mf8_fpm:
case NEON::BI__builtin_neon_vdotq_lane_f32_mf8_fpm:
ExtendLane = true;
ExtendLaneArg = true;
LLVM_FALLTHROUGH;
case NEON::BI__builtin_neon_vdot_laneq_f32_mf8_fpm:
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
ExtendLaneArg, FloatTy, Ops, E, "fdot4_lane");

case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm:
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb,
{llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
"vmlal");
case NEON::BI__builtin_neon_vmlaltq_f16_mf8_fpm:
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalt,
{llvm::FixedVectorType::get(HalfTy, 8)}, Ops, E,
"vmlal");
case NEON::BI__builtin_neon_vmlallbbq_f32_mf8_fpm:
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbb,
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
"vmlall");
case NEON::BI__builtin_neon_vmlallbtq_f32_mf8_fpm:
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlallbt,
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
"vmlall");
case NEON::BI__builtin_neon_vmlalltbq_f32_mf8_fpm:
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltb,
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
"vmlall");
case NEON::BI__builtin_neon_vmlallttq_f32_mf8_fpm:
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt,
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
"vmlall");
case NEON::BI__builtin_neon_vmlalbq_lane_f16_mf8_fpm:
ExtendLaneArg = true;
LLVM_FALLTHROUGH;
case NEON::BI__builtin_neon_vmlalbq_laneq_f16_mf8_fpm:
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalb_lane,
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
case NEON::BI__builtin_neon_vmlaltq_lane_f16_mf8_fpm:
ExtendLaneArg = true;
LLVM_FALLTHROUGH;
case NEON::BI__builtin_neon_vmlaltq_laneq_f16_mf8_fpm:
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalt_lane,
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
case NEON::BI__builtin_neon_vmlallbbq_lane_f32_mf8_fpm:
ExtendLaneArg = true;
LLVM_FALLTHROUGH;
case NEON::BI__builtin_neon_vmlallbbq_laneq_f32_mf8_fpm:
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbb_lane,
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
case NEON::BI__builtin_neon_vmlallbtq_lane_f32_mf8_fpm:
ExtendLaneArg = true;
LLVM_FALLTHROUGH;
case NEON::BI__builtin_neon_vmlallbtq_laneq_f32_mf8_fpm:
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbt_lane,
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
case NEON::BI__builtin_neon_vmlalltbq_lane_f32_mf8_fpm:
ExtendLaneArg = true;
LLVM_FALLTHROUGH;
case NEON::BI__builtin_neon_vmlalltbq_laneq_f32_mf8_fpm:
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltb_lane,
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
case NEON::BI__builtin_neon_vmlallttq_lane_f32_mf8_fpm:
ExtendLaneArg = true;
LLVM_FALLTHROUGH;
case NEON::BI__builtin_neon_vmlallttq_laneq_f32_mf8_fpm:
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltt_lane,
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
case NEON::BI__builtin_neon_vamin_f16:
case NEON::BI__builtin_neon_vaminq_f16:
case NEON::BI__builtin_neon_vamin_f32:
Expand Down
10 changes: 7 additions & 3 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4692,14 +4692,18 @@ class CodeGenFunction : public CodeGenTypeCache {
SmallVectorImpl<llvm::Value*> &O,
const char *name,
unsigned shift = 0, bool rightshift = false);
llvm::Value *EmitFP8NeonCall(llvm::Function *F,
llvm::Value *EmitFP8NeonCall(unsigned IID, ArrayRef<llvm::Type *> Tys,
SmallVectorImpl<llvm::Value *> &O,
llvm::Value *FPM, const char *name);
const CallExpr *E, const char *name);
llvm::Value *EmitFP8NeonCvtCall(unsigned IID, llvm::Type *Ty0,
llvm::Type *Ty1, bool Extract,
SmallVectorImpl<llvm::Value *> &Ops,
const CallExpr *E, const char *name);
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLane,
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLaneArg,
llvm::Type *RetTy,
SmallVectorImpl<llvm::Value *> &Ops,
const CallExpr *E, const char *name);
llvm::Value *EmitFP8NeonFMLACall(unsigned IID, bool ExtendLaneArg,
llvm::Type *RetTy,
SmallVectorImpl<llvm::Value *> &Ops,
const CallExpr *E, const char *name);
Expand Down
Loading
Loading