Skip to content

Commit 0c84643

Browse files
authored
[CIR] Upstream handling for BaseToDerived casts (#167769)
Upstream handling for BaseToDerived casts, adding the cir.base_class_addr operation and lowering to LLVM IR.
1 parent 6f5c8fe commit 0c84643

File tree

8 files changed

+254
-4
lines changed

8 files changed

+254
-4
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3386,6 +3386,10 @@ def CIR_BaseClassAddrOp : CIR_Op<"base_class_addr"> {
33863386
cannot be known by the operation, and that information affects how the
33873387
operation is lowered.
33883388

3389+
The validity of the relationship of derived and base cannot yet be verified.
3390+
If the target class is not a valid base class for the object, the behavior
3391+
is undefined.
3392+
33893393
Example:
33903394
```c++
33913395
struct Base { };
@@ -3399,8 +3403,6 @@ def CIR_BaseClassAddrOp : CIR_Op<"base_class_addr"> {
33993403
```
34003404
}];
34013405

3402-
// The validity of the relationship of derived and base cannot yet be
3403-
// verified, currently not worth adding a verifier.
34043406
let arguments = (ins
34053407
Arg<CIR_PointerType, "derived class pointer", [MemRead]>:$derived_addr,
34063408
IndexAttr:$offset, UnitAttr:$assume_not_null);
@@ -3414,6 +3416,56 @@ def CIR_BaseClassAddrOp : CIR_Op<"base_class_addr"> {
34143416
}];
34153417
}
34163418

3419+
//===----------------------------------------------------------------------===//
3420+
// DerivedClassAddrOp
3421+
//===----------------------------------------------------------------------===//
3422+
3423+
def CIR_DerivedClassAddrOp : CIR_Op<"derived_class_addr"> {
3424+
let summary = "Get the derived class address for a class/struct";
3425+
let description = [{
3426+
The `cir.derived_class_addr` operaration gets the address of a particular
3427+
derived class given a non-virtual base class pointer. The offset in bytes
3428+
of the base class must be passed in, similar to `cir.base_class_addr`, but
3429+
going into the other direction. This means lowering to a negative offset.
3430+
3431+
The operation contains a flag for whether or not the operand may be nullptr.
3432+
That depends on the context and cannot be known by the operation, and that
3433+
information affects how the operation is lowered.
3434+
3435+
The validity of the relationship of derived and base cannot yet be verified.
3436+
If the target class is not a valid derived class for the object, the
3437+
behavior is undefined.
3438+
3439+
Example:
3440+
```c++
3441+
class A {};
3442+
class B : public A {};
3443+
3444+
B *getAsB(A *a) {
3445+
return static_cast<B*>(a);
3446+
}
3447+
```
3448+
3449+
leads to
3450+
```mlir
3451+
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
3452+
%3 = cir.base_class_addr %2 : !cir.ptr<!rec_B> [0] -> !cir.ptr<!rec_A>
3453+
```
3454+
}];
3455+
3456+
let arguments = (ins
3457+
Arg<CIR_PointerType, "base class pointer", [MemRead]>:$base_addr,
3458+
IndexAttr:$offset, UnitAttr:$assume_not_null);
3459+
3460+
let results = (outs Res<CIR_PointerType, "">:$derived_addr);
3461+
3462+
let assemblyFormat = [{
3463+
$base_addr `:` qualified(type($base_addr))
3464+
(`nonnull` $assume_not_null^)?
3465+
` ` `[` $offset `]` `->` qualified(type($derived_addr)) attr-dict
3466+
}];
3467+
}
3468+
34173469
//===----------------------------------------------------------------------===//
34183470
// ComplexCreateOp
34193471
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,19 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
405405
return Address(baseAddr, destType, addr.getAlignment());
406406
}
407407

