Skip to content

Commit 45066c2

Browse files
authored
[CIR] Add lowering for the cir.vtable.address_point operation (#153243)
This adds support for lowering the cir.vtable.address_point operation to the LLVM dialect, as well as type converter support for the cir.vptr type.
1 parent 6961139 commit 45066c2

File tree

3 files changed

+87
-2
lines changed

3 files changed

+87
-2
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2083,6 +2083,10 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
20832083

20842084
return mlir::LLVM::LLVMPointerType::get(type.getContext(), targetAS);
20852085
});
2086+
converter.addConversion([&](cir::VPtrType type) -> mlir::Type {
2087+
assert(!cir::MissingFeatures::addressSpace());
2088+
return mlir::LLVM::LLVMPointerType::get(type.getContext());
2089+
});
20862090
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
20872091
mlir::Type ty =
20882092
convertTypeForMemory(converter, dataLayout, type.getElementType());
@@ -2314,6 +2318,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
23142318
CIRToLLVMSwitchFlatOpLowering,
23152319
CIRToLLVMTrapOpLowering,
23162320
CIRToLLVMUnaryOpLowering,
2321+
CIRToLLVMUnreachableOpLowering,
23172322
CIRToLLVMVecCmpOpLowering,
23182323
CIRToLLVMVecCreateOpLowering,
23192324
CIRToLLVMVecExtractOpLowering,
@@ -2322,7 +2327,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
23222327
CIRToLLVMVecShuffleOpLowering,
23232328
CIRToLLVMVecSplatOpLowering,
23242329
CIRToLLVMVecTernaryOpLowering,
2325-
CIRToLLVMUnreachableOpLowering
2330+
CIRToLLVMVTableAddrPointOpLowering
2331+
// clang-format on
23262332
>(converter, patterns.getContext());
23272333

23282334
processCIRAttrs(module);
@@ -2400,6 +2406,51 @@ mlir::LogicalResult CIRToLLVMTrapOpLowering::matchAndRewrite(
24002406
return mlir::success();
24012407
}
24022408

2409+
static mlir::Value
2410+
getValueForVTableSymbol(mlir::Operation *op,
2411+
mlir::ConversionPatternRewriter &rewriter,
2412+
const mlir::TypeConverter *converter,
2413+
mlir::FlatSymbolRefAttr nameAttr, mlir::Type &eltType) {
2414+
auto module = op->getParentOfType<mlir::ModuleOp>();
2415+
mlir::Operation *symbol = mlir::SymbolTable::lookupSymbolIn(module, nameAttr);
2416+
if (auto llvmSymbol = mlir::dyn_cast<mlir::LLVM::GlobalOp>(symbol)) {
2417+
eltType = llvmSymbol.getType();
2418+
} else if (auto cirSymbol = mlir::dyn_cast<cir::GlobalOp>(symbol)) {
2419+
eltType = converter->convertType(cirSymbol.getSymType());
2420+
} else {
2421+
op->emitError() << "unexpected symbol type for " << symbol;
2422+
return {};
2423+
}
2424+
2425+
return mlir::LLVM::AddressOfOp::create(
2426+
rewriter, op->getLoc(),
2427+
mlir::LLVM::LLVMPointerType::get(op->getContext()), nameAttr.getValue());
2428+
}
2429+
2430+
mlir::LogicalResult CIRToLLVMVTableAddrPointOpLowering::matchAndRewrite(
2431+
cir::VTableAddrPointOp op, OpAdaptor adaptor,
2432+
mlir::ConversionPatternRewriter &rewriter) const {
2433+
const mlir::TypeConverter *converter = getTypeConverter();
2434+
mlir::Type targetType = converter->convertType(op.getType());
2435+
llvm::SmallVector<mlir::LLVM::GEPArg> offsets;
2436+
mlir::Type eltType;
2437+
mlir::Value symAddr = getValueForVTableSymbol(op, rewriter, converter,
2438+
op.getNameAttr(), eltType);
2439+
if (!symAddr)
2440+
return op.emitError() << "Unable to get value for vtable symbol";
2441+
2442+
offsets = llvm::SmallVector<mlir::LLVM::GEPArg>{
2443+
0, op.getAddressPointAttr().getIndex(),
2444+
op.getAddressPointAttr().getOffset()};
2445+
2446+
assert(eltType && "Shouldn't ever be missing an eltType here");
2447+
mlir::LLVM::GEPNoWrapFlags inboundsNuw =
2448+
mlir::LLVM::GEPNoWrapFlags::inbounds | mlir::LLVM::GEPNoWrapFlags::nuw;
2449+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, targetType, eltType,
2450+
symAddr, offsets, inboundsNuw);
2451+
return mlir::success();
2452+
}
2453+
24032454
mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite(
24042455
cir::StackSaveOp op, OpAdaptor adaptor,
24052456
mlir::ConversionPatternRewriter &rewriter) const {

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,16 @@ class CIRToLLVMBaseClassAddrOpLowering
457457
mlir::ConversionPatternRewriter &) const override;
458458
};
459459

