Skip to content

Commit aeac352

Browse files
andykaylorlanza
authored andcommitted
[CIR] Add special type and new operations for vptrs (#1745)
This change introduces a new type, cir.vptr, and two new operations, `cir.vtable.get_vptr` and `cir.vtable.get_virtual_fn_addr` to make operations involving vptrs more explicit. This also replaces cases where `cir.vtable.address_point` was being used as a general GEP-like operation and not actually returning the address point of a vtable.
1 parent 6495149 commit aeac352

25 files changed

+411
-174
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 91 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,7 +2594,7 @@ def CIR_GetGlobalOp : CIR_Op<"get_global", [
25942594
// VTableAddrPointOp
25952595
//===----------------------------------------------------------------------===//
25962596

2597-
def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
2597+
def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point", [
25982598
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>
25992599
]> {
26002600
let summary = "Get the vtable (global variable) address point";
@@ -2603,39 +2603,116 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
26032603
(address point) of a C++ virtual table. An object internal `__vptr`
26042604
gets initializated on top of the value returned by this operation.
26052605

2606-
`address_point.index` (vtable index) provides the appropriate vtable within the vtable group
2607-
(as specified by Itanium ABI), and `address_point.offset` (address point index) the actual address
2608-
point within that vtable.
2606+
`address_point.index` (vtable index) provides the appropriate vtable within
2607+
the vtable group (as specified by Itanium ABI), and `address_point.offset`
2608+
(address point index) the actual address point within that vtable.
26092609

2610-
The return type is always a `!cir.ptr<!cir.ptr<() -> i32>>`.
2610+
The return type is always `!cir.vptr`.
26112611

26122612
Example:
26132613
```mlir
26142614
cir.global linkonce_odr @_ZTV1B = ...
26152615
...
2616-
%3 = cir.vtable.address_point(@_ZTV1B, address_point = <index = 0, offset = 2>) : !cir.ptr<!cir.ptr<() -> i32>>
2616+
%3 = cir.vtable.address_point(@_ZTV1B,
2617+
address_point = <index = 0, offset = 2>) : !cir.vptr
26172618
```
26182619
}];
26192620

26202621
let arguments = (ins
2621-
OptionalAttr<FlatSymbolRefAttr>:$name,
2622-
Optional<CIR_AnyType>:$sym_addr,
2622+
FlatSymbolRefAttr:$name,
26232623
CIR_AddressPointAttr:$address_point
26242624
);
26252625

2626-
let results = (outs Res<CIR_PointerType, "", []>:$addr);
2626+
let results = (outs Res<CIR_VPtrType, "", []>:$addr);
26272627

26282628
let assemblyFormat = [{
26292629
`(`
2630-
($name^)?
2631-
($sym_addr^ `:` type($sym_addr))?
2632-
`,`
2633-
`address_point` `=` $address_point
2630+
$name `,` `address_point` `=` $address_point
26342631
`)`
26352632
`:` qualified(type($addr)) attr-dict
26362633
}];
2634+
}
26372635

