Skip to content

Commit 0cd35e7

Browse files
authored
[CIR] Add cir.vtable.get_vptr operation (#153630)
This adds support for the cir.vtable.get_vptr operation and uses it to initialize the vptr member during constructors of dynamic classes.
1 parent b7d6f48 commit 0cd35e7

File tree

8 files changed

+96
-11
lines changed

8 files changed

+96
-11
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,39 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point", [
17491749
}];
17501750
}
17511751

1752+
//===----------------------------------------------------------------------===//
1753+
// VTableGetVPtr
1754+
//===----------------------------------------------------------------------===//
1755+
1756+
def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", [Pure]> {
1757+
let summary = "Get a the address of the vtable pointer for an object";
1758+
let description = [{
1759+
The `vtable.get_vptr` operation retrieves the address of the vptr for a
1760+
C++ object. This operation requires that the object pointer points to
1761+
the start of a complete object. (TODO: Describe how we get that).
1762+
The vptr will always be at offset zero in the object, but this operation
1763+
is more explicit about what is being retrieved than a direct bitcast.
1764+
1765+
The return type is always `!cir.ptr<!cir.vptr>`.
1766+
1767+
Example:
1768+
```mlir
1769+
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
1770+
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
1771+
```
1772+
}];
1773+
1774+
let arguments = (ins
1775+
Arg<CIR_PointerType, "the vptr address", [MemRead]>:$src
1776+
);
1777+
1778+
let results = (outs CIR_PtrToVPtr:$result);
1779+
1780+
let assemblyFormat = [{
1781+
$src `:` qualified(type($src)) `->` qualified(type($result)) attr-dict
1782+
}];
1783+
}
1784+
17521785
//===----------------------------------------------------------------------===//
17531786
// SetBitfieldOp
17541787
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,14 @@ def CIR_AnyFloatOrVecOfFloatType
289289
let cppFunctionName = "isFPOrVectorOfFPType";
290290
}
291291

292+
//===----------------------------------------------------------------------===//
293+
// VPtr type predicates
294+
//===----------------------------------------------------------------------===//
295+
296+
def CIR_AnyVPtrType : CIR_TypeBase<"::cir::VPtrType", "vptr type">;
297+
298+
def CIR_PtrToVPtr : CIR_PtrToType<CIR_AnyVPtrType>;
299+
292300
//===----------------------------------------------------------------------===//
293301
// Scalar Type predicates
294302
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,10 @@ def CIR_VPtrType : CIR_Type<"VPtr", "vptr", [
296296
access to the vptr.
297297

298298
This type will be the element type of the 'vptr' member of structures that
299-
require a vtable pointer. A pointer to this type is returned by the
300-
`cir.vtable.address_point` and `cir.vtable.get_vptr` operations, and this
301-
pointer may be passed to the `cir.vtable.get_virtual_fn_addr` operation to
302-
get the address of a virtual function pointer.
299+
require a vtable pointer. The `cir.vtable.address_point` operation returns
300+
this type. The `cir.vtable.get_vptr` operations returns a pointer to this
301+
type. This pointer may be passed to the `cir.vtable.get_virtual_fn_addr`
302+
operation to get the address of a virtual function pointer.
303303

304304
The pointer may also be cast to other pointer types in order to perform
305305
pointer arithmetic based on information encoded in the AST layout to get

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
8484
llvm_unreachable("Unsupported format for long double");
8585
}
8686

