Skip to content

Commit e7b61d9

Browse files
committed
[CHERI] Clean up the CHERI capability format support functions to be more ergonomic.
1 parent c0076e2 commit e7b61d9

File tree

9 files changed

+268
-275
lines changed

9 files changed

+268
-275
lines changed

clang/lib/StaticAnalyzer/Checkers/CHERI/SubObjectRepresentabilityChecker.cpp

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,17 @@
2323
#include "clang/StaticAnalyzer/Core/CheckerManager.h"
2424
#include "clang/StaticAnalyzer/Core/PathSensitive/AnalysisManager.h"
2525
#include "llvm/ADT/SmallString.h"
26-
#include "llvm/CHERI/CompressedCapability.h"
26+
#include "llvm/CHERI/CapabilityFormat.h"
2727
#include "llvm/Support/raw_ostream.h"
2828

2929
using namespace clang;
3030
using namespace ento;
3131

3232
namespace {
3333

34-
template <llvm::CompressedCapability::CapabilityFormat>
35-
std::unique_ptr<BugReport> checkFieldImpl(const FieldDecl *D, BugReporter &BR,
36-
const BugType &BT);
34+
std::unique_ptr<BugReport> checkField(const FieldDecl *D, BugReporter &BR,
35+
const BugType &BT,
36+
llvm::CHERICapabilityFormat CapFormat);
3737

3838
class SubObjectRepresentabilityChecker
3939
: public Checker<check::ASTDecl<RecordDecl>, check::ASTCodeBody> {
@@ -49,11 +49,8 @@ class SubObjectRepresentabilityChecker
4949
BugReporter &BR) const;
5050

5151
private:
52-
using CheckFieldFn = std::unique_ptr<BugReport> (*)(const FieldDecl *D,
53-
BugReporter &BR,
54-
const BugType &BT);
55-
56-
CheckFieldFn getCheckFieldFn(ASTContext &ASTCtx) const;
52+
std::optional<llvm::CHERICapabilityFormat>
53+
getCapabilityFormat(ASTContext &ASTCtx) const;
5754
};
5855

5956
} // namespace
@@ -97,9 +94,9 @@ reportExposedFields(const FieldDecl *D, ASTContext &ASTCtx, BugReporter &BR,
9794
return Report;
9895
}
9996

