Skip to content

Commit cb8936f

Browse files
committed
Fix interaction with streaming-compatible functions.
1 parent c28804a commit cb8936f

File tree

13 files changed

+92
-39
lines changed

13 files changed

+92
-39
lines changed

clang/include/clang/Basic/TargetInfo.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1034,9 +1034,16 @@ class TargetInfo : public TransferrableTargetInfo,
10341034
/// set of primary and secondary targets.
10351035
virtual llvm::SmallVector<Builtin::InfosShard> getTargetBuiltins() const = 0;
10361036

1037+
enum class ArmStreamingKind {
1038+
NotStreaming,
1039+
StreamingCompatible,
1040+
Streaming,
1041+
};
1042+
10371043
/// Returns target-specific min and max values VScale_Range.
10381044
virtual std::optional<std::pair<unsigned, unsigned>>
1039-
getVScaleRange(const LangOptions &LangOpts, bool IsArmStreamingFunction,
1045+
getVScaleRange(const LangOptions &LangOpts,
1046+
ArmStreamingKind IsArmStreamingFunction,
10401047
llvm::StringMap<bool> *FeatureMap = nullptr) const {
10411048
return std::nullopt;
10421049
}

clang/lib/AST/ASTContext.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10446,8 +10446,8 @@ bool ASTContext::areCompatibleVectorTypes(QualType FirstVec,
1044610446
/// getRVVTypeSize - Return RVV vector register size.
1044710447
static uint64_t getRVVTypeSize(ASTContext &Context, const BuiltinType *Ty) {
1044810448
assert(Ty->isRVVVLSBuiltinType() && "Invalid RVV Type");
10449-
auto VScale =
10450-
Context.getTargetInfo().getVScaleRange(Context.getLangOpts(), false);
10449+
auto VScale = Context.getTargetInfo().getVScaleRange(
10450+
Context.getLangOpts(), TargetInfo::ArmStreamingKind::NotStreaming);
1045110451
if (!VScale)
1045210452
return 0;
1045310453

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4273,7 +4273,8 @@ void CXXNameMangler::mangleRISCVFixedRVVVectorType(const VectorType *T) {
42734273

42744274
// Apend the LMUL suffix.
42754275
auto VScale = getASTContext().getTargetInfo().getVScaleRange(
4276-
getASTContext().getLangOpts(), false);
4276+
getASTContext().getLangOpts(),
4277+
TargetInfo::ArmStreamingKind::NotStreaming);
42774278
unsigned VLen = VScale->first * llvm::RISCV::RVVBitsPerBlock;
42784279

42794280
if (T->getVectorKind() == VectorKind::RVVFixedLengthData) {

clang/lib/Basic/Targets/AArch64.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -820,13 +820,14 @@ AArch64TargetInfo::getTargetBuiltins() const {
820820

821821
std::optional<std::pair<unsigned, unsigned>>
822822
AArch64TargetInfo::getVScaleRange(const LangOptions &LangOpts,
823-
bool IsArmStreamingFunction,
823+
ArmStreamingKind IsArmStreamingFunction,
824824
llvm::StringMap<bool> *FeatureMap) const {
825-
if (!IsArmStreamingFunction && (LangOpts.VScaleMin || LangOpts.VScaleMax))
825+
if (IsArmStreamingFunction == ArmStreamingKind::NotStreaming &&
826+
(LangOpts.VScaleMin || LangOpts.VScaleMax))
826827
return std::pair<unsigned, unsigned>(
827828
LangOpts.VScaleMin ? LangOpts.VScaleMin : 1, LangOpts.VScaleMax);
828829

829-
if (IsArmStreamingFunction &&
830+
if (IsArmStreamingFunction == ArmStreamingKind::Streaming &&
830831
(LangOpts.VScaleStreamingMin || LangOpts.VScaleStreamingMax))
831832
return std::pair<unsigned, unsigned>(
832833
LangOpts.VScaleStreamingMin ? LangOpts.VScaleStreamingMin : 1,
@@ -835,7 +836,7 @@ AArch64TargetInfo::getVScaleRange(const LangOptions &LangOpts,
835836
if (hasFeature("sve") || (FeatureMap && (FeatureMap->lookup("sve"))))
836837
return std::pair<unsigned, unsigned>(1, 16);
837838

838-
if (IsArmStreamingFunction &&
839+
if (IsArmStreamingFunction == ArmStreamingKind::Streaming &&
839840
(hasFeature("sme") || (FeatureMap && (FeatureMap->lookup("sme")))))
840841
return std::pair<unsigned, unsigned>(1, 16);
841842

clang/lib/Basic/Targets/AArch64.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ class LLVM_LIBRARY_VISIBILITY AArch64TargetInfo : public TargetInfo {
197197
llvm::SmallVector<Builtin::InfosShard> getTargetBuiltins() const override;
198198

199199
std::optional<std::pair<unsigned, unsigned>>
200-
getVScaleRange(const LangOptions &LangOpts, bool IsArmStreamingFunction,
200+
getVScaleRange(const LangOptions &LangOpts,
201+
ArmStreamingKind IsArmStreamingFunction,
201202
llvm::StringMap<bool> *FeatureMap = nullptr) const override;
202203
bool doesFeatureAffectCodeGen(StringRef Name) const override;
203204
bool validateCpuSupports(StringRef FeatureStr) const override;

clang/lib/Basic/Targets/RISCV.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ void RISCVTargetInfo::getTargetDefines(const LangOptions &Opts,
222222
// Currently we support the v1.0 RISC-V V intrinsics.
223223
Builder.defineMacro("__riscv_v_intrinsic", Twine(getVersionValue(1, 0)));
224224

225-
auto VScale = getVScaleRange(Opts, false);
225+
auto VScale = getVScaleRange(Opts, ArmStreamingKind::NotStreaming);
226226
if (VScale && VScale->first && VScale->first == VScale->second)
227227
Builder.defineMacro("__riscv_v_fixed_vlen",
228228
Twine(VScale->first * llvm::RISCV::RVVBitsPerBlock));
@@ -367,7 +367,7 @@ bool RISCVTargetInfo::initFeatureMap(
367367

368368
std::optional<std::pair<unsigned, unsigned>>
369369
RISCVTargetInfo::getVScaleRange(const LangOptions &LangOpts,
370-
bool IsArmStreamingFunction,
370+
ArmStreamingKind IsArmStreamingFunction,
371371
llvm::StringMap<bool> *FeatureMap) const {
372372
// RISCV::RVVBitsPerBlock is 64.
373373
unsigned VScaleMin = ISAInfo->getMinVLen() / llvm::RISCV::RVVBitsPerBlock;

clang/lib/Basic/Targets/RISCV.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ class RISCVTargetInfo : public TargetInfo {
9999
const std::vector<std::string> &FeaturesVec) const override;
100100

101101
std::optional<std::pair<unsigned, unsigned>>
102-
getVScaleRange(const LangOptions &LangOpts, bool IsArmStreamingFunction,
102+
getVScaleRange(const LangOptions &LangOpts,
103+
ArmStreamingKind IsArmStreamingFunction,
103104
llvm::StringMap<bool> *FeatureMap = nullptr) const override;
104105

105106
bool hasFeature(StringRef Feature) const override;

clang/lib/CodeGen/CodeGenFunction.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,10 +1108,16 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
11081108

11091109
// Add vscale_range attribute if appropriate.
11101110
llvm::StringMap<bool> FeatureMap;
1111-
bool IsArmStreaming = false;
1111+
auto IsArmStreaming = TargetInfo::ArmStreamingKind::NotStreaming;
11121112
if (FD) {
11131113
getContext().getFunctionFeatureMap(FeatureMap, FD);
1114-
IsArmStreaming = IsArmStreamingFunction(FD, true);
1114+
if (const auto *T = FD->getType()->getAs<FunctionProtoType>())
1115+
if (T->getAArch64SMEAttributes() &
1116+
FunctionType::SME_PStateSMCompatibleMask)
1117+
IsArmStreaming = TargetInfo::ArmStreamingKind::StreamingCompatible;
1118+
1119+
if (IsArmStreamingFunction(FD, true))
1120+
IsArmStreaming = TargetInfo::ArmStreamingKind::Streaming;
11151121
}
11161122
std::optional<std::pair<unsigned, unsigned>> VScaleRange =
11171123
getContext().getTargetInfo().getVScaleRange(getLangOpts(), IsArmStreaming,

clang/lib/CodeGen/Targets/RISCV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty, unsigned ABIVLen) const {
544544
assert(VT->getElementType()->isBuiltinType() && "expected builtin type!");
545545

546546
auto VScale = getContext().getTargetInfo().getVScaleRange(
547-
getContext().getLangOpts(), false);
547+
getContext().getLangOpts(), TargetInfo::ArmStreamingKind::NotStreaming);
548548

549549
unsigned NumElts = VT->getNumElements();
550550
llvm::Type *EltType = llvm::Type::getInt1Ty(getVMContext());

clang/lib/Sema/SemaARM.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,9 +1423,18 @@ static uint64_t getSVETypeSize(ASTContext &Context, const BuiltinType *Ty,
14231423

14241424
bool SemaARM::areCompatibleSveTypes(QualType FirstType, QualType SecondType) {
14251425
bool IsStreaming = false;
1426-
if (const FunctionDecl *FD = SemaRef.getCurFunctionDecl(/*AllowLambda=*/true))
1426+
if (const FunctionDecl *FD =
1427+
SemaRef.getCurFunctionDecl(/*AllowLambda=*/true)) {
1428+
// For streaming-compatible functions, we don't know vector length.
1429+
if (const auto *T = FD->getType()->getAs<FunctionProtoType>())
1430+
if (T->getAArch64SMEAttributes() &
1431+
FunctionType::SME_PStateSMCompatibleMask)
1432+
return false;
1433+
14271434
if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true))
14281435
IsStreaming = true;
1436+
}
1437+
14291438
auto IsValidCast = [&](QualType FirstType, QualType SecondType) {
14301439
if (const auto *BT = FirstType->getAs<BuiltinType>()) {
14311440
if (const auto *VT = SecondType->getAs<VectorType>()) {
@@ -1455,9 +1464,17 @@ bool SemaARM::areCompatibleSveTypes(QualType FirstType, QualType SecondType) {
14551464
bool SemaARM::areLaxCompatibleSveTypes(QualType FirstType,
14561465
QualType SecondType) {
14571466
bool IsStreaming = false;
1458-
if (const FunctionDecl *FD = SemaRef.getCurFunctionDecl(/*AllowLambda=*/true))
1467+
if (const FunctionDecl *FD =
1468+
SemaRef.getCurFunctionDecl(/*AllowLambda=*/true)) {
1469+
// For streaming-compatible functions, we don't know vector length.
1470+
if (const auto *T = FD->getType()->getAs<FunctionProtoType>())
1471+
if (T->getAArch64SMEAttributes() &
1472+
FunctionType::SME_PStateSMCompatibleMask)
1473+
return false;
1474+
14591475
if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true))
14601476
IsStreaming = true;
1477+
}
14611478

14621479
auto IsLaxCompatible = [&](QualType FirstType, QualType SecondType) {
14631480
const auto *BT = FirstType->getAs<BuiltinType>();

0 commit comments

Comments
 (0)