87+
mlir::Type getPtrToVPtrType() {
88+
return getPointerTo(cir::VPtrType::get(getContext()));
89+
}
90+
8791
/// Get a CIR record kind from a AST declaration tag.
8892
cir::RecordType::RecordKind getRecordKind(const clang::TagTypeKind kind) {
8993
switch (kind) {

clang/lib/CIR/CodeGen/CIRGenClass.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
289289
}
290290

291291
// Apply the offsets.
292-
Address vtableField = loadCXXThisAddress();
292+
Address classAddr = loadCXXThisAddress();
293293
if (!nonVirtualOffset.isZero() || virtualOffset) {
294294
cgm.errorNYI(loc,
295295
"initializeVTablePointer: non-virtual and virtual offset");
@@ -300,9 +300,9 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
300300
// vtable field is derived from `this` pointer, therefore they should be in
301301
// the same addr space.
302302
assert(!cir::MissingFeatures::addressSpace());
303-
// TODO(cir): This should be cir.vtable.get_vptr.
304-
vtableField = builder.createElementBitCast(loc, vtableField,
305-
vtableAddressPoint.getType());
303+
auto vtablePtr = cir::VTableGetVPtrOp::create(
304+
builder, loc, builder.getPtrToVPtrType(), classAddr.getPointer());
305+
Address vtableField = Address(vtablePtr, classAddr.getAlignment());
306306
builder.createStore(loc, vtableAddressPoint, vtableField);
307307
assert(!cir::MissingFeatures::opTBAA());
308308
assert(!cir::MissingFeatures::createInvariantGroup());
@@ -657,6 +657,23 @@ Address CIRGenFunction::getAddressOfBaseClass(
657657
return value;
658658
}
659659

660+
mlir::Value CIRGenFunction::getVTablePtr(mlir::Location loc, Address thisAddr,
661+
const CXXRecordDecl *rd) {
662+
auto vtablePtr = cir::VTableGetVPtrOp::create(
663+
builder, loc, builder.getPtrToVPtrType(), thisAddr.getPointer());
664+
Address vtablePtrAddr = Address(vtablePtr, thisAddr.getAlignment());
665+
666+
auto vtable = builder.createLoad(loc, vtablePtrAddr);
667+
assert(!cir::MissingFeatures::opTBAA());
668+
669+
if (cgm.getCodeGenOpts().OptimizationLevel > 0 &&
670+
cgm.getCodeGenOpts().StrictVTablePointers) {
671+
assert(!cir::MissingFeatures::createInvariantGroup());
672+
}
673+
674+
return vtable;
675+
}
676+
660677
void CIRGenFunction::emitCXXConstructorCall(const clang::CXXConstructorDecl *d,
661678
clang::CXXCtorType type,
662679
bool forVirtualBase,

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2344,7 +2344,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
23442344
CIRToLLVMVecShuffleOpLowering,
23452345
CIRToLLVMVecSplatOpLowering,
23462346
CIRToLLVMVecTernaryOpLowering,
2347-
CIRToLLVMVTableAddrPointOpLowering
2347+
CIRToLLVMVTableAddrPointOpLowering,
2348+
CIRToLLVMVTableGetVPtrOpLowering
23482349
// clang-format on
23492350
>(converter, patterns.getContext());
23502351

@@ -2468,6 +2469,18 @@ mlir::LogicalResult CIRToLLVMVTableAddrPointOpLowering::matchAndRewrite(
24682469
return mlir::success();
24692470
}
24702471

2472+
mlir::LogicalResult CIRToLLVMVTableGetVPtrOpLowering::matchAndRewrite(
2473+
cir::VTableGetVPtrOp op, OpAdaptor adaptor,
2474+
mlir::ConversionPatternRewriter &rewriter) const {
2475+
// cir.vtable.get_vptr is equivalent to a bitcast from the source object
2476+
// pointer to the vptr type. Since the LLVM dialect uses opaque pointers
2477+
// we can just replace uses of this operation with the original pointer.
2478+
mlir::Value srcVal = adaptor.getSrc();
2479+
rewriter.replaceAllUsesWith(op, srcVal);
2480+
rewriter.eraseOp(op);
2481+
return mlir::success();
2482+
}
2483+
24712484
mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite(
24722485
cir::StackSaveOp op, OpAdaptor adaptor,
24732486
mlir::ConversionPatternRewriter &rewriter) const {

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,16 @@ class CIRToLLVMVTableAddrPointOpLowering
467467
mlir::ConversionPatternRewriter &) const override;
468468
};
469469

470+
class CIRToLLVMVTableGetVPtrOpLowering
471+
: public mlir::OpConversionPattern<cir::VTableGetVPtrOp> {
472+
public:
473+
using mlir::OpConversionPattern<cir::VTableGetVPtrOp>::OpConversionPattern;
474+
475+
mlir::LogicalResult
476+
matchAndRewrite(cir::VTableGetVPtrOp op, OpAdaptor,
477+
mlir::ConversionPatternRewriter &) const override;
478+
};
479+
470480
class CIRToLLVMStackSaveOpLowering
471481
: public mlir::OpConversionPattern<cir::StackSaveOp> {
472482
public:

clang/test/CIR/CodeGen/virtual-function-calls.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ A::A() {}
2727
// CIR: cir.store %arg0, %[[THIS_ADDR]] : !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>
2828
// CIR: %[[THIS:.*]] = cir.load %[[THIS_ADDR]] : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
2929
// CIR: %[[VPTR:.*]] = cir.vtable.address_point(@_ZTV1A, address_point = <index = 0, offset = 2>) : !cir.vptr
30-
// CIR: %[[THIS_VPTR_PTR:.*]] = cir.cast(bitcast, %[[THIS]] : !cir.ptr<!rec_A>), !cir.ptr<!cir.vptr>
31-
// CIR: cir.store align(8) %[[VPTR]], %[[THIS_VPTR_PTR]] : !cir.vptr, !cir.ptr<!cir.vptr>
30+
// CIR: %[[THIS_VPTR_PTR:.*]] = cir.vtable.get_vptr %[[THIS]] : !cir.ptr<!rec_A> -> !cir.ptr<!cir.vptr>
31+
// CIR: cir.store{{.*}} align(8) %[[VPTR]], %[[THIS_VPTR_PTR]] : !cir.vptr, !cir.ptr<!cir.vptr>
3232
// CIR: cir.return
3333

3434
// LLVM: define{{.*}} void @_ZN1AC2Ev(ptr %[[ARG0:.*]])

0 commit comments

Comments
 (0)