100-
template <llvm::CompressedCapability::CapabilityFormat CapFormat>
101-
std::unique_ptr<BugReport> checkFieldImpl(const FieldDecl *D, BugReporter &BR,
102-
const BugType &BT) {
97+
std::unique_ptr<BugReport> checkField(const FieldDecl *D, BugReporter &BR,
98+
const BugType &BT,
99+
llvm::CHERICapabilityFormat CapFormat) {
103100
QualType T = D->getType();
104101

105102
// If the parent struct is explicitly marked as packed, then don't emit a
@@ -111,9 +108,7 @@ std::unique_ptr<BugReport> checkFieldImpl(const FieldDecl *D, BugReporter &BR,
111108
uint64_t Offset = ASTCtx.getFieldOffset(D) / 8;
112109
if (Offset > 0) {
113110
uint64_t Size = ASTCtx.getTypeSize(T) / 8;
114-
uint64_t ReqAlign =
115-
llvm::CompressedCapability::GetRequiredAlignment(Size, CapFormat)
116-
.value();
111+
uint64_t ReqAlign = CapFormat.getRequiredAlignment(Size).value();
117112
uint64_t CurAlign = 1 << llvm::countr_zero(Offset);
118113
if (CurAlign < ReqAlign) {
119114
/* Emit warning */
@@ -131,9 +126,8 @@ std::unique_ptr<BugReport> checkFieldImpl(const FieldDecl *D, BugReporter &BR,
131126
uint64_t OffsetToAlign = Offset % ReqAlign;
132127
uint64_t Base = Offset - OffsetToAlign;
133128
uint64_t AlignedSize = Size + OffsetToAlign;
134-
uint64_t TailPadding = static_cast<uint64_t>(
135-
llvm::CompressedCapability::GetRequiredTailPadding(AlignedSize,
136-
CapFormat));
129+
uint64_t TailPadding =
130+
static_cast<uint64_t>(CapFormat.getRequiredTailPadding(AlignedSize));
137131
uint64_t Top = Base + AlignedSize + TailPadding;
138132
OS << " Current bounds: " << Base << "-" << Top;
139133

@@ -156,40 +150,40 @@ std::unique_ptr<BugReport> checkFieldImpl(const FieldDecl *D, BugReporter &BR,
156150

157151
} // namespace
158152

159-
SubObjectRepresentabilityChecker::CheckFieldFn
160-
SubObjectRepresentabilityChecker::getCheckFieldFn(ASTContext &ASTCtx) const {
153+
std::optional<llvm::CHERICapabilityFormat>
154+
SubObjectRepresentabilityChecker::getCapabilityFormat(
155+
ASTContext &ASTCtx) const {
161156
const TargetInfo &TI = ASTCtx.getTargetInfo();
162157
if (!TI.areAllPointersCapabilities())
163-
return nullptr;
158+
return std::nullopt;
164159

165160
const auto &T = TI.getTriple();
166161
if (T.getArch() == llvm::Triple::riscv32 && TI.hasFeature("xcheriot")) {
167-
return &checkFieldImpl<llvm::CompressedCapability::Cheriot64>;
162+
return llvm::CHERICapabilityFormat::Cheriot64;
168163
}
169164

170-
static constexpr std::array CheckFieldFnMap = {
171-
std::make_pair(llvm::Triple::mips,
172-
&checkFieldImpl<llvm::CompressedCapability::Cheri64>),
165+
static constexpr std::array CapabilityFormatMap = {
166+
std::make_pair(llvm::Triple::mips, &llvm::CHERICapabilityFormat::Cheri64),
173167
std::make_pair(llvm::Triple::mips64,
174-
&checkFieldImpl<llvm::CompressedCapability::Cheri128>),
168+
&llvm::CHERICapabilityFormat::Cheri128),
175169
std::make_pair(llvm::Triple::riscv32,
176-
&checkFieldImpl<llvm::CompressedCapability::Cheri64>),
170+
&llvm::CHERICapabilityFormat::Cheri64),
177171
std::make_pair(llvm::Triple::riscv64,
178-
&checkFieldImpl<llvm::CompressedCapability::Cheri128>),
172+
&llvm::CHERICapabilityFormat::Cheri128),
179173
};
180174

181-
auto It = std::find_if(CheckFieldFnMap.begin(), CheckFieldFnMap.end(),
175+
auto It = std::find_if(CapabilityFormatMap.begin(), CapabilityFormatMap.end(),
182176
[A = T.getArch()](auto p) { return p.first == A; });
183-
if (It == CheckFieldFnMap.end())
184-
return nullptr;
185-
return It->second;
177+
if (It == CapabilityFormatMap.end())
178+
return std::nullopt;
179+
return *It->second;
186180
}
187181

188182
void SubObjectRepresentabilityChecker::checkASTDecl(const RecordDecl *R,
189183
AnalysisManager &mgr,
190184
BugReporter &BR) const {
191-
CheckFieldFn checkField = getCheckFieldFn(mgr.getASTContext());
192-
if (!checkField)
185+
auto CapFormat = getCapabilityFormat(mgr.getASTContext());
186+
if (!CapFormat)
193187
return; // skip this target
194188

195189
if (!R->isCompleteDefinition() || R->isDependentType())
@@ -199,7 +193,7 @@ void SubObjectRepresentabilityChecker::checkASTDecl(const RecordDecl *R,
199193
return;
200194

201195
for (FieldDecl *D : R->fields()) {
202-
auto Report = checkField(D, BR, BT_1);
196+
auto Report = checkField(D, BR, BT_1, *CapFormat);
203197
if (Report)
204198
BR.emitReport(std::move(Report));
205199
}
@@ -208,8 +202,8 @@ void SubObjectRepresentabilityChecker::checkASTDecl(const RecordDecl *R,
208202
void SubObjectRepresentabilityChecker::checkASTCodeBody(const Decl *D,
209203
AnalysisManager &mgr,
210204
BugReporter &BR) const {
211-
CheckFieldFn checkField = getCheckFieldFn(mgr.getASTContext());
212-
if (!checkField)
205+
auto CapFormat = getCapabilityFormat(mgr.getASTContext());
206+
if (!CapFormat)
213207
return; // skip this target
214208

215209
using namespace ast_matchers;
@@ -228,7 +222,7 @@ void SubObjectRepresentabilityChecker::checkASTCodeBody(const Decl *D,
228222
if (const MemberExpr *ME = Match.getNodeAs<MemberExpr>("member")) {
229223
ValueDecl *VD = ME->getMemberDecl();
230224
if (FieldDecl *FD = dyn_cast<FieldDecl>(VD)) {
231-
auto Report = checkField(FD, BR, BT_2);
225+
auto Report = checkField(FD, BR, BT_2, *CapFormat);
232226
if (Report) {
233227
PathDiagnosticLocation LN = PathDiagnosticLocation::createBegin(
234228
CE, BR.getSourceManager(), mgr.getAnalysisDeclContext(D));
@@ -242,7 +236,7 @@ void SubObjectRepresentabilityChecker::checkASTCodeBody(const Decl *D,
242236
if (const MemberExpr *ME = Match.getNodeAs<MemberExpr>("member")) {
243237
ValueDecl *VD = ME->getMemberDecl();
244238
if (FieldDecl *FD = dyn_cast<FieldDecl>(VD)) {
245-
auto Report = checkField(FD, BR, BT_2);
239+
auto Report = checkField(FD, BR, BT_2, *CapFormat);
246240
if (Report) {
247241
PathDiagnosticLocation LN = PathDiagnosticLocation::createBegin(
248242
UO, BR.getSourceManager(), mgr.getAnalysisDeclContext(D));

lld/ELF/Arch/RISCV.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include "Symbols.h"
1414
#include "SyntheticSections.h"
1515
#include "Target.h"
16-
#include "llvm/CHERI/CompressedCapability.h"
16+
#include "llvm/CHERI/CapabilityFormat.h"
1717
#include "llvm/Support/ELFAttributes.h"
1818
#include "llvm/Support/LEB128.h"
1919
#include "llvm/Support/RISCVAttributeParser.h"
@@ -761,14 +761,13 @@ static void tlsdescToLe(uint8_t *loc, const Relocation &rel, uint64_t val) {
761761
}
762762

763763
uint64_t RISCV::cheriRequiredAlignment(uint64_t size) const {
764-
auto CapFormat = llvm::CompressedCapability::Cheri128;
764+
auto CapFormat = llvm::CHERICapabilityFormat::Cheri128;
765765
if (ctx.arg.isCheriot)
766-
CapFormat = llvm::CompressedCapability::Cheriot64;
766+
CapFormat = llvm::CHERICapabilityFormat::Cheriot64;
767767
else if (!ctx.arg.is64)
768-
CapFormat = llvm::CompressedCapability::Cheri64;
768+
CapFormat = llvm::CHERICapabilityFormat::Cheri64;
769769

770-
return llvm::CompressedCapability::GetRequiredAlignment(size, CapFormat)
771-
.value();
770+
return CapFormat.getRequiredAlignment(size).value();
772771
}
773772

774773
void RISCV::relocateAlloc(InputSectionBase &sec, uint8_t *buf) const {
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//===--- CHERICompressedCapability.h ----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_COMPRESSED_CAPABILITY_H
10+
#define LLVM_COMPRESSED_CAPABILITY_H
11+
12+
#include "llvm/ADT/ArrayRef.h"
13+
#include "llvm/MC/MCTargetOptions.h"
14+
#include "llvm/Support/Alignment.h"
15+
16+
#include <algorithm>
17+
#include <cstdint>
18+
19+
namespace llvm {
20+
21+
class CHERICapabilityFormat {
22+
constexpr CHERICapabilityFormat(uint64_t AM,
23+
ArrayRef<std::pair<uint64_t, uint64_t>> L)
24+
: AddressMask(AM), LUT(L) {}
25+
26+
uint64_t AddressMask;
27+
ArrayRef<std::pair<uint64_t, uint64_t>> LUT;
28+
29+
public:
30+
inline uint64_t getAddressMask() const { return AddressMask; }
31+
uint64_t getAlignmentMask(uint64_t Length) const {
32+
auto el = std::find_if(LUT.begin(), LUT.end(),
33+
[=](const auto &p) { return Length <= p.first; });
34+
assert(el != LUT.end());
35+
return el->second;
36+
}
37+
38+
inline uint64_t getRepresentableLength(uint64_t Length) const {
39+
uint64_t Mask = getAlignmentMask(Length);
40+
return (Length + ~Mask) & Mask;
41+
}
42+
43+
inline Align getRequiredAlignment(uint64_t Length) const {
44+
return Align((~getAlignmentMask(Length) + 1) & getAddressMask());
45+
}
46+
47+
inline TailPaddingAmount getRequiredTailPadding(uint64_t Length) const {
48+
return static_cast<TailPaddingAmount>(
49+
llvm::alignTo(Length, getRequiredAlignment(Length)) - Length);
50+
}
51+
52+
static const CHERICapabilityFormat Cheriot64;
53+
static const CHERICapabilityFormat Cheri64;
54+
static const CHERICapabilityFormat Cheri128;
55+
};
56+
57+
} // namespace llvm
58+
59+
#endif

0 commit comments

Comments
 (0)