2638-
let hasVerifier = 1;
2636+
//===----------------------------------------------------------------------===//
2637+
// VTableGetVPtr
2638+
//===----------------------------------------------------------------------===//
2639+
2640+
def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", [Pure]> {
2641+
let summary = "Get a the address of the vtable pointer for an object";
2642+
let description = [{
2643+
The `vtable.get_vptr` operation retrieves the address of the vptr for a
2644+
C++ object. This operation requires that the object pointer points to
2645+
the start of a complete object. (TODO: Describe how we get that).
2646+
The vptr will always be at offset zero in the object, but this operation
2647+
is more explicit about what is being retrieved than a direct bitcast.
2648+
2649+
The return type is always `!cir.ptr<!cir.vptr>`.
2650+
2651+
Example:
2652+
```mlir
2653+
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
2654+
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
2655+
```
2656+
}];
2657+
2658+
let arguments = (ins
2659+
Arg<CIR_PointerType, "the vptr address", [MemRead]>:$src
2660+
);
2661+
2662+
let results = (outs CIR_PtrToVPtr:$result);
2663+
2664+
let assemblyFormat = [{
2665+
$src `:` qualified(type($src)) `->` qualified(type($result)) attr-dict
2666+
}];
2667+
2668+
}
2669+
2670+
//===----------------------------------------------------------------------===//
2671+
// VTableGetVirtualFnAddrOp
2672+
//===----------------------------------------------------------------------===//
2673+
2674+
def CIR_VTableGetVirtualFnAddrOp : CIR_Op<"vtable.get_virtual_fn_addr", [
2675+
Pure
2676+
]> {
2677+
let summary = "Get a the address of a virtual function pointer";
2678+
let description = [{
2679+
The `vtable.get_virtual_fn_addr` operation retrieves the address of a
2680+
virtual function pointer from an object's vtable (__vptr).
2681+
This is an abstraction to perform the basic pointer arithmetic to get
2682+
the address of the virtual function pointer, which can then be loaded and
2683+
called.
2684+
2685+
The `vptr` operand must be a `!cir.ptr<!cir.vptr>` value, which would
2686+
have been returned by a previous call to `cir.vatble.get_vptr`. The
2687+
`index` operand is an index of the virtual function in the vtable.
2688+
2689+
The return type is a pointer-to-pointer to the function type.
2690+
2691+
Example:
2692+
```mlir
2693+
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
2694+
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
2695+
%4 = cir.load %3 : !cir.ptr<!cir.vptr>, !cir.vptr
2696+
%5 = cir.vtable.get_virtual_fn_addr %4[2] : !cir.vptr
2697+
-> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>>
2698+
%6 = cir.load align(8) %5 : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>)
2699+
-> !s32i>>>,
2700+
!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>
2701+
%7 = cir.call %6(%2) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>,
2702+
!cir.ptr<!rec_C>) -> !s32i
2703+
```
2704+
}];
2705+
2706+
let arguments = (ins
2707+
Arg<CIR_VPtrType, "vptr", [MemRead]>:$vptr,
2708+
I64Attr:$index);
2709+
2710+
let results = (outs CIR_PointerType:$result);
2711+
2712+
let assemblyFormat = [{
2713+
$vptr `[` $index `]` attr-dict
2714+
`:` qualified(type($vptr)) `->` qualified(type($result))
2715+
}];
26392716
}
26402717

26412718
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,21 @@ def CIR_PtrToExceptionInfoType
263263
def CIR_AnyDataMemberType : CIR_TypeBase<"::cir::DataMemberType",
264264
"data member type">;
265265

266+
//===----------------------------------------------------------------------===//
267+
// VPtr type predicates
268+
//===----------------------------------------------------------------------===//
269+
270+
def CIR_AnyVPtrType : CIR_TypeBase<"::cir::VPtrType", "vptr type">;
271+
272+
def CIR_PtrToVPtr : CIR_PtrToType<CIR_AnyVPtrType>;
273+
266274
//===----------------------------------------------------------------------===//
267275
// Scalar Type predicates
268276
//===----------------------------------------------------------------------===//
269277

270278
defvar CIR_ScalarTypes = [
271279
CIR_AnyBoolType, CIR_AnyIntType, CIR_AnyFloatType, CIR_AnyPtrType,
272-
CIR_AnyDataMemberType
280+
CIR_AnyDataMemberType, CIR_AnyVPtrType
273281
];
274282

275283
def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,37 @@ def CIR_DataMemberType : CIR_Type<"DataMember", "data_member",
343343
}];
344344
}
345345

