Skip to content

Commit 2f9fe2a

Browse files
committed
[CIR] Upstream handling for BaseToDerived casts
Upstream handling for BaseToDerived casts, adding the cir.base_class_addr operation and lowering to LLVM IR.
1 parent 518b38c commit 2f9fe2a

File tree

8 files changed

+248
-2
lines changed

8 files changed

+248
-2
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3414,6 +3414,54 @@ def CIR_BaseClassAddrOp : CIR_Op<"base_class_addr"> {
34143414
}];
34153415
}
34163416

3417+
//===----------------------------------------------------------------------===//
3418+
// DerivedClassAddrOp
3419+
//===----------------------------------------------------------------------===//
3420+
3421+
def CIR_DerivedClassAddrOp : CIR_Op<"derived_class_addr"> {
3422+
let summary = "Get the derived class address for a class/struct";
3423+
let description = [{
3424+
The `cir.derived_class_addr` operaration gets the address of a particular
3425+
derived class given a non-virtual base class pointer. The offset in bytes
3426+
of the base class must be passed in, similar to `cir.base_class_addr`, but
3427+
going into the other direction. This means lowering to a negative offset.
3428+
3429+
The operation contains a flag for whether or not the operand may be nullptr.
3430+
That depends on the context and cannot be known by the operation, and that
3431+
information affects how the operation is lowered.
3432+
3433+
Example:
3434+
```c++
3435+
class A {};
3436+
class B : public A {};
3437+
3438+
B *getAsB(A *a) {
3439+
return static_cast<B*>(a);
3440+
}
3441+
```
3442+
3443+
leads to
3444+
```mlir
3445+
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
3446+
%3 = cir.base_class_addr %2 : !cir.ptr<!rec_B> [0] -> !cir.ptr<!rec_A>
3447+
```
3448+
}];
3449+
3450+
// The validity of the relationship of derived and base cannot yet be
3451+
// verified, currently not worth adding a verifier.
3452+
let arguments = (ins
3453+
Arg<CIR_PointerType, "base class pointer", [MemRead]>:$base_addr,
3454+
IndexAttr:$offset, UnitAttr:$assume_not_null);
3455+
3456+
let results = (outs Res<CIR_PointerType, "">:$derived_addr);
3457+
3458+
let assemblyFormat = [{
3459+
$base_addr `:` qualified(type($base_addr))
3460+
(`nonnull` $assume_not_null^)?
3461+
` ` `[` $offset `]` `->` qualified(type($derived_addr)) attr-dict
3462+
}];
3463+
}
3464+
34173465
//===----------------------------------------------------------------------===//
34183466
// ComplexCreateOp
34193467
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,19 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
403403
return Address(baseAddr, destType, addr.getAlignment());
404404
}
405405

406+
Address createDerivedClassAddr(mlir::Location loc, Address addr,
407+
mlir::Type destType, unsigned offset,
408+
bool assumeNotNull) {
409+
if (destType == addr.getElementType())
410+
return addr;
411+
412+
cir::PointerType ptrTy = getPointerTo(destType);
413+
auto derivedAddr =
414+
cir::DerivedClassAddrOp::create(*this, loc, ptrTy, addr.getPointer(),
415+
mlir::APInt(64, offset), assumeNotNull);
416+
return Address(derivedAddr, destType, addr.getAlignment());
417+
}
418+
406419
mlir::Value createVTTAddrPoint(mlir::Location loc, mlir::Type retTy,
407420
mlir::Value addr, uint64_t offset) {
408421
return cir::VTTAddrPointOp::create(*this, loc, retTy,

clang/lib/CIR/CodeGen/CIRGenClass.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,25 @@ mlir::Value CIRGenFunction::getVTTParameter(GlobalDecl gd, bool forVirtualBase,
11101110
}
11111111
}
11121112

1113+
Address CIRGenFunction::getAddressOfDerivedClass(
1114+
mlir::Location loc, Address baseAddr, const CXXRecordDecl *derived,
1115+
llvm::iterator_range<CastExpr::path_const_iterator> path,
1116+
bool nullCheckValue) {
1117+
assert(!path.empty() && "Base path should not be empty!");
1118+
1119+
QualType derivedTy = getContext().getCanonicalTagType(derived);
1120+
mlir::Type derivedValueTy = convertType(derivedTy);
1121+
CharUnits nonVirtualOffset =
1122+
cgm.computeNonVirtualBaseClassOffset(derived, path);
1123+
1124+
// Note that in OG, no offset (nonVirtualOffset.getQuantity() == 0) means it
1125+
// just gives the address back. In CIR a `cir.derived_class` is created and
1126+
// made into a nop later on during lowering.
1127+
return builder.createDerivedClassAddr(loc, baseAddr, derivedValueTy,
1128+
nonVirtualOffset.getQuantity(),
1129+
/*assumeNotNull=*/!nullCheckValue);
1130+
}
1131+
11131132
Address CIRGenFunction::getAddressOfBaseClass(
11141133
Address value, const CXXRecordDecl *derived,
11151134
llvm::iterator_range<CastExpr::path_const_iterator> path,

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,6 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
12871287
case CK_NonAtomicToAtomic:
12881288
case CK_AtomicToNonAtomic:
12891289
case CK_ToUnion:
1290-
case CK_BaseToDerived:
12911290
case CK_ObjCObjectLValueCast:
12921291
case CK_VectorSplat:
12931292
case CK_ConstructorConversion:
@@ -1322,6 +1321,7 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
13221321
lv.getAddress().getAlignment()),
13231322
e->getType(), lv.getBaseInfo());
13241323
}
1324+
13251325
case CK_LValueBitCast: {
13261326
// This must be a reinterpret_cast (or c-style equivalent).
13271327
const auto *ce = cast<ExplicitCastExpr>(e);
@@ -1373,6 +1373,22 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
13731373
return makeAddrLValue(baseAddr, e->getType(), lv.getBaseInfo());
13741374
}
13751375

1376+
case CK_BaseToDerived: {
1377+
const auto *derivedClassDecl = e->getType()->castAsCXXRecordDecl();
1378+
LValue lv = emitLValue(e->getSubExpr());
1379+
1380+
// Perform the base-to-derived conversion
1381+
Address derived = getAddressOfDerivedClass(
1382+
getLoc(e->getSourceRange()), lv.getAddress(), derivedClassDecl,
1383+
e->path(), /*NullCheckValue=*/false);
1384+
// C++11 [expr.static.cast]p2: Behavior is undefined if a downcast is
1385+
// performed and the object is not of the derived type.
1386+
assert(!cir::MissingFeatures::sanitizers());
1387+
1388+
assert(!cir::MissingFeatures::opTBAA());
1389+
return makeAddrLValue(derived, e->getType(), lv.getBaseInfo());
1390+
}
1391+
13761392
case CK_ZeroToOCLOpaqueType:
13771393
llvm_unreachable("NULL to OpenCL opaque type lvalue cast is not valid");
13781394
}

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1992,14 +1992,27 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
19921992
return builder.createIntToPtr(middleVal, destCIRTy);
19931993
}
19941994

1995+
case CK_BaseToDerived: {
1996+
const CXXRecordDecl *derivedClassDecl = destTy->getPointeeCXXRecordDecl();
1997+
assert(derivedClassDecl && "BaseToDerived arg isn't a C++ object pointer!");
1998+
Address base = cgf.emitPointerWithAlignment(subExpr);
1999+
Address derived = cgf.getAddressOfDerivedClass(
2000+
cgf.getLoc(ce->getSourceRange()), base, derivedClassDecl, ce->path(),
2001+
cgf.shouldNullCheckClassCastValue(ce));
2002+
2003+
// C++11 [expr.static.cast]p11: Behavior is undefined if a downcast is
2004+
// performed and the object is not of the derived type.
2005+
assert(!cir::MissingFeatures::sanitizers());
2006+
2007+
return cgf.getAsNaturalPointerTo(derived, ce->getType()->getPointeeType());
2008+
}
19952009
case CK_UncheckedDerivedToBase:
19962010
case CK_DerivedToBase: {
19972011
// The EmitPointerWithAlignment path does this fine; just discard
19982012
// the alignment.
19992013
return cgf.getAsNaturalPointerTo(cgf.emitPointerWithAlignment(ce),
20002014
ce->getType()->getPointeeType());
20012015
}
2002-
20032016
case CK_Dynamic: {
20042017
Address v = cgf.emitPointerWithAlignment(subExpr);
20052018
const auto *dce = cast<CXXDynamicCastExpr>(ce);

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,11 @@ class CIRGenFunction : public CIRGenTypeCache {
823823
llvm::iterator_range<CastExpr::path_const_iterator> path,
824824
bool nullCheckValue, SourceLocation loc);
825825

826+
Address getAddressOfDerivedClass(
827+
mlir::Location loc, Address baseAddr, const CXXRecordDecl *derived,
828+
llvm::iterator_range<CastExpr::path_const_iterator> path,
829+
bool nullCheckValue);
830+
826831
/// Return the VTT parameter that should be passed to a base
827832
/// constructor/destructor with virtual bases.
828833
/// FIXME: VTTs are Itanium ABI-specific, so the definition should move

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,41 @@ mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite(
13601360
return mlir::success();
13611361
}
13621362

1363+
mlir::LogicalResult CIRToLLVMDerivedClassAddrOpLowering::matchAndRewrite(
1364+
cir::DerivedClassAddrOp derivedClassOp, OpAdaptor adaptor,
1365+
mlir::ConversionPatternRewriter &rewriter) const {
1366+
const mlir::Type resultType =
1367+
getTypeConverter()->convertType(derivedClassOp.getType());
1368+
mlir::Value baseAddr = adaptor.getBaseAddr();
1369+
// The offset is set in the operation as an unsigned value, but it must be
1370+
// applied as a negative offset.
1371+
int64_t offsetVal = -(adaptor.getOffset().getZExtValue());
1372+
if (offsetVal == 0) {
1373+
// If the offset is zero, we can just return the base address,
1374+
rewriter.replaceOp(derivedClassOp, baseAddr);
1375+
return mlir::success();
1376+
}
1377+
llvm::SmallVector<mlir::LLVM::GEPArg, 1> offset = {offsetVal};
1378+
mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8,
1379+
mlir::IntegerType::Signless);
1380+
if (derivedClassOp.getAssumeNotNull()) {
1381+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
1382+
derivedClassOp, resultType, byteType, baseAddr, offset,
1383+
mlir::LLVM::GEPNoWrapFlags::inbounds);
1384+
} else {
1385+
mlir::Location loc = derivedClassOp.getLoc();
1386+
mlir::Value isNull = mlir::LLVM::ICmpOp::create(
1387+
rewriter, loc, mlir::LLVM::ICmpPredicate::eq, baseAddr,
1388+
mlir::LLVM::ZeroOp::create(rewriter, loc, baseAddr.getType()));
1389+
mlir::Value adjusted =
1390+
mlir::LLVM::GEPOp::create(rewriter, loc, resultType, byteType, baseAddr,
1391+
offset, mlir::LLVM::GEPNoWrapFlags::inbounds);
1392+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(derivedClassOp, isNull,
1393+
baseAddr, adjusted);
1394+
}
1395+
return mlir::success();
1396+
}
1397+
13631398
mlir::LogicalResult CIRToLLVMATanOpLowering::matchAndRewrite(
13641399
cir::ATanOp op, OpAdaptor adaptor,
13651400
mlir::ConversionPatternRewriter &rewriter) const {
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
3+
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --check-prefix=LLVM --input-file=%t-cir.ll %s
5+
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -emit-llvm %s -o %t.ll
6+
// RUN: FileCheck --check-prefix=OGCG --input-file=%t.ll %s
7+
8+
class A {
9+
int a;
10+
};
11+
12+
class B {
13+
int b;
14+
public:
15+
A *getAsA();
16+
};
17+
18+
class X : public A, public B {
19+
int x;
20+
};
21+
22+
X *castAtoX(A *a) {
23+
return static_cast<X*>(a);
24+
}
25+
26+
// CIR: cir.func {{.*}} @_Z8castAtoXP1A(%[[ARG0:.*]]: !cir.ptr<!rec_A> {{.*}})
27+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>, ["a", init]
28+
// CIR: cir.store %[[ARG0]], %[[A_ADDR]] : !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>
29+
// CIR: %[[A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
30+
// CIR: %[[X:.*]] = cir.derived_class_addr %[[A]] : !cir.ptr<!rec_A> [0] -> !cir.ptr<!rec_X>
31+
32+
// Note: Because the offset is 0, a null check is not needed.
33+
34+
// LLVM: define {{.*}} ptr @_Z8castAtoXP1A(ptr %[[ARG0:.*]])
35+
// LLVM: %[[A_ADDR:.*]] = alloca ptr
36+
// LLVM: store ptr %[[ARG0]], ptr %[[A_ADDR]]
37+
// LLVM: %[[X:.*]] = load ptr, ptr %[[A_ADDR]]
38+
39+
// OGCG: define {{.*}} ptr @_Z8castAtoXP1A(ptr {{.*}} %[[ARG0:.*]])
40+
// OGCG: %[[A_ADDR:.*]] = alloca ptr
41+
// OGCG: store ptr %[[ARG0]], ptr %[[A_ADDR]]
42+
// OGCG: %[[X:.*]] = load ptr, ptr %[[A_ADDR]]
43+
44+
X *castBtoX(B *b) {
45+
return static_cast<X*>(b);
46+
}
47+
48+
// CIR: cir.func {{.*}} @_Z8castBtoXP1B(%[[ARG0:.*]]: !cir.ptr<!rec_B> {{.*}})
49+
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>, ["b", init]
50+
// CIR: cir.store %[[ARG0]], %[[B_ADDR]] : !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>
51+
// CIR: %[[B:.*]] = cir.load{{.*}} %[[B_ADDR]] : !cir.ptr<!cir.ptr<!rec_B>>, !cir.ptr<!rec_B>
52+
// CIR: %[[X:.*]] = cir.derived_class_addr %[[B]] : !cir.ptr<!rec_B> [4] -> !cir.ptr<!rec_X>
53+
54+
// LLVM: define {{.*}} ptr @_Z8castBtoXP1B(ptr %[[ARG0:.*]])
55+
// LLVM: %[[B_ADDR:.*]] = alloca ptr, i64 1, align 8
56+
// LLVM: store ptr %[[ARG0]], ptr %[[B_ADDR]], align 8
57+
// LLVM: %[[B:.*]] = load ptr, ptr %[[B_ADDR]], align 8
58+
// LLVM: %[[IS_NULL:.*]] = icmp eq ptr %[[B]], null
59+
// LLVM: %[[B_NON_NULL:.*]] = getelementptr inbounds i8, ptr %[[B]], i32 -4
60+
// LLVM: %[[X:.*]] = select i1 %[[IS_NULL]], ptr %[[B]], ptr %[[B_NON_NULL]]
61+
62+
// OGCG: define {{.*}} ptr @_Z8castBtoXP1B(ptr {{.*}} %[[ARG0:.*]])
63+
// OGCG: entry:
64+
// OGCG: %[[B_ADDR:.*]] = alloca ptr
65+
// OGCG: store ptr %[[ARG0]], ptr %[[B_ADDR]]
66+
// OGCG: %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
67+
// OGCG: %[[IS_NULL:.*]] = icmp eq ptr %[[B]], null
68+
// OGCG: br i1 %[[IS_NULL]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]]
69+
// OGCG: [[LABEL_NOTNULL]]:
70+
// OGCG: %[[B_NON_NULL:.*]] = getelementptr inbounds i8, ptr %[[B]], i64 -4
71+
// OGCG: br label %[[LABEL_END:.*]]
72+
// OGCG: [[LABEL_NULL]]:
73+
// OGCG: br label %[[LABEL_END:.*]]
74+
// OGCG: [[LABEL_END]]:
75+
// OGCG: %[[X:.*]] = phi ptr [ %[[B_NON_NULL]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ]
76+
77+
X &castBReftoXRef(B &b) {
78+
return static_cast<X&>(b);
79+
}
80+
81+
// CIR: cir.func {{.*}} @_Z14castBReftoXRefR1B(%[[ARG0:.*]]: !cir.ptr<!rec_B> {{.*}})
82+
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>, ["b", init, const]
83+
// CIR: cir.store %[[ARG0]], %[[B_ADDR]] : !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>
84+
// CIR: %[[B:.*]] = cir.load{{.*}} %[[B_ADDR]] : !cir.ptr<!cir.ptr<!rec_B>>, !cir.ptr<!rec_B>
85+
// CIR: %[[X:.*]] = cir.derived_class_addr %[[B]] : !cir.ptr<!rec_B> nonnull [4] -> !cir.ptr<!rec_X>
86+
87+
// LLVM: define {{.*}} ptr @_Z14castBReftoXRefR1B(ptr %[[ARG0:.*]])
88+
// LLVM: %[[B_ADDR:.*]] = alloca ptr
89+
// LLVM: store ptr %[[ARG0]], ptr %[[B_ADDR]]
90+
// LLVM: %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
91+
// LLVM: %[[X:.*]] = getelementptr inbounds i8, ptr %[[B]], i32 -4
92+
93+
// OGCG: define {{.*}} ptr @_Z14castBReftoXRefR1B(ptr {{.*}} %[[ARG0:.*]])
94+
// OGCG: %[[B_ADDR:.*]] = alloca ptr
95+
// OGCG: store ptr %[[ARG0]], ptr %[[B_ADDR]]
96+
// OGCG: %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
97+
// OGCG: %[[X:.*]] = getelementptr inbounds i8, ptr %[[B]], i64 -4

0 commit comments

Comments
 (0)