Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,39 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point", [
}];
}

//===----------------------------------------------------------------------===//
// VTableGetVPtr
//===----------------------------------------------------------------------===//

def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", [Pure]> {
let summary = "Get a the address of the vtable pointer for an object";
let description = [{
The `vtable.get_vptr` operation retrieves the address of the vptr for a
C++ object. This operation requires that the object pointer points to
the start of a complete object. (TODO: Describe how we get that).
The vptr will always be at offset zero in the object, but this operation
is more explicit about what is being retrieved than a direct bitcast.

The return type is always `!cir.ptr<!cir.vptr>`.

Example:
```mlir
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
```
}];

let arguments = (ins
Arg<CIR_PointerType, "the vptr address", [MemRead]>:$src
);

let results = (outs CIR_PtrToVPtr:$result);

let assemblyFormat = [{
$src `:` qualified(type($src)) `->` qualified(type($result)) attr-dict
}];
}

//===----------------------------------------------------------------------===//
// SetBitfieldOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,14 @@ def CIR_AnyFloatOrVecOfFloatType
let cppFunctionName = "isFPOrVectorOfFPType";
}

//===----------------------------------------------------------------------===//
// VPtr type predicates
//===----------------------------------------------------------------------===//

def CIR_AnyVPtrType : CIR_TypeBase<"::cir::VPtrType", "vptr type">;

def CIR_PtrToVPtr : CIR_PtrToType<CIR_AnyVPtrType>;

//===----------------------------------------------------------------------===//
// Scalar Type predicates
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 4 additions & 4 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,10 @@ def CIR_VPtrType : CIR_Type<"VPtr", "vptr", [
access to the vptr.

This type will be the element type of the 'vptr' member of structures that
require a vtable pointer. A pointer to this type is returned by the
`cir.vtable.address_point` and `cir.vtable.get_vptr` operations, and this
pointer may be passed to the `cir.vtable.get_virtual_fn_addr` operation to
get the address of a virtual function pointer.
require a vtable pointer. The `cir.vtable.address_point` operation returns
this type. The `cir.vtable.get_vptr` operations returns a pointer to this
type. This pointer may be passed to the `cir.vtable.get_virtual_fn_addr`
operation to get the address of a virtual function pointer.

The pointer may also be cast to other pointer types in order to perform
pointer arithmetic based on information encoded in the AST layout to get
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
llvm_unreachable("Unsupported format for long double");
}

mlir::Type getPtrToVPtrType() {
return getPointerTo(cir::VPtrType::get(getContext()));
}

/// Get a CIR record kind from a AST declaration tag.
cir::RecordType::RecordKind getRecordKind(const clang::TagTypeKind kind) {
switch (kind) {
Expand Down
25 changes: 21 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
}

// Apply the offsets.
Address vtableField = loadCXXThisAddress();
Address classAddr = loadCXXThisAddress();
if (!nonVirtualOffset.isZero() || virtualOffset) {
cgm.errorNYI(loc,
"initializeVTablePointer: non-virtual and virtual offset");
Expand All @@ -300,9 +300,9 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
// vtable field is derived from `this` pointer, therefore they should be in
// the same addr space.
assert(!cir::MissingFeatures::addressSpace());
// TODO(cir): This should be cir.vtable.get_vptr.
vtableField = builder.createElementBitCast(loc, vtableField,
vtableAddressPoint.getType());
auto vtablePtr = cir::VTableGetVPtrOp::create(
builder, loc, builder.getPtrToVPtrType(), classAddr.getPointer());
Address vtableField = Address(vtablePtr, classAddr.getAlignment());
builder.createStore(loc, vtableAddressPoint, vtableField);
assert(!cir::MissingFeatures::opTBAA());
assert(!cir::MissingFeatures::createInvariantGroup());
Expand Down Expand Up @@ -657,6 +657,23 @@ Address CIRGenFunction::getAddressOfBaseClass(
return value;
}

mlir::Value CIRGenFunction::getVTablePtr(mlir::Location loc, Address thisAddr,
const CXXRecordDecl *rd) {
auto vtablePtr = cir::VTableGetVPtrOp::create(
builder, loc, builder.getPtrToVPtrType(), thisAddr.getPointer());
Address vtablePtrAddr = Address(vtablePtr, thisAddr.getAlignment());

auto vtable = builder.createLoad(loc, vtablePtrAddr);
assert(!cir::MissingFeatures::opTBAA());

if (cgm.getCodeGenOpts().OptimizationLevel > 0 &&
cgm.getCodeGenOpts().StrictVTablePointers) {
assert(!cir::MissingFeatures::createInvariantGroup());
}

return vtable;
}

void CIRGenFunction::emitCXXConstructorCall(const clang::CXXConstructorDecl *d,
clang::CXXCtorType type,
bool forVirtualBase,
Expand Down
15 changes: 14 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2327,7 +2327,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMVecShuffleOpLowering,
CIRToLLVMVecSplatOpLowering,
CIRToLLVMVecTernaryOpLowering,
CIRToLLVMVTableAddrPointOpLowering
CIRToLLVMVTableAddrPointOpLowering,
CIRToLLVMVTableGetVPtrOpLowering
// clang-format on
>(converter, patterns.getContext());

Expand Down Expand Up @@ -2451,6 +2452,18 @@ mlir::LogicalResult CIRToLLVMVTableAddrPointOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::LogicalResult CIRToLLVMVTableGetVPtrOpLowering::matchAndRewrite(
cir::VTableGetVPtrOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
// cir.vtable.get_vptr is equivalent to a bitcast from the source object
// pointer to the vptr type. Since the LLVM dialect uses opaque pointers
// we can just replace uses of this operation with the original pointer.
mlir::Value srcVal = adaptor.getSrc();
rewriter.replaceAllUsesWith(op, srcVal);
rewriter.eraseOp(op);
return mlir::success();
}

mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite(
cir::StackSaveOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,16 @@ class CIRToLLVMVTableAddrPointOpLowering
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMVTableGetVPtrOpLowering
: public mlir::OpConversionPattern<cir::VTableGetVPtrOp> {
public:
using mlir::OpConversionPattern<cir::VTableGetVPtrOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::VTableGetVPtrOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMStackSaveOpLowering
: public mlir::OpConversionPattern<cir::StackSaveOp> {
public:
Expand Down
4 changes: 2 additions & 2 deletions clang/test/CIR/CodeGen/virtual-function-calls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ A::A() {}
// CIR: cir.store %arg0, %[[THIS_ADDR]] : !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>
// CIR: %[[THIS:.*]] = cir.load %[[THIS_ADDR]] : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
// CIR: %[[VPTR:.*]] = cir.vtable.address_point(@_ZTV1A, address_point = <index = 0, offset = 2>) : !cir.vptr
// CIR: %[[THIS_VPTR_PTR:.*]] = cir.cast(bitcast, %[[THIS]] : !cir.ptr<!rec_A>), !cir.ptr<!cir.vptr>
// CIR: cir.store align(8) %[[VPTR]], %[[THIS_VPTR_PTR]] : !cir.vptr, !cir.ptr<!cir.vptr>
// CIR: %[[THIS_VPTR_PTR:.*]] = cir.vtable.get_vptr %[[THIS]] : !cir.ptr<!rec_A> -> !cir.ptr<!cir.vptr>
// CIR: cir.store{{.*}} align(8) %[[VPTR]], %[[THIS_VPTR_PTR]] : !cir.vptr, !cir.ptr<!cir.vptr>
// CIR: cir.return

// LLVM: define{{.*}} void @_ZN1AC2Ev(ptr %[[ARG0:.*]])
Expand Down