Skip to content

Commit cb63cea

Browse files
committed
[CIR] Upstream handling for BaseToDerived casts
Upstream handling for BaseToDerived casts, adding the cir.base_class_addr operation and lowering to LLVM IR.
1 parent 260df80 commit cb63cea

File tree

8 files changed

+247
-2
lines changed

8 files changed

+247
-2
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3369,6 +3369,54 @@ def CIR_BaseClassAddrOp : CIR_Op<"base_class_addr"> {
33693369
}];
33703370
}
33713371

3372+
//===----------------------------------------------------------------------===//
3373+
// DerivedClassAddrOp
3374+
//===----------------------------------------------------------------------===//
3375+
3376+
def CIR_DerivedClassAddrOp : CIR_Op<"derived_class_addr"> {
3377+
let summary = "Get the derived class address for a class/struct";
3378+
let description = [{
3379+
The `cir.derived_class_addr` operaration gets the address of a particular
3380+
derived class given a non-virtual base class pointer. The offset in bytes
3381+
of the base class must be passed in, similar to `cir.base_class_addr`, but
3382+
going into the other direction. This means lowering to a negative offset.
3383+
3384+
The operation contains a flag for whether or not the operand may be nullptr.
3385+
That depends on the context and cannot be known by the operation, and that
3386+
information affects how the operation is lowered.
3387+
3388+
Example:
3389+
```c++
3390+
class A {};
3391+
class B : public A {};
3392+
3393+
B *getAsB(A *a) {
3394+
return static_cast<B*>(a);
3395+
}
3396+
```
3397+
3398+
leads to
3399+
```mlir
3400+
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
3401+
%3 = cir.base_class_addr %2 : !cir.ptr<!rec_B> [0] -> !cir.ptr<!rec_A>
3402+
```
3403+
}];
3404+
3405+
// The validity of the relationship of derived and base cannot yet be
3406+
// verified, currently not worth adding a verifier.
3407+
let arguments = (ins
3408+
Arg<CIR_PointerType, "base class pointer", [MemRead]>:$base_addr,
3409+
IndexAttr:$offset, UnitAttr:$assume_not_null);
3410+
3411+
let results = (outs Res<CIR_PointerType, "">:$derived_addr);
3412+
3413+
let assemblyFormat = [{
3414+
$base_addr `:` qualified(type($base_addr))
3415+
(`nonnull` $assume_not_null^)?
3416+
` ` `[` $offset `]` `->` qualified(type($derived_addr)) attr-dict
3417+
}];
3418+
}
3419+
33723420
//===----------------------------------------------------------------------===//
33733421
// ComplexCreateOp
33743422
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,19 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
403403
return Address(baseAddr, destType, addr.getAlignment());
404404
}
405405