346+
//===----------------------------------------------------------------------===//
347+
// CIR_VPtrType
348+
//===----------------------------------------------------------------------===//
349+
350+
def CIR_VPtrType : CIR_Type<"VPtr", "vptr", [
351+
DeclareTypeInterfaceMethods<DataLayoutTypeInterface>
352+
]> {
353+
354+
let summary = "CIR type that is used for the vptr member of C++ objects";
355+
let description = [{
356+
`cir.vptr` is a special type used as the type for the vptr member of a C++
357+
object. This avoids using arbitrary pointer types to declare vptr values
358+
and allows stronger type-based checking for operations that use or provide
359+
access to the vptr.
360+
361+
This type will be the element type of the 'vptr' member of structures that
362+
require a vtable pointer. A pointer to this type is returned by the
363+
`cir.vtable.address_point` and `cir.vtable.get_vptr` operations, and this
364+
pointer may be passed to the `cir.vtable.get_virtual_fn_addr` operation to
365+
get the address of a virtual function pointer.
366+
367+
The pointer may also be cast to other pointer types in order to perform
368+
pointer arithmetic based on information encoded in the AST layout to get
369+
the offset from a pointer to a dynamic object to the base object pointer,
370+
the base object offset value from the vtable, or the type information
371+
entry for an object.
372+
TODO: We should have special operations to do that too.
373+
}];
374+
}
375+
376+
346377
//===----------------------------------------------------------------------===//
347378
// BoolType
348379
//===----------------------------------------------------------------------===//
@@ -751,7 +782,8 @@ def CIRRecordType : Type<
751782
def CIR_AnyType : AnyTypeOf<[
752783
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_MethodType,
753784
CIR_BoolType, CIR_ArrayType, CIR_VectorType, CIR_FuncType, CIR_VoidType,
754-
CIR_RecordType, CIR_ExceptionType, CIR_AnyFloatType, CIR_ComplexType
785+
CIR_RecordType, CIR_ExceptionType, CIR_AnyFloatType, CIR_ComplexType,
786+
CIR_VPtrType
755787
]>;
756788

757789
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -425,12 +425,8 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
425425
llvm_unreachable("unsupported long double format");
426426
}
427427

428-
mlir::Type getVirtualFnPtrType(bool isVarArg = false) {
429-
// FIXME: replay LLVM codegen for now, perhaps add a vtable ptr special
430-
// type so it's a bit more clear and C++ idiomatic.
431-
auto fnTy = cir::FuncType::get({}, getUInt32Ty(), isVarArg);
432-
assert(!cir::MissingFeatures::isVarArg());
433-
return getPointerTo(getPointerTo(fnTy));
428+
mlir::Type getPtrToVPtrType() {
429+
return getPointerTo(cir::VPtrType::get(getContext()));
434430
}
435431

436432
cir::FuncType getFuncType(llvm::ArrayRef<mlir::Type> params, mlir::Type retTy,

clang/lib/CIR/CodeGen/CIRGenClass.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,10 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
806806
//
807807
// vtable field is derived from `this` pointer, therefore they should be in
808808
// the same addr space.
809+
// TODO(cir): We should be using cir.get_vptr rather than a bitcast to get
810+
// the vptr field, but the call to ApplyNonVirtualAndVirtualOffset
811+
// will also need to be adjusted. That should probably be using
812+
// cir.base_class_addr.
809813
assert(!cir::MissingFeatures::addressSpace());
810814
VTableField = builder.createElementBitCast(loc, VTableField,
811815
VTableAddressPoint.getType());
@@ -1704,10 +1708,12 @@ void CIRGenFunction::emitTypeMetadataCodeForVCall(const CXXRecordDecl *RD,
17041708
}
17051709

17061710
mlir::Value CIRGenFunction::getVTablePtr(mlir::Location Loc, Address This,
1707-
mlir::Type VTableTy,
17081711
const CXXRecordDecl *RD) {
1709-
Address VTablePtrSrc = builder.createElementBitCast(Loc, This, VTableTy);
1710-
auto VTable = builder.createLoad(Loc, VTablePtrSrc);
1712+
auto VTablePtr = builder.create<cir::VTableGetVPtrOp>(
1713+
Loc, builder.getPtrToVPtrType(), This.getPointer());
1714+
Address VTablePtrAddr = Address(VTablePtr, This.getAlignment());
1715+
1716+
auto VTable = builder.createLoad(Loc, VTablePtrAddr);
17111717
assert(!cir::MissingFeatures::tbaa());
17121718

17131719
if (CGM.getCodeGenOpts().OptimizationLevel > 0 &&

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,6 @@ class CIRGenFunction : public CIRGenTypeCache {
957957
VisitedVirtualBasesSetTy &VBases, VPtrsVector &vptrs);
958958
/// Return the Value of the vtable pointer member pointed to by This.
959959
mlir::Value getVTablePtr(mlir::Location Loc, Address This,
960-
mlir::Type VTableTy,
961960
const CXXRecordDecl *VTableClass);
962961

963962
/// Returns whether we should perform a type checked load when loading a

clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -935,11 +935,11 @@ cir::GlobalOp CIRGenItaniumCXXABI::getAddrOfVTable(const CXXRecordDecl *RD,
935935
CIRGenCallee CIRGenItaniumCXXABI::getVirtualFunctionPointer(
936936
CIRGenFunction &CGF, GlobalDecl GD, Address This, mlir::Type Ty,
937937
SourceLocation Loc) {
938+
auto &builder = CGM.getBuilder();
938939
auto loc = CGF.getLoc(Loc);
939-
auto TyPtr = CGF.getBuilder().getPointerTo(Ty);
940+
auto TyPtr = builder.getPointerTo(Ty);
940941
auto *MethodDecl = cast<CXXMethodDecl>(GD.getDecl());
941-
auto VTable = CGF.getVTablePtr(
942-
loc, This, CGF.getBuilder().getPointerTo(TyPtr), MethodDecl->getParent());
942+
auto VTable = CGF.getVTablePtr(loc, This, MethodDecl->getParent());
943943

944944
uint64_t VTableIndex = CGM.getItaniumVTableContext().getMethodVTableIndex(GD);
945945
mlir::Value VFunc{};
@@ -952,15 +952,10 @@ CIRGenCallee CIRGenItaniumCXXABI::getVirtualFunctionPointer(
952952
if (CGM.getItaniumVTableContext().isRelativeLayout()) {
953953
llvm_unreachable("NYI");
954954
} else {
955-
VTable = CGF.getBuilder().createBitcast(
956-
loc, VTable, CGF.getBuilder().getPointerTo(TyPtr));
957-
auto VTableSlotPtr = CGF.getBuilder().create<cir::VTableAddrPointOp>(
958-
loc, CGF.getBuilder().getPointerTo(TyPtr),
959-
::mlir::FlatSymbolRefAttr{}, VTable,
960-
cir::AddressPointAttr::get(CGF.getBuilder().getContext(), 0,
961-
VTableIndex));
962-
VFuncLoad = CGF.getBuilder().createAlignedLoad(loc, TyPtr, VTableSlotPtr,
963-
CGF.getPointerAlign());
955+
auto VTableSlotPtr = builder.create<cir::VTableGetVirtualFnAddrOp>(
956+
loc, builder.getPointerTo(TyPtr), VTable, VTableIndex);
957+
VFuncLoad = builder.createAlignedLoad(loc, TyPtr, VTableSlotPtr,
958+
CGF.getPointerAlign());
964959
}
965960

966961
// Add !invariant.load md to virtual function load to indicate that
@@ -1014,11 +1009,11 @@ CIRGenItaniumCXXABI::getVTableAddressPoint(BaseSubobject Base,
10141009
.getAddressPoint(Base);
10151010

10161011
auto &builder = CGM.getBuilder();
1017-
auto vtablePtrTy = builder.getVirtualFnPtrType(/*isVarArg=*/false);
1012+
auto vtablePtrTy = cir::VPtrType::get(builder.getContext());
10181013

10191014
return builder.create<cir::VTableAddrPointOp>(
10201015
CGM.getLoc(VTableClass->getSourceRange()), vtablePtrTy,
1021-
mlir::FlatSymbolRefAttr::get(vtable.getSymNameAttr()), mlir::Value{},
1016+
mlir::FlatSymbolRefAttr::get(vtable.getSymNameAttr()),
10221017
cir::AddressPointAttr::get(CGM.getBuilder().getContext(),
10231018
AddressPoint.VTableIndex,
10241019
AddressPoint.AddressPointIndex));
@@ -2411,14 +2406,16 @@ void CIRGenItaniumCXXABI::emitThrow(CIRGenFunction &CGF,
24112406
mlir::Value CIRGenItaniumCXXABI::getVirtualBaseClassOffset(
24122407
mlir::Location loc, CIRGenFunction &CGF, Address This,
24132408
const CXXRecordDecl *ClassDecl, const CXXRecordDecl *BaseClassDecl) {
2414-
auto VTablePtr = CGF.getVTablePtr(loc, This, CGM.UInt8PtrTy, ClassDecl);
2409+
auto VTablePtr = CGF.getVTablePtr(loc, This, ClassDecl);
2410+
auto VTableBytePtr =
2411+
CGF.getBuilder().createBitcast(VTablePtr, CGM.UInt8PtrTy);
24152412
CharUnits VBaseOffsetOffset =
24162413
CGM.getItaniumVTableContext().getVirtualBaseOffsetOffset(ClassDecl,
24172414
BaseClassDecl);
24182415
mlir::Value OffsetVal =
24192416
CGF.getBuilder().getSInt64(VBaseOffsetOffset.getQuantity(), loc);
24202417
auto VBaseOffsetPtr = CGF.getBuilder().create<cir::PtrStrideOp>(
2421-
loc, VTablePtr.getType(), VTablePtr,
2418+
loc, CGM.UInt8PtrTy, VTableBytePtr,
24222419
OffsetVal); // vbase.offset.ptr
24232420

24242421
mlir::Value VBaseOffset;

clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,9 +488,7 @@ void CIRRecordLowering::accumulateVPtrs() {
488488
}
489489

490490
mlir::Type CIRRecordLowering::getVFPtrType() {
491-
// FIXME: replay LLVM codegen for now, perhaps add a vtable ptr special
492-
// type so it's a bit more clear and C++ idiomatic.
493-
return builder.getVirtualFnPtrType();
491+
return cir::VPtrType::get(builder.getContext());
494492
}
495493

496494
void CIRRecordLowering::fillOutputFields() {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,12 @@ LogicalResult cir::CastOp::verify() {
561561
return success();
562562
}
563563

564+
// Allow casting cir.vptr to pointer types.
565+
// TODO: Add operations to get object offset and type info and remove this.
566+
if (mlir::isa<cir::VPtrType>(srcType) &&
567+
mlir::dyn_cast<cir::PointerType>(resType))
568+
return success();
569+
564570
// Handle the data member pointer types.
565571
if (mlir::isa<cir::DataMemberType>(srcType) &&
566572
mlir::isa<cir::DataMemberType>(resType))
@@ -2389,10 +2395,7 @@ cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
23892395

23902396
LogicalResult
23912397
cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2392-
// vtable ptr is not coming from a symbol.
2393-
if (!getName())
2394-
return success();
2395-
auto name = *getName();
2398+
StringRef name = getName();
23962399

23972400
// Verify that the result type underlying pointer type matches the type of
23982401
// the referenced cir.global or cir.func op.
@@ -2410,27 +2413,6 @@ cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
24102413
return success();
24112414
}
24122415

2413-
LogicalResult cir::VTableAddrPointOp::verify() {
2414-
// The operation uses either a symbol or a value to operate, but not both
2415-
if (getName() && getSymAddr())
2416-
return emitOpError("should use either a symbol or value, but not both");
2417-
2418-
// If not a symbol, stick with the concrete type used for getSymAddr.
2419-
if (getSymAddr())
2420-
return success();
2421-
2422-
auto resultType = getAddr().getType();
2423-
auto intTy = cir::IntType::get(getContext(), 32, /*isSigned=*/false);
2424-
auto fnTy = cir::FuncType::get({}, intTy);
2425-
2426-
auto resTy = cir::PointerType::get(cir::PointerType::get(fnTy));
2427-
2428-
if (resultType != resTy)
2429-
return emitOpError("result type must be '")
2430-
<< resTy << "', but provided result type is '" << resultType << "'";
2431-
return success();
2432-
}
2433-
24342416
//===----------------------------------------------------------------------===//
24352417
// VTTAddrPointOp
24362418
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)