408+
Address createDerivedClassAddr(mlir::Location loc, Address addr,
409+
mlir::Type destType, unsigned offset,
410+
bool assumeNotNull) {
411+
if (destType == addr.getElementType())
412+
return addr;
413+
414+
cir::PointerType ptrTy = getPointerTo(destType);
415+
auto derivedAddr =
416+
cir::DerivedClassAddrOp::create(*this, loc, ptrTy, addr.getPointer(),
417+
mlir::APInt(64, offset), assumeNotNull);
418+
return Address(derivedAddr, destType, addr.getAlignment());
419+
}
420+
408421
mlir::Value createVTTAddrPoint(mlir::Location loc, mlir::Type retTy,
409422
mlir::Value addr, uint64_t offset) {
410423
return cir::VTTAddrPointOp::create(*this, loc, retTy,

clang/lib/CIR/CodeGen/CIRGenClass.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,25 @@ mlir::Value CIRGenFunction::getVTTParameter(GlobalDecl gd, bool forVirtualBase,
11101110
}
11111111
}
11121112

1113+
Address CIRGenFunction::getAddressOfDerivedClass(
1114+
mlir::Location loc, Address baseAddr, const CXXRecordDecl *derived,
1115+
llvm::iterator_range<CastExpr::path_const_iterator> path,
1116+
bool nullCheckValue) {
1117+
assert(!path.empty() && "Base path should not be empty!");
1118+
1119+
QualType derivedTy = getContext().getCanonicalTagType(derived);
1120+
mlir::Type derivedValueTy = convertType(derivedTy);
1121+
CharUnits nonVirtualOffset =
1122+
cgm.computeNonVirtualBaseClassOffset(derived, path);
1123+
1124+
// Note that in OG, no offset (nonVirtualOffset.getQuantity() == 0) means it
1125+
// just gives the address back. In CIR a `cir.derived_class` is created and
1126+
// made into a nop later on during lowering.
1127+
return builder.createDerivedClassAddr(loc, baseAddr, derivedValueTy,
1128+
nonVirtualOffset.getQuantity(),
1129+
/*assumeNotNull=*/!nullCheckValue);
1130+
}
1131+
11131132
Address CIRGenFunction::getAddressOfBaseClass(
11141133
Address value, const CXXRecordDecl *derived,
11151134
llvm::iterator_range<CastExpr::path_const_iterator> path,

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1301,7 +1301,6 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
13011301
case CK_NonAtomicToAtomic:
13021302
case CK_AtomicToNonAtomic:
13031303
case CK_ToUnion:
1304-
case CK_BaseToDerived:
13051304
case CK_ObjCObjectLValueCast:
13061305
case CK_VectorSplat:
13071306
case CK_ConstructorConversion:
@@ -1336,6 +1335,7 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
13361335
lv.getAddress().getAlignment()),
13371336
e->getType(), lv.getBaseInfo());
13381337
}
1338+
13391339
case CK_LValueBitCast: {
13401340
// This must be a reinterpret_cast (or c-style equivalent).
13411341
const auto *ce = cast<ExplicitCastExpr>(e);
@@ -1387,6 +1387,22 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
13871387
return makeAddrLValue(baseAddr, e->getType(), lv.getBaseInfo());
13881388
}
13891389

1390+
case CK_BaseToDerived: {
1391+
const auto *derivedClassDecl = e->getType()->castAsCXXRecordDecl();
1392+
LValue lv = emitLValue(e->getSubExpr());
1393+
1394+
// Perform the base-to-derived conversion
1395+
Address derived = getAddressOfDerivedClass(
1396+
getLoc(e->getSourceRange()), lv.getAddress(), derivedClassDecl,
1397+
e->path(), /*NullCheckValue=*/false);
1398+
// C++11 [expr.static.cast]p2: Behavior is undefined if a downcast is
1399+
// performed and the object is not of the derived type.
1400+
assert(!cir::MissingFeatures::sanitizers());
1401+
1402+
assert(!cir::MissingFeatures::opTBAA());
1403+
return makeAddrLValue(derived, e->getType(), lv.getBaseInfo());
1404+
}
1405+
13901406
case CK_ZeroToOCLOpaqueType:
13911407
llvm_unreachable("NULL to OpenCL opaque type lvalue cast is not valid");
13921408
}

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1972,14 +1972,27 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
19721972
return builder.createIntToPtr(middleVal, destCIRTy);
19731973
}
19741974