406+
Address createDerivedClassAddr(mlir::Location loc, Address addr,
407+
mlir::Type destType, unsigned offset,
408+
bool assumeNotNull) {
409+
if (destType == addr.getElementType())
410+
return addr;
411+
412+
cir::PointerType ptrTy = getPointerTo(destType);
413+
auto derivedAddr =
414+
cir::DerivedClassAddrOp::create(*this, loc, ptrTy, addr.getPointer(),
415+
mlir::APInt(64, offset), assumeNotNull);
416+
return Address(derivedAddr, destType, addr.getAlignment());
417+
}
418+
406419
mlir::Value createVTTAddrPoint(mlir::Location loc, mlir::Type retTy,
407420
mlir::Value addr, uint64_t offset) {
408421
return cir::VTTAddrPointOp::create(*this, loc, retTy,

clang/lib/CIR/CodeGen/CIRGenClass.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,25 @@ mlir::Value CIRGenFunction::getVTTParameter(GlobalDecl gd, bool forVirtualBase,
11101110
}
11111111
}
11121112

1113+
Address CIRGenFunction::getAddressOfDerivedClass(
1114+
mlir::Location loc, Address baseAddr, const CXXRecordDecl *derived,
1115+
llvm::iterator_range<CastExpr::path_const_iterator> path,
1116+
bool nullCheckValue) {
1117+
assert(!path.empty() && "Base path should not be empty!");
1118+
1119+
QualType derivedTy = getContext().getCanonicalTagType(derived);
1120+
mlir::Type derivedValueTy = convertType(derivedTy);
1121+
CharUnits nonVirtualOffset =
1122+
cgm.computeNonVirtualBaseClassOffset(derived, path);
1123+
1124+
// Note that in OG, no offset (nonVirtualOffset.getQuantity() == 0) means it
1125+
// just gives the address back. In CIR a `cir.derived_class` is created and
1126+
// made into a nop later on during lowering.
1127+
return builder.createDerivedClassAddr(loc, baseAddr, derivedValueTy,
1128+
nonVirtualOffset.getQuantity(),
1129+
/*assumeNotNull=*/!nullCheckValue);
1130+
}
1131+
11131132
Address CIRGenFunction::getAddressOfBaseClass(
11141133
Address value, const CXXRecordDecl *derived,
11151134
llvm::iterator_range<CastExpr::path_const_iterator> path,

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,6 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
12041204
case CK_NonAtomicToAtomic:
12051205
case CK_AtomicToNonAtomic:
12061206
case CK_ToUnion:
1207-
case CK_BaseToDerived:
12081207
case CK_AddressSpaceConversion:
12091208
case CK_ObjCObjectLValueCast:
12101209
case CK_VectorSplat:
@@ -1220,6 +1219,22 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
12201219
return {};
12211220
}
12221221

1222+
case CK_BaseToDerived: {
1223+
const auto *derivedClassDecl = e->getType()->castAsCXXRecordDecl();
1224+
LValue lv = emitLValue(e->getSubExpr());
1225+
1226+
// Perform the base-to-derived conversion
1227+
Address derived = getAddressOfDerivedClass(
1228+
getLoc(e->getSourceRange()), lv.getAddress(), derivedClassDecl,
1229+
e->path(), /*NullCheckValue=*/false);
1230+
// C++11 [expr.static.cast]p2: Behavior is undefined if a downcast is
1231+
// performed and the object is not of the derived type.
1232+
assert(!cir::MissingFeatures::sanitizers());
1233+
1234+
assert(!cir::MissingFeatures::opTBAA());
1235+
return makeAddrLValue(derived, e->getType(), lv.getBaseInfo());
1236+
}
1237+
12231238
case CK_LValueBitCast: {
12241239
// This must be a reinterpret_cast (or c-style equivalent).
12251240
const auto *ce = cast<ExplicitCastExpr>(e);

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1933,14 +1933,27 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
19331933
return builder.createIntToPtr(middleVal, destCIRTy);
19341934
}
19351935

1936+
case CK_BaseToDerived: {
1937+
const CXXRecordDecl *derivedClassDecl = destTy->getPointeeCXXRecordDecl();
1938+
assert(derivedClassDecl && "BaseToDerived arg isn't a C++ object pointer!");
1939+
Address base = cgf.emitPointerWithAlignment(subExpr);
1940+
Address derived = cgf.getAddressOfDerivedClass(
1941+
cgf.getLoc(ce->getSourceRange()), base, derivedClassDecl, ce->path(),
1942+
cgf.shouldNullCheckClassCastValue(ce));
1943+
1944+
// C++11 [expr.static.cast]p11: Behavior is undefined if a downcast is
1945+
// performed and the object is not of the derived type.
1946+
assert(!cir::MissingFeatures::sanitizers());
1947+
1948+
return cgf.getAsNaturalPointerTo(derived, ce->getType()->getPointeeType());
1949+
}
19361950
case CK_UncheckedDerivedToBase:
19371951
case CK_DerivedToBase: {
19381952
// The EmitPointerWithAlignment path does this fine; just discard
19391953
// the alignment.
19401954
return cgf.getAsNaturalPointerTo(cgf.emitPointerWithAlignment(ce),
19411955
ce->getType()->getPointeeType());
19421956
}
1943-
19441957
case CK_Dynamic: {
19451958
Address v = cgf.emitPointerWithAlignment(subExpr);
19461959
const auto *dce = cast<CXXDynamicCastExpr>(ce);

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,11 @@ class CIRGenFunction : public CIRGenTypeCache {
816816
llvm::iterator_range<CastExpr::path_const_iterator> path,
817817
bool nullCheckValue, SourceLocation loc);
818818

819+
Address getAddressOfDerivedClass(
820+
mlir::Location loc, Address baseAddr, const CXXRecordDecl *derived,
821+
llvm::iterator_range<CastExpr::path_const_iterator> path,
822+
bool nullCheckValue);
823+
819824
/// Return the VTT parameter that should be passed to a base
820825
/// constructor/destructor with virtual bases.
821826
/// FIXME: VTTs are Itanium ABI-specific, so the definition should move

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,41 @@ mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite(
13361336
return mlir::success();
13371337
}
13381338

1339+
mlir::LogicalResult CIRToLLVMDerivedClassAddrOpLowering::matchAndRewrite(
1340+
cir::DerivedClassAddrOp derivedClassOp, OpAdaptor adaptor,
1341+
mlir::ConversionPatternRewriter &rewriter) const {
1342+
const mlir::Type resultType =
1343+
getTypeConverter()->convertType(derivedClassOp.getType());
1344+
mlir::Value baseAddr = adaptor.getBaseAddr();
1345+
// The offset is set in the operation as an unsigned value, but it must be
1346+
// applied as a negative offset.
1347+
int64_t offsetVal = -(adaptor.getOffset().getZExtValue());
1348+
if (offsetVal == 0) {
1349+
// If the offset is zero, we can just return the base address,
1350+
rewriter.replaceOp(derivedClassOp, baseAddr);
1351+
return mlir::success();
1352+
}
1353+
llvm::SmallVector<mlir::LLVM::GEPArg, 1> offset = {offsetVal};
1354+
mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8,
1355+
mlir::IntegerType::Signless);
1356+
if (derivedClassOp.getAssumeNotNull()) {
1357+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
1358+
derivedClassOp, resultType, byteType, baseAddr, offset,
1359+
mlir::LLVM::GEPNoWrapFlags::inbounds);
1360+
} else {
1361+
mlir::Location loc = derivedClassOp.getLoc();
1362+
mlir::Value isNull = mlir::LLVM::ICmpOp::create(
1363+
rewriter, loc, mlir::LLVM::ICmpPredicate::eq, baseAddr,
1364+
mlir::LLVM::ZeroOp::create(rewriter, loc, baseAddr.getType()));
1365+
mlir::Value adjusted =
1366+
mlir::LLVM::GEPOp::create(rewriter, loc, resultType, byteType, baseAddr,
1367+
offset, mlir::LLVM::GEPNoWrapFlags::inbounds);
1368+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(derivedClassOp, isNull,
1369+
baseAddr, adjusted);
1370+
}
1371+
return mlir::success();
1372+
}
1373+
13391374
mlir::LogicalResult CIRToLLVMATanOpLowering::matchAndRewrite(
13401375
cir::ATanOp op, OpAdaptor adaptor,
13411376
mlir::ConversionPatternRewriter &rewriter) const {
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
3+
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --check-prefix=LLVM --input-file=%t-cir.ll %s
5+
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -emit-llvm %s -o %t.ll
6+
// RUN: FileCheck --check-prefix=OGCG --input-file=%t.ll %s
7+
8+
class A {
9+
int a;
10+
};
11+
12+
class B {
13+
int b;
14+
public:
15+
A *getAsA();
16+
};
17+
18+
class X : public A, public B {
19+
int x;
20+
};
21+
22+
X *castAtoX(A *a) {
23+
return static_cast<X*>(a);
24+
}
25+
26+
// CIR: cir.func {{.*}} @_Z8castAtoXP1A(%[[ARG0:.*]]: !cir.ptr<!rec_A> {{.*}})
27+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>, ["a", init]
28+
// CIR: cir.store %[[ARG0]], %[[A_ADDR]] : !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>
29+
// CIR: %[[A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
30+
// CIR: %[[X:.*]] = cir.derived_class_addr %[[A]] : !cir.ptr<!rec_A> [0] -> !cir.ptr<!rec_X>
31+
32+
// Note: Because the offset is 0, a null check is not needed.
33+
34+
// LLVM: define {{.*}} ptr @_Z8castAtoXP1A(ptr %[[ARG0:.*]])
35+
// LLVM: %[[A_ADDR:.*]] = alloca ptr
36+
// LLVM: store ptr %[[ARG0]], ptr %[[A_ADDR]]
37+
// LLVM: %[[X:.*]] = load ptr, ptr %[[A_ADDR]]
38+
39+
// OGCG: define {{.*}} ptr @_Z8castAtoXP1A(ptr {{.*}} %[[ARG0:.*]])
40+
// OGCG: %[[A_ADDR:.*]] = alloca ptr
41+
// OGCG: store ptr %[[ARG0]], ptr %[[A_ADDR]]
42+
// OGCG: %[[X:.*]] = load ptr, ptr %[[A_ADDR]]
43+
44+
X *castBtoX(B *b) {
45+
return static_cast<X*>(b);
46+
}
47+
48+
// CIR: cir.func {{.*}} @_Z8castBtoXP1B(%[[ARG0:.*]]: !cir.ptr<!rec_B> {{.*}})
49+
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>, ["b", init]
50+
// CIR: cir.store %[[ARG0]], %[[B_ADDR]] : !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>
51+
// CIR: %[[B:.*]] = cir.load{{.*}} %[[B_ADDR]] : !cir.ptr<!cir.ptr<!rec_B>>, !cir.ptr<!rec_B>
52+
// CIR: %[[X:.*]] = cir.derived_class_addr %[[B]] : !cir.ptr<!rec_B> [4] -> !cir.ptr<!rec_X>
53+
54+
// LLVM: define {{.*}} ptr @_Z8castBtoXP1B(ptr %[[ARG0:.*]])
55+
// LLVM: %[[B_ADDR:.*]] = alloca ptr, i64 1, align 8
56+
// LLVM: store ptr %[[ARG0]], ptr %[[B_ADDR]], align 8
57+
// LLVM: %[[B:.*]] = load ptr, ptr %[[B_ADDR]], align 8
58+
// LLVM: %[[IS_NULL:.*]] = icmp eq ptr %[[B]], null
59+
// LLVM: %[[B_NON_NULL:.*]] = getelementptr inbounds i8, ptr %[[B]], i32 -4
60+
// LLVM: %[[X:.*]] = select i1 %[[IS_NULL]], ptr %[[B]], ptr %[[B_NON_NULL]]
61+
62+
// OGCG: define {{.*}} ptr @_Z8castBtoXP1B(ptr {{.*}} %[[ARG0:.*]])
63+
// OGCG: entry:
64+
// OGCG: %[[B_ADDR:.*]] = alloca ptr
65+
// OGCG: store ptr %[[ARG0]], ptr %[[B_ADDR]]
66+
// OGCG: %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
67+
// OGCG: %[[IS_NULL:.*]] = icmp eq ptr %[[B]], null
68+
// OGCG: br i1 %[[IS_NULL]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]]
69+
// OGCG: [[LABEL_NOTNULL]]:
70+
// OGCG: %[[B_NON_NULL:.*]] = getelementptr inbounds i8, ptr %[[B]], i64 -4
71+
// OGCG: br label %[[LABEL_END:.*]]
72+
// OGCG: [[LABEL_NULL]]:
73+
// OGCG: br label %[[LABEL_END:.*]]
74+
// OGCG: [[LABEL_END]]:
75+
// OGCG: %[[X:.*]] = phi ptr [ %[[B_NON_NULL]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ]
76+
77+
X &castBReftoXRef(B &b) {
78+
return static_cast<X&>(b);
79+
}
80+
81+
// CIR: cir.func {{.*}} @_Z14castBReftoXRefR1B(%[[ARG0:.*]]: !cir.ptr<!rec_B> {{.*}})
82+
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>, ["b", init, const]
83+
// CIR: cir.store %[[ARG0]], %[[B_ADDR]] : !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>
84+
// CIR: %[[B:.*]] = cir.load{{.*}} %[[B_ADDR]] : !cir.ptr<!cir.ptr<!rec_B>>, !cir.ptr<!rec_B>
85+
// CIR: %[[X:.*]] = cir.derived_class_addr %[[B]] : !cir.ptr<!rec_B> nonnull [4] -> !cir.ptr<!rec_X>
86+
87+
// LLVM: define {{.*}} ptr @_Z14castBReftoXRefR1B(ptr %[[ARG0:.*]])
88+
// LLVM: %[[B_ADDR:.*]] = alloca ptr
89+
// LLVM: store ptr %[[ARG0]], ptr %[[B_ADDR]]
90+
// LLVM: %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
91+
// LLVM: %[[X:.*]] = getelementptr inbounds i8, ptr %[[B]], i32 -4
92+
93+
// OGCG: define {{.*}} ptr @_Z14castBReftoXRefR1B(ptr {{.*}} %[[ARG0:.*]])
94+
// OGCG: %[[B_ADDR:.*]] = alloca ptr
95+
// OGCG: store ptr %[[ARG0]], ptr %[[B_ADDR]]
96+
// OGCG: %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
97+
// OGCG: %[[X:.*]] = getelementptr inbounds i8, ptr %[[B]], i64 -4

0 commit comments

Comments
 (0)