460+
class CIRToLLVMVTableAddrPointOpLowering
461+
: public mlir::OpConversionPattern<cir::VTableAddrPointOp> {
462+
public:
463+
using mlir::OpConversionPattern<cir::VTableAddrPointOp>::OpConversionPattern;
464+
465+
mlir::LogicalResult
466+
matchAndRewrite(cir::VTableAddrPointOp op, OpAdaptor,
467+
mlir::ConversionPatternRewriter &) const override;
468+
};
469+
460470
class CIRToLLVMStackSaveOpLowering
461471
: public mlir::OpConversionPattern<cir::StackSaveOp> {
462472
public:

clang/test/CIR/CodeGen/virtual-function-calls.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -mconstructor-aliases -fclangir -emit-cir %s -o %t.cir
22
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -mconstructor-aliases -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -mconstructor-aliases -emit-llvm %s -o %t.ll
6+
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
37

48
struct A {
59
A();
@@ -14,11 +18,31 @@ A::A() {}
1418

1519
// CIR: cir.global "private" external @_ZTV1A : !rec_anon_struct
1620

17-
// CIR: cir.func dso_local @_ZN1AC2Ev(%arg0: !cir.ptr<!rec_A> {{.*}})
21+
// LLVM: @_ZTV1A = external global { [3 x ptr] }
22+
23+
// OGCG: @_ZTV1A = external unnamed_addr constant { [3 x ptr] }
24+
25+
// CIR: cir.func{{.*}} @_ZN1AC2Ev(%arg0: !cir.ptr<!rec_A> {{.*}})
1826
// CIR: %[[THIS_ADDR:.*]] = cir.alloca !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>, ["this", init]
1927
// CIR: cir.store %arg0, %[[THIS_ADDR]] : !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>
2028
// CIR: %[[THIS:.*]] = cir.load %[[THIS_ADDR]] : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
2129
// CIR: %[[VPTR:.*]] = cir.vtable.address_point(@_ZTV1A, address_point = <index = 0, offset = 2>) : !cir.vptr
2230
// CIR: %[[THIS_VPTR_PTR:.*]] = cir.cast(bitcast, %[[THIS]] : !cir.ptr<!rec_A>), !cir.ptr<!cir.vptr>
2331
// CIR: cir.store align(8) %[[VPTR]], %[[THIS_VPTR_PTR]] : !cir.vptr, !cir.ptr<!cir.vptr>
2432
// CIR: cir.return
33+
34+
// LLVM: define{{.*}} void @_ZN1AC2Ev(ptr %[[ARG0:.*]])
35+
// LLVM: %[[THIS_ADDR:.*]] = alloca ptr
36+
// LLVM: store ptr %[[ARG0]], ptr %[[THIS_ADDR]]
37+
// LLVM: %[[THIS:.*]] = load ptr, ptr %[[THIS_ADDR]]
38+
// LLVM: store ptr getelementptr inbounds nuw (i8, ptr @_ZTV1A, i64 16), ptr %[[THIS]]
39+
40+
// OGCG: define{{.*}} void @_ZN1AC2Ev(ptr {{.*}} %[[ARG0:.*]])
41+
// OGCG: %[[THIS_ADDR:.*]] = alloca ptr
42+
// OGCG: store ptr %[[ARG0]], ptr %[[THIS_ADDR]]
43+
// OGCG: %[[THIS:.*]] = load ptr, ptr %[[THIS_ADDR]]
44+
// OGCG: store ptr getelementptr inbounds inrange(-16, 8) ({ [3 x ptr] }, ptr @_ZTV1A, i32 0, i32 0, i32 2), ptr %[[THIS]]
45+
46+
// NOTE: The GEP in OGCG looks very different from the one generated with CIR,
47+
// but it is equivalent. The OGCG GEP indexes by base pointer, then
48+
// structure, then array, whereas the CIR GEP indexes by byte offset.

0 commit comments

Comments
 (0)