Skip to content

Commit b3f2d93

Browse files
authored
[CIR] Add lowering support for dynamic cast (#162715)
This adds support for lowering cir.dyn_cast operations to a form that can be lowered to LLVM IR.
1 parent 312f1fa commit b3f2d93

File tree

9 files changed

+293
-10
lines changed

9 files changed

+293
-10
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
451451
return createBitcast(src, getPointerTo(newPointeeTy));
452452
}
453453

454+
mlir::Value createPtrIsNull(mlir::Value ptr) {
455+
mlir::Value nullPtr = getNullPtr(ptr.getType(), ptr.getLoc());
456+
return createCompare(ptr.getLoc(), cir::CmpOpKind::eq, ptr, nullPtr);
457+
}
458+
454459
//===--------------------------------------------------------------------===//
455460
// Binary Operators
456461
//===--------------------------------------------------------------------===//
@@ -644,6 +649,12 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
644649
return getI64IntegerAttr(size.getQuantity());
645650
}
646651

652+
// Creates constant nullptr for pointer type ty.
653+
cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc) {
654+
assert(!cir::MissingFeatures::targetCodeGenInfoGetNullPointer());
655+
return cir::ConstantOp::create(*this, loc, getConstPtrAttr(ty, 0));
656+
}
657+
647658
/// Create a loop condition.
648659
cir::ConditionOp createCondition(mlir::Value condition) {
649660
return cir::ConditionOp::create(*this, condition.getLoc(), condition);

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ struct MissingFeatures {
8585
static bool opFuncReadOnly() { return false; }
8686
static bool opFuncSection() { return false; }
8787
static bool opFuncWillReturn() { return false; }
88+
static bool opFuncNoReturn() { return false; }
8889
static bool setLLVMFunctionFEnvAttributes() { return false; }
8990
static bool setFunctionAttributes() { return false; }
9091

@@ -256,6 +257,8 @@ struct MissingFeatures {
256257
static bool loopInfoStack() { return false; }
257258
static bool lowerAggregateLoadStore() { return false; }
258259
static bool lowerModeOptLevel() { return false; }
260+
static bool loweringPrepareX86CXXABI() { return false; }
261+
static bool loweringPrepareAArch64XXABI() { return false; }
259262
static bool maybeHandleStaticInExternC() { return false; }
260263
static bool mergeAllConstants() { return false; }
261264
static bool metaDataNode() { return false; }

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,12 +319,6 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
319319
return cir::ConstantOp::create(*this, loc, cir::IntAttr::get(sInt64Ty, c));
320320
}
321321

322-
// Creates constant nullptr for pointer type ty.
323-
cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc) {
324-
assert(!cir::MissingFeatures::targetCodeGenInfoGetNullPointer());
325-
return cir::ConstantOp::create(*this, loc, getConstPtrAttr(ty, 0));
326-
}
327-
328322
mlir::Value createNeg(mlir::Value value) {
329323

330324
if (auto intTy = mlir::dyn_cast<cir::IntType>(value.getType())) {

clang/lib/CIR/Dialect/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_clang_library(MLIRCIRTransforms
44
FlattenCFG.cpp
55
HoistAllocas.cpp
66
LoweringPrepare.cpp
7+
LoweringPrepareItaniumCXXABI.cpp
78
GotoSolver.cpp
89

910
DEPENDS

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "LoweringPrepareCXXABI.h"
910
#include "PassDetail.h"
1011
#include "clang/AST/ASTContext.h"
1112
#include "clang/Basic/Module.h"
@@ -62,6 +63,7 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
6263
void lowerComplexMulOp(cir::ComplexMulOp op);
6364
void lowerUnaryOp(cir::UnaryOp op);
6465
void lowerGlobalOp(cir::GlobalOp op);
66+
void lowerDynamicCastOp(cir::DynamicCastOp op);
6567
void lowerArrayDtor(cir::ArrayDtor op);
6668
void lowerArrayCtor(cir::ArrayCtor op);
6769

@@ -91,6 +93,9 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
9193

9294
clang::ASTContext *astCtx;
9395

96+
// Helper for lowering C++ ABI specific operations.
97+
std::shared_ptr<cir::LoweringPrepareCXXABI> cxxABI;
98+
9499
/// Tracks current module.
95100
mlir::ModuleOp mlirModule;
96101

@@ -101,7 +106,24 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
101106
/// List of ctors and their priorities to be called before main()
102107
llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList;
103108

104-
void setASTContext(clang::ASTContext *c) { astCtx = c; }
109+
void setASTContext(clang::ASTContext *c) {
110+
astCtx = c;
111+
switch (c->getCXXABIKind()) {
112+
case clang::TargetCXXABI::GenericItanium:
113+
// We'll need X86-specific support for handling vaargs lowering, but for
114+
// now the Itanium ABI will work.
115+
assert(!cir::MissingFeatures::loweringPrepareX86CXXABI());
116+
cxxABI.reset(cir::LoweringPrepareCXXABI::createItaniumABI());
117+
break;
118+
case clang::TargetCXXABI::GenericAArch64:
119+
case clang::TargetCXXABI::AppleARM64:
120+
assert(!cir::MissingFeatures::loweringPrepareAArch64XXABI());
121+
cxxABI.reset(cir::LoweringPrepareCXXABI::createItaniumABI());
122+
break;
123+
default:
124+
llvm_unreachable("NYI");
125+
}
126+
}
105127
};
106128

107129
} // namespace
@@ -850,6 +872,17 @@ void LoweringPreparePass::buildCXXGlobalInitFunc() {
850872
cir::ReturnOp::create(builder, f.getLoc());
851873
}
852874

875+
void LoweringPreparePass::lowerDynamicCastOp(DynamicCastOp op) {
876+
CIRBaseBuilderTy builder(getContext());
877+
builder.setInsertionPointAfter(op);
878+
879+
assert(astCtx && "AST context is not available during lowering prepare");
880+
auto loweredValue = cxxABI->lowerDynamicCast(builder, *astCtx, op);
881+
882+
op.replaceAllUsesWith(loweredValue);
883+
op.erase();
884+
}
885+
853886
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder,
854887
clang::ASTContext *astCtx,
855888
mlir::Operation *op, mlir::Type eltTy,
@@ -954,6 +987,8 @@ void LoweringPreparePass::runOnOp(mlir::Operation *op) {
954987
lowerComplexMulOp(complexMul);
955988
else if (auto glob = mlir::dyn_cast<cir::GlobalOp>(op))
956989
lowerGlobalOp(glob);
990+
else if (auto dynamicCast = mlir::dyn_cast<cir::DynamicCastOp>(op))
991+
lowerDynamicCastOp(dynamicCast);
957992
else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
958993
lowerUnaryOp(unary);
959994
}
@@ -967,8 +1002,8 @@ void LoweringPreparePass::runOnOperation() {
9671002

9681003
op->walk([&](mlir::Operation *op) {
9691004
if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
970-
cir::ComplexMulOp, cir::ComplexDivOp, cir::GlobalOp,
971-
cir::UnaryOp>(op))
1005+
cir::ComplexMulOp, cir::ComplexDivOp, cir::DynamicCastOp,
1006+
cir::GlobalOp, cir::UnaryOp>(op))
9721007
opsToTransform.push_back(op);
9731008
});
9741009

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file provides the LoweringPrepareCXXABI class, which is the base class
10+
// for ABI specific functionalities that are required during LLVM lowering
11+
// prepare.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef CIR_DIALECT_TRANSFORMS__LOWERINGPREPARECXXABI_H
16+
#define CIR_DIALECT_TRANSFORMS__LOWERINGPREPARECXXABI_H
17+
18+
#include "mlir/IR/Value.h"
19+
#include "clang/AST/ASTContext.h"
20+
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
21+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
22+
23+
namespace cir {
24+
25+
class LoweringPrepareCXXABI {
26+
public:
27+
static LoweringPrepareCXXABI *createItaniumABI();
28+
29+
virtual ~LoweringPrepareCXXABI() {}
30+
31+
virtual mlir::Value lowerDynamicCast(CIRBaseBuilderTy &builder,
32+
clang::ASTContext &astCtx,
33+
cir::DynamicCastOp op) = 0;
34+
};
35+
36+
} // namespace cir
37+
38+
#endif // CIR_DIALECT_TRANSFORMS__LOWERINGPREPARECXXABI_H
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
//===--------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with
4+
// LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===--------------------------------------------------------------------===//
9+
//
10+
// This file provides Itanium C++ ABI specific code
11+
// that is used during LLVMIR lowering prepare.
12+
//
13+
//===--------------------------------------------------------------------===//
14+
15+
#include "LoweringPrepareCXXABI.h"
16+
#include "mlir/IR/BuiltinAttributes.h"
17+
#include "mlir/IR/Value.h"
18+
#include "mlir/IR/ValueRange.h"
19+
#include "clang/Basic/TargetInfo.h"
20+
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
21+
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
22+
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
23+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
24+
#include "clang/CIR/MissingFeatures.h"
25+
26+
class LoweringPrepareItaniumCXXABI : public cir::LoweringPrepareCXXABI {
27+
public:
28+
mlir::Value lowerDynamicCast(cir::CIRBaseBuilderTy &builder,
29+
clang::ASTContext &astCtx,
30+
cir::DynamicCastOp op) override;
31+
};
32+
33+
cir::LoweringPrepareCXXABI *cir::LoweringPrepareCXXABI::createItaniumABI() {
34+
return new LoweringPrepareItaniumCXXABI();
35+
}
36+
37+
static void buildBadCastCall(cir::CIRBaseBuilderTy &builder, mlir::Location loc,
38+
mlir::FlatSymbolRefAttr badCastFuncRef) {
39+
builder.createCallOp(loc, badCastFuncRef, cir::VoidType(),
40+
mlir::ValueRange{});
41+
// TODO(cir): Set the 'noreturn' attribute on the function.
42+
assert(!cir::MissingFeatures::opFuncNoReturn());
43+
cir::UnreachableOp::create(builder, loc);
44+
builder.clearInsertionPoint();
45+
}
46+
47+
static mlir::Value
48+
buildDynamicCastAfterNullCheck(cir::CIRBaseBuilderTy &builder,
49+
cir::DynamicCastOp op) {
50+
mlir::Location loc = op->getLoc();
51+
mlir::Value srcValue = op.getSrc();
52+
cir::DynamicCastInfoAttr castInfo = op.getInfo().value();
53+
54+
// TODO(cir): consider address space
55+
assert(!cir::MissingFeatures::addressSpace());
56+
57+
mlir::Value srcPtr = builder.createBitcast(srcValue, builder.getVoidPtrTy());
58+
cir::ConstantOp srcRtti = builder.getConstant(loc, castInfo.getSrcRtti());
59+
cir::ConstantOp destRtti = builder.getConstant(loc, castInfo.getDestRtti());
60+
cir::ConstantOp offsetHint =
61+
builder.getConstant(loc, castInfo.getOffsetHint());
62+
63+
mlir::FlatSymbolRefAttr dynCastFuncRef = castInfo.getRuntimeFunc();
64+
mlir::Value dynCastFuncArgs[4] = {srcPtr, srcRtti, destRtti, offsetHint};
65+
66+
mlir::Value castedPtr =
67+
builder
68+
.createCallOp(loc, dynCastFuncRef, builder.getVoidPtrTy(),
69+
dynCastFuncArgs)
70+
.getResult();
71+
72+
assert(mlir::isa<cir::PointerType>(castedPtr.getType()) &&
73+
"the return value of __dynamic_cast should be a ptr");
74+
75+
/// C++ [expr.dynamic.cast]p9:
76+
/// A failed cast to reference type throws std::bad_cast
77+
if (op.isRefCast()) {
78+
// Emit a cir.if that checks the casted value.
79+
mlir::Value castedValueIsNull = builder.createPtrIsNull(castedPtr);
80+
builder.create<cir::IfOp>(
81+
loc, castedValueIsNull, false, [&](mlir::OpBuilder &, mlir::Location) {
82+
buildBadCastCall(builder, loc, castInfo.getBadCastFunc());
83+
});
84+
}
85+
86+
// Note that castedPtr is a void*. Cast it to a pointer to the destination
87+
// type before return.
88+
return builder.createBitcast(castedPtr, op.getType());
89+
}
90+
91+
static mlir::Value
92+
buildDynamicCastToVoidAfterNullCheck(cir::CIRBaseBuilderTy &builder,
93+
clang::ASTContext &astCtx,
94+
cir::DynamicCastOp op) {
95+
llvm_unreachable("dynamic cast to void is NYI");
96+
}
97+
98+
mlir::Value
99+
LoweringPrepareItaniumCXXABI::lowerDynamicCast(cir::CIRBaseBuilderTy &builder,
100+
clang::ASTContext &astCtx,
101+
cir::DynamicCastOp op) {
102+
mlir::Location loc = op->getLoc();
103+
mlir::Value srcValue = op.getSrc();
104+
105+
assert(!cir::MissingFeatures::emitTypeCheck());
106+
107+
if (op.isRefCast())
108+
return buildDynamicCastAfterNullCheck(builder, op);
109+
110+
mlir::Value srcValueIsNotNull = builder.createPtrToBoolCast(srcValue);
111+
return builder
112+
.create<cir::TernaryOp>(
113+
loc, srcValueIsNotNull,
114+
[&](mlir::OpBuilder &, mlir::Location) {
115+
mlir::Value castedValue =
116+
op.isCastToVoid()
117+
? buildDynamicCastToVoidAfterNullCheck(builder, astCtx, op)
118+
: buildDynamicCastAfterNullCheck(builder, op);
119+
builder.createYield(loc, castedValue);
120+
},
121+
[&](mlir::OpBuilder &, mlir::Location) {
122+
builder.createYield(
123+
loc, builder.getNullPtr(op.getType(), loc).getResult());
124+
})
125+
.getResult();
126+
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1739,7 +1739,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
17391739
const mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage());
17401740
const StringRef symbol = op.getSymName();
17411741
SmallVector<mlir::NamedAttribute> attributes;
1742-
mlir::SymbolRefAttr comdatAttr = getComdatAttr(op, rewriter);
17431742

17441743
if (init.has_value()) {
17451744
if (mlir::isa<cir::FPAttr, cir::IntAttr, cir::BoolAttr>(init.value())) {
@@ -1771,6 +1770,7 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
17711770
}
17721771

17731772
// Rewrite op.
1773+
mlir::SymbolRefAttr comdatAttr = getComdatAttr(op, rewriter);
17741774
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
17751775
op, llvmType, isConst, linkage, symbol, init.value_or(mlir::Attribute()),
17761776
alignment, addrSpace, isDsoLocal, isThreadLocal, comdatAttr, attributes);

0 commit comments

Comments
 (0)