1975+
case CK_BaseToDerived: {
1976+
const CXXRecordDecl *derivedClassDecl = destTy->getPointeeCXXRecordDecl();
1977+
assert(derivedClassDecl && "BaseToDerived arg isn't a C++ object pointer!");
1978+
Address base = cgf.emitPointerWithAlignment(subExpr);
1979+
Address derived = cgf.getAddressOfDerivedClass(
1980+
cgf.getLoc(ce->getSourceRange()), base, derivedClassDecl, ce->path(),
1981+
cgf.shouldNullCheckClassCastValue(ce));
1982+
1983+
// C++11 [expr.static.cast]p11: Behavior is undefined if a downcast is
1984+
// performed and the object is not of the derived type.
1985+
assert(!cir::MissingFeatures::sanitizers());
1986+
1987+
return cgf.getAsNaturalPointerTo(derived, ce->getType()->getPointeeType());
1988+
}
19751989
case CK_UncheckedDerivedToBase:
19761990
case CK_DerivedToBase: {
19771991
// The EmitPointerWithAlignment path does this fine; just discard
19781992
// the alignment.
19791993
return cgf.getAsNaturalPointerTo(cgf.emitPointerWithAlignment(ce),
19801994
ce->getType()->getPointeeType());
19811995
}
1982-
19831996
case CK_Dynamic: {
19841997
Address v = cgf.emitPointerWithAlignment(subExpr);
19851998
const auto *dce = cast<CXXDynamicCastExpr>(ce);

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,11 @@ class CIRGenFunction : public CIRGenTypeCache {
823823
llvm::iterator_range<CastExpr::path_const_iterator> path,
824824
bool nullCheckValue, SourceLocation loc);
825825

826+
Address getAddressOfDerivedClass(
827+
mlir::Location loc, Address baseAddr, const CXXRecordDecl *derived,
828+
llvm::iterator_range<CastExpr::path_const_iterator> path,
829+
bool nullCheckValue);
830+
826831
/// Return the VTT parameter that should be passed to a base
827832
/// constructor/destructor with virtual bases.
828833
/// FIXME: VTTs are Itanium ABI-specific, so the definition should move

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,41 @@ mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite(
13601360
return mlir::success();
13611361
}
13621362

1363+
mlir::LogicalResult CIRToLLVMDerivedClassAddrOpLowering::matchAndRewrite(
1364+
cir::DerivedClassAddrOp derivedClassOp, OpAdaptor adaptor,
1365+
mlir::ConversionPatternRewriter &rewriter) const {
1366+
const mlir::Type resultType =
1367+
getTypeConverter()->convertType(derivedClassOp.getType());
1368+
mlir::Value baseAddr = adaptor.getBaseAddr();
1369+
// The offset is set in the operation as an unsigned value, but it must be
1370+
// applied as a negative offset.
1371+
int64_t offsetVal = -(adaptor.getOffset().getZExtValue());
1372+
if (offsetVal == 0) {
1373+
// If the offset is zero, we can just return the base address,
1374+
rewriter.replaceOp(derivedClassOp, baseAddr);
1375+
return mlir::success();
1376+
}
1377+
llvm::SmallVector<mlir::LLVM::GEPArg, 1> offset = {offsetVal};
1378+
mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8,
1379+
mlir::IntegerType::Signless);
1380+
if (derivedClassOp.getAssumeNotNull()) {
1381+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
1382+
derivedClassOp, resultType, byteType, baseAddr, offset,
1383+
mlir::LLVM::GEPNoWrapFlags::inbounds);
1384+
} else {
1385+
mlir::Location loc = derivedClassOp.getLoc();
1386+
mlir::Value isNull = mlir::LLVM::ICmpOp::create(
1387+
rewriter, loc, mlir::LLVM::ICmpPredicate::eq, baseAddr,
1388+
mlir::LLVM::ZeroOp::create(rewriter, loc, baseAddr.getType()));
1389+
mlir::Value adjusted =
1390+
mlir::LLVM::GEPOp::create(rewriter, loc, resultType, byteType, baseAddr,
1391+
offset, mlir::LLVM::GEPNoWrapFlags::inbounds);
1392+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(derivedClassOp, isNull,
1393+
baseAddr, adjusted);
1394+
}
1395+
return mlir::success();
1396+
}
1397+
13631398
mlir::LogicalResult CIRToLLVMATanOpLowering::matchAndRewrite(
13641399
cir::ATanOp op, OpAdaptor adaptor,
13651400
mlir::ConversionPatternRewriter &rewriter) const {
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
3+
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --check-prefix=LLVM --input-file=%t-cir.ll %s
5+
// RUN: %clang_cc1 -triple aarch64-none-linux-android21 -emit-llvm %s -o %t.ll
6+
// RUN: FileCheck --check-prefix=OGCG --input-file=%t.ll %s
7+
8+
class A {
9+
int a;
10+
};
11+
12+
class B {
13+
int b;
14+
public:
15+
A *getAsA();
16+
};
17+
18+
class X : public A, public B {
19+
int x;
20+
};
21+
22+
X *castAtoX(A *a) {
23+
return static_cast<X*>(a);
24+
}
25+
26+
// CIR: cir.func {{.*}} @_Z8castAtoXP1A(%[[ARG0:.*]]: !cir.ptr<!rec_A> {{.*}})
27+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>, ["a", init]
28+
// CIR: cir.store %[[ARG0]], %[[A_ADDR]] : !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>
29+
// CIR: %[[A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
30+
// CIR: %[[X:.*]] = cir.derived_class_addr %[[A]] : !cir.ptr<!rec_A> [0] -> !cir.ptr<!rec_X>
31+
32+
// Note: Because the offset is 0, a null check is not needed.
33+
34+
// LLVM: define {{.*}} ptr @_Z8castAtoXP1A(ptr %[[ARG0:.*]])
35+
// LLVM: %[[A_ADDR:.*]] = alloca ptr
36+
// LLVM: store ptr %[[ARG0]], ptr %[[A_ADDR]]
37+
// LLVM: %[[X:.*]] = load ptr, ptr %[[A_ADDR]]
38+
39+
// OGCG: define {{.*}} ptr @_Z8castAtoXP1A(ptr {{.*}} %[[ARG0:.*]])
40+
// OGCG: %[[A_ADDR:.*]] = alloca ptr
41+
// OGCG: store ptr %[[ARG0]], ptr %[[A_ADDR]]
42+
// OGCG: %[[X:.*]] = load ptr, ptr %[[A_ADDR]]
43+
44+
X *castBtoX(B *b) {
45+
return static_cast<X*>(b);
46+
}
47+
48+
// CIR: cir.func {{.*}} @_Z8castBtoXP1B(%[[ARG0:.*]]: !cir.ptr<!rec_B> {{.*}})
49+
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>, ["b", init]
50+
// CIR: cir.store %[[ARG0]], %[[B_ADDR]] : !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>
51+
// CIR: %[[B:.*]] = cir.load{{.*}} %[[B_ADDR]] : !cir.ptr<!cir.ptr<!rec_B>>, !cir.ptr<!rec_B>
52+
// CIR: %[[X:.*]] = cir.derived_class_addr %[[B]] : !cir.ptr<!rec_B> [4] -> !cir.ptr<!rec_X>
53+
54+
// LLVM: define {{.*}} ptr @_Z8castBtoXP1B(ptr %[[ARG0:.*]])
55+
// LLVM: %[[B_ADDR:.*]] = alloca ptr, i64 1, align 8
56+
// LLVM: store ptr %[[ARG0]], ptr %[[B_ADDR]], align 8
57+
// LLVM: %[[B:.*]] = load ptr, ptr %[[B_ADDR]], align 8
58+
// LLVM: %[[IS_NULL:.*]] = icmp eq ptr %[[B]], null
59+
// LLVM: %[[B_NON_NULL:.*]] = getelementptr inbounds i8, ptr %[[B]], i32 -4
60+
// LLVM: %[[X:.*]] = select i1 %[[IS_NULL]], ptr %[[B]], ptr %[[B_NON_NULL]]
61+
62+
// OGCG: define {{.*}} ptr @_Z8castBtoXP1B(ptr {{.*}} %[[ARG0:.*]])
63+
// OGCG: entry:
64+
// OGCG: %[[B_ADDR:.*]] = alloca ptr
65+
// OGCG: store ptr %[[ARG0]], ptr %[[B_ADDR]]
66+
// OGCG: %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
67+
// OGCG: %[[IS_NULL:.*]] = icmp eq ptr %[[B]], null
68+
// OGCG: br i1 %[[IS_NULL]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]]
69+
// OGCG: [[LABEL_NOTNULL]]:
70+
// OGCG: %[[B_NON_NULL:.*]] = getelementptr inbounds i8, ptr %[[B]], i64 -4
71+
// OGCG: br label %[[LABEL_END:.*]]
72+
// OGCG: [[LABEL_NULL]]:
73+
// OGCG: br label %[[LABEL_END:.*]]
74+
// OGCG: [[LABEL_END]]:
75+
// OGCG: %[[X:.*]] = phi ptr [ %[[B_NON_NULL]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ]
76+
77+
X &castBReftoXRef(B &b) {
78+
return static_cast<X&>(b);
79+
}
80+
81+
// CIR: cir.func {{.*}} @_Z14castBReftoXRefR1B(%[[ARG0:.*]]: !cir.ptr<!rec_B> {{.*}})
82+
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>, ["b", init, const]
83+
// CIR: cir.store %[[ARG0]], %[[B_ADDR]] : !cir.ptr<!rec_B>, !cir.ptr<!cir.ptr<!rec_B>>
84+
// CIR: %[[B:.*]] = cir.load{{.*}} %[[B_ADDR]] : !cir.ptr<!cir.ptr<!rec_B>>, !cir.ptr<!rec_B>
85+
// CIR: %[[X:.*]] = cir.derived_class_addr %[[B]] : !cir.ptr<!rec_B> nonnull [4] -> !cir.ptr<!rec_X>
86+
87+
// LLVM: define {{.*}} ptr @_Z14castBReftoXRefR1B(ptr %[[ARG0:.*]])
88+
// LLVM: %[[B_ADDR:.*]] = alloca ptr
89+
// LLVM: store ptr %[[ARG0]], ptr %[[B_ADDR]]
90+
// LLVM: %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
91+
// LLVM: %[[X:.*]] = getelementptr inbounds i8, ptr %[[B]], i32 -4
92+
93+
// OGCG: define {{.*}} ptr @_Z14castBReftoXRefR1B(ptr {{.*}} %[[ARG0:.*]])
94+
// OGCG: %[[B_ADDR:.*]] = alloca ptr
95+
// OGCG: store ptr %[[ARG0]], ptr %[[B_ADDR]]
96+
// OGCG: %[[B:.*]] = load ptr, ptr %[[B_ADDR]]
97+
// OGCG: %[[X:.*]] = getelementptr inbounds i8, ptr %[[B]], i64 -4

0 commit comments

Comments
 (0)