diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp index 89e2eace9120b..44b4e3ece8ee2 100644 --- a/clang/lib/CodeGen/CGCall.cpp +++ b/clang/lib/CodeGen/CGCall.cpp @@ -5101,6 +5101,9 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo, const FunctionDecl *CalleeDecl = dyn_cast_or_null(TargetDecl); CGM.getTargetCodeGenInfo().checkFunctionCallABI(CGM, Loc, CallerDecl, CalleeDecl, CallArgs, RetTy); + // 0. Allow the target to emit an additional prolog for the function call + CGM.getTargetCodeGenInfo().emitFunctionCallProlog(Builder, CallerDecl, + CalleeDecl); // 1. Set up the arguments. diff --git a/clang/lib/CodeGen/TargetInfo.h b/clang/lib/CodeGen/TargetInfo.h index ab3142bdea684..5ba6fc4acc9f0 100644 --- a/clang/lib/CodeGen/TargetInfo.h +++ b/clang/lib/CodeGen/TargetInfo.h @@ -443,6 +443,10 @@ class TargetCodeGenInfo { return nullptr; } + virtual void emitFunctionCallProlog(CGBuilderTy &Builder, + const FunctionDecl *Caller, + const FunctionDecl *Callee) const {} + // Set the Branch Protection Attributes of the Function accordingly to the // BPI. Remove attributes that contradict with current BPI. static void diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp index 7db67ecba07c8..4cfafcc6747fb 100644 --- a/clang/lib/CodeGen/Targets/AArch64.cpp +++ b/clang/lib/CodeGen/Targets/AArch64.cpp @@ -10,6 +10,7 @@ #include "TargetInfo.h" #include "clang/AST/Decl.h" #include "clang/Basic/DiagnosticFrontend.h" +#include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/TargetParser/AArch64TargetParser.h" using namespace clang; @@ -181,6 +182,9 @@ class AArch64TargetCodeGenInfo : public TargetCodeGenInfo { bool wouldInliningViolateFunctionCallABI( const FunctionDecl *Caller, const FunctionDecl *Callee) const override; + void emitFunctionCallProlog(CGBuilderTy &Builder, const FunctionDecl *Caller, + const FunctionDecl *Callee) const override; + private: // Diagnose calls between functions with incompatible Streaming SVE // attributes. @@ -1275,6 +1279,31 @@ bool AArch64TargetCodeGenInfo::wouldInliningViolateFunctionCallABI( GetArmSMEInlinability(Caller, Callee) != ArmSMEInlinability::Ok; } +void AArch64TargetCodeGenInfo::emitFunctionCallProlog( + CGBuilderTy &Builder, const FunctionDecl *Caller, + const FunctionDecl *Callee) const { + const AArch64ABIInfo &ABIInfo = getABIInfo(); + const TargetInfo &TI = ABIInfo.getContext().getTargetInfo(); + + if (!TI.hasFeature("sme")) + return; + + if (!Callee || !isStreamingCompatible(Callee)) + return; + + if (const auto *FPT = Caller->getType()->getAs()) { + unsigned SMEAttrs = FPT->getAArch64SMEAttributes(); + if (!(SMEAttrs & FunctionType::SME_PStateSMCompatibleMask)) { + bool IsStreaming = SMEAttrs & FunctionType::SME_PStateSMEnabledMask; + llvm::Value *Call = Builder.CreateIntrinsic( + llvm::Intrinsic::aarch64_sme_in_streaming_mode, {}, {}); + if (!IsStreaming) + Call = Builder.CreateNot(Call); + Builder.CreateAssumption(Call); + } + } +} + void AArch64ABIInfo::appendAttributeMangling(TargetClonesAttr *Attr, unsigned Index, raw_ostream &Out) const { diff --git a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c index 72f2d17fc6dc1..80af80682d194 100644 --- a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c +++ b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c @@ -29,7 +29,6 @@ bool test_in_streaming_mode_streaming_compatible(void) __arm_streaming_compatibl // CPP-CHECK-NEXT: ret i1 true // bool test_in_streaming_mode_streaming(void) __arm_streaming { -// return __arm_in_streaming_mode(); } @@ -47,12 +46,12 @@ bool test_in_streaming_mode_non_streaming(void) { // CHECK-LABEL: @test_za_disable( // CHECK-NEXT: entry: -// CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR7:[0-9]+]] +// CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR8:[0-9]+]] // CHECK-NEXT: ret void // // CPP-CHECK-LABEL: @_Z15test_za_disablev( // CPP-CHECK-NEXT: entry: -// CPP-CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR7:[0-9]+]] +// CPP-CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR8:[0-9]+]] // CPP-CHECK-NEXT: ret void // void test_za_disable(void) __arm_streaming_compatible { @@ -61,14 +60,14 @@ void test_za_disable(void) __arm_streaming_compatible { // CHECK-LABEL: @test_has_sme( // CHECK-NEXT: entry: -// CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR7]] +// CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR8]] // CHECK-NEXT: [[TMP1:%.*]] = extractvalue { i64, i64 } [[TMP0]], 0 // CHECK-NEXT: [[TOBOOL_I:%.*]] = icmp slt i64 [[TMP1]], 0 // CHECK-NEXT: ret i1 [[TOBOOL_I]] // // CPP-CHECK-LABEL: @_Z12test_has_smev( // CPP-CHECK-NEXT: entry: -// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR7]] +// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR8]] // CPP-CHECK-NEXT: [[TMP1:%.*]] = extractvalue { i64, i64 } [[TMP0]], 0 // CPP-CHECK-NEXT: [[TOBOOL_I:%.*]] = icmp slt i64 [[TMP1]], 0 // CPP-CHECK-NEXT: ret i1 [[TOBOOL_I]] @@ -91,12 +90,12 @@ void test_svundef_za(void) __arm_streaming_compatible __arm_out("za") { // CHECK-LABEL: @test_sc_memcpy( // CHECK-NEXT: entry: -// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]] +// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]] // CHECK-NEXT: ret ptr [[CALL]] // // CPP-CHECK-LABEL: @_Z14test_sc_memcpyPvPKvm( // CPP-CHECK-NEXT: entry: -// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]] +// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]] // CPP-CHECK-NEXT: ret ptr [[CALL]] // void *test_sc_memcpy(void *dest, const void *src, size_t n) __arm_streaming_compatible { @@ -105,12 +104,12 @@ void *test_sc_memcpy(void *dest, const void *src, size_t n) __arm_streaming_comp // CHECK-LABEL: @test_sc_memmove( // CHECK-NEXT: entry: -// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]] +// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]] // CHECK-NEXT: ret ptr [[CALL]] // // CPP-CHECK-LABEL: @_Z15test_sc_memmovePvPKvm( // CPP-CHECK-NEXT: entry: -// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]] +// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]] // CPP-CHECK-NEXT: ret ptr [[CALL]] // void *test_sc_memmove(void *dest, const void *src, size_t n) __arm_streaming_compatible { @@ -119,12 +118,12 @@ void *test_sc_memmove(void *dest, const void *src, size_t n) __arm_streaming_com // CHECK-LABEL: @test_sc_memset( // CHECK-NEXT: entry: -// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]] +// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]] // CHECK-NEXT: ret ptr [[CALL]] // // CPP-CHECK-LABEL: @_Z14test_sc_memsetPvim( // CPP-CHECK-NEXT: entry: -// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]] +// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]] // CPP-CHECK-NEXT: ret ptr [[CALL]] // void *test_sc_memset(void *s, int c, size_t n) __arm_streaming_compatible { @@ -133,12 +132,12 @@ void *test_sc_memset(void *s, int c, size_t n) __arm_streaming_compatible { // CHECK-LABEL: @test_sc_memchr( // CHECK-NEXT: entry: -// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]] +// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]] // CHECK-NEXT: ret ptr [[CALL]] // // CPP-CHECK-LABEL: @_Z14test_sc_memchrPvim( // CPP-CHECK-NEXT: entry: -// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]] +// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]] // CPP-CHECK-NEXT: ret ptr [[CALL]] // void *test_sc_memchr(void *s, int c, size_t n) __arm_streaming_compatible {