Skip to content

Commit 0fddaf0

Browse files
authored
[Clang] Refactor allocation type inference logic (#163636)
Refactor the logic for inferring allocated types out of `CodeGen` and into a new reusable component in `clang/AST/InferAlloc.h`. This is a preparatory step for implementing `__builtin_infer_alloc_token`. By moving the type inference heuristics into a common place, it can be shared between the existing allocation-call instrumentation and the new builtin's implementation.
1 parent ae6cb98 commit 0fddaf0

File tree

6 files changed

+267
-178
lines changed

6 files changed

+267
-178
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===--- InferAlloc.h - Allocation type inference ---------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines interfaces for allocation-related type inference.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_CLANG_AST_INFERALLOC_H
14+
#define LLVM_CLANG_AST_INFERALLOC_H
15+
16+
#include "clang/AST/ASTContext.h"
17+
#include "clang/AST/Expr.h"
18+
#include "llvm/Support/AllocToken.h"
19+
#include <optional>
20+
21+
namespace clang {
22+
namespace infer_alloc {
23+
24+
/// Infer the possible allocated type from an allocation call expression.
25+
QualType inferPossibleType(const CallExpr *E, const ASTContext &Ctx,
26+
const CastExpr *CastE);
27+
28+
/// Get the information required for construction of an allocation token ID.
29+
std::optional<llvm::AllocTokenMetadata>
30+
getAllocTokenMetadata(QualType T, const ASTContext &Ctx);
31+
32+
} // namespace infer_alloc
33+
} // namespace clang
34+
35+
#endif // LLVM_CLANG_AST_INFERALLOC_H

clang/lib/AST/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ add_clang_library(clangAST
6666
ExternalASTMerger.cpp
6767
ExternalASTSource.cpp
6868
FormatString.cpp
69+
InferAlloc.cpp
6970
InheritViz.cpp
7071
ByteCode/BitcastBuffer.cpp
7172
ByteCode/ByteCodeEmitter.cpp

clang/lib/AST/InferAlloc.cpp

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
//===--- InferAlloc.cpp - Allocation type inference -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements allocation-related type inference.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "clang/AST/InferAlloc.h"
14+
#include "clang/AST/ASTContext.h"
15+
#include "clang/AST/Decl.h"
16+
#include "clang/AST/DeclCXX.h"
17+
#include "clang/AST/Expr.h"
18+
#include "clang/AST/Type.h"
19+
#include "clang/Basic/IdentifierTable.h"
20+
#include "llvm/ADT/SmallPtrSet.h"
21+
22+
using namespace clang;
23+
using namespace infer_alloc;
24+
25+
static bool
26+
typeContainsPointer(QualType T,
27+
llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD,
28+
bool &IncompleteType) {
29+
QualType CanonicalType = T.getCanonicalType();
30+
if (CanonicalType->isPointerType())
31+
return true; // base case
32+
33+
// Look through typedef chain to check for special types.
34+
for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>();
35+
CurrentT = TT->getDecl()->getUnderlyingType()) {
36+
const IdentifierInfo *II = TT->getDecl()->getIdentifier();
37+
// Special Case: Syntactically uintptr_t is not a pointer; semantically,
38+
// however, very likely used as such. Therefore, classify uintptr_t as a
39+
// pointer, too.
40+
if (II && II->isStr("uintptr_t"))
41+
return true;
42+
}
43+
44+
// The type is an array; check the element type.
45+
if (const ArrayType *AT = dyn_cast<ArrayType>(CanonicalType))
46+
return typeContainsPointer(AT->getElementType(), VisitedRD, IncompleteType);
47+
// The type is a struct, class, or union.
48+
if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) {
49+
if (!RD->isCompleteDefinition()) {
50+
IncompleteType = true;
51+
return false;
52+
}
53+
if (!VisitedRD.insert(RD).second)
54+
return false; // already visited
55+
// Check all fields.
56+
for (const FieldDecl *Field : RD->fields()) {
57+
if (typeContainsPointer(Field->getType(), VisitedRD, IncompleteType))
58+
return true;
59+
}
60+
// For C++ classes, also check base classes.
61+
if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
62+
// Polymorphic types require a vptr.
63+
if (CXXRD->isDynamicClass())
64+
return true;
65+
for (const CXXBaseSpecifier &Base : CXXRD->bases()) {
66+
if (typeContainsPointer(Base.getType(), VisitedRD, IncompleteType))
67+
return true;
68+
}
69+
}
70+
}
71+
return false;
72+
}
73+
74+
/// Infer type from a simple sizeof expression.
75+
static QualType inferTypeFromSizeofExpr(const Expr *E) {
76+
const Expr *Arg = E->IgnoreParenImpCasts();
77+
if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) {
78+
if (UET->getKind() == UETT_SizeOf) {
79+
if (UET->isArgumentType())
80+
return UET->getArgumentTypeInfo()->getType();
81+
else
82+
return UET->getArgumentExpr()->getType();
83+
}
84+
}
85+
return QualType();
86+
}
87+
88+
/// Infer type from an arithmetic expression involving a sizeof. For example:
89+
///
90+
/// malloc(sizeof(MyType) + padding); // infers 'MyType'
91+
/// malloc(sizeof(MyType) * 32); // infers 'MyType'
92+
/// malloc(32 * sizeof(MyType)); // infers 'MyType'
93+
/// malloc(sizeof(MyType) << 1); // infers 'MyType'
94+
/// ...
95+
///
96+
/// More complex arithmetic expressions are supported, but are a heuristic, e.g.
97+
/// when considering allocations for structs with flexible array members:
98+
///
99+
/// malloc(sizeof(HasFlexArray) + sizeof(int) * 32); // infers 'HasFlexArray'
100+
///
101+
static QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) {
102+
const Expr *Arg = E->IgnoreParenImpCasts();
103+
// The argument is a lone sizeof expression.
104+
if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull())
105+
return T;
106+
if (const auto *BO = dyn_cast<BinaryOperator>(Arg)) {
107+
// Argument is an arithmetic expression. Cover common arithmetic patterns
108+
// involving sizeof.
109+
switch (BO->getOpcode()) {
110+
case BO_Add:
111+
case BO_Div:
112+
case BO_Mul:
113+
case BO_Shl:
114+
case BO_Shr:
115+
case BO_Sub:
116+
if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getLHS());
117+
!T.isNull())
118+
return T;
119+
if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getRHS());
120+
!T.isNull())
121+
return T;
122+
break;
123+
default:
124+
break;
125+
}
126+
}
127+
return QualType();
128+
}
129+
130+
/// If the expression E is a reference to a variable, infer the type from a
131+
/// variable's initializer if it contains a sizeof. Beware, this is a heuristic
132+
/// and ignores if a variable is later reassigned. For example:
133+
///
134+
/// size_t my_size = sizeof(MyType);
135+
/// void *x = malloc(my_size); // infers 'MyType'
136+
///
137+
static QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) {
138+
const Expr *Arg = E->IgnoreParenImpCasts();
139+
if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) {
140+
if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
141+
if (const Expr *Init = VD->getInit())
142+
return inferPossibleTypeFromArithSizeofExpr(Init);
143+
}
144+
}
145+
return QualType();
146+
}
147+
148+
/// Deduces the allocated type by checking if the allocation call's result
149+
/// is immediately used in a cast expression. For example:
150+
///
151+
/// MyType *x = (MyType *)malloc(4096); // infers 'MyType'
152+
///
153+
static QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE,
154+
const CastExpr *CastE) {
155+
if (!CastE)
156+
return QualType();
157+
QualType PtrType = CastE->getType();
158+
if (PtrType->isPointerType())
159+
return PtrType->getPointeeType();
160+
return QualType();
161+
}
162+
163+
QualType infer_alloc::inferPossibleType(const CallExpr *E,
164+
const ASTContext &Ctx,
165+
const CastExpr *CastE) {
166+
QualType AllocType;
167+
// First check arguments.
168+
for (const Expr *Arg : E->arguments()) {
169+
AllocType = inferPossibleTypeFromArithSizeofExpr(Arg);
170+
if (AllocType.isNull())
171+
AllocType = inferPossibleTypeFromVarInitSizeofExpr(Arg);
172+
if (!AllocType.isNull())
173+
break;
174+
}
175+
// Then check later casts.
176+
if (AllocType.isNull())
177+
AllocType = inferPossibleTypeFromCastExpr(E, CastE);
178+
return AllocType;
179+
}
180+
181+
std::optional<llvm::AllocTokenMetadata>
182+
infer_alloc::getAllocTokenMetadata(QualType T, const ASTContext &Ctx) {
183+
llvm::AllocTokenMetadata ATMD;
184+
185+
// Get unique type name.
186+
PrintingPolicy Policy(Ctx.getLangOpts());
187+
Policy.SuppressTagKeyword = true;
188+
Policy.FullyQualifiedName = true;
189+
llvm::raw_svector_ostream TypeNameOS(ATMD.TypeName);
190+
T.getCanonicalType().print(TypeNameOS, Policy);
191+
192+
// Check if QualType contains a pointer. Implements a simple DFS to
193+
// recursively check if a type contains a pointer type.
194+
llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD;
195+
bool IncompleteType = false;
196+
ATMD.ContainsPointer = typeContainsPointer(T, VisitedRD, IncompleteType);
197+
if (!ATMD.ContainsPointer && IncompleteType)
198+
return std::nullopt;
199+
200+
return ATMD;
201+
}

0 commit comments

Comments
 (0)