Skip to content

Conversation

andykaylor
Copy link
Contributor

This adds support for the cir.vtable.get_vptr operation and uses it to initialize the vptr member during constructors of dynamic classes.

This adds support for the cir.vtable.get_vptr operation and uses it
to initialize the vptr member during constructors of dynamic classes.
@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels Aug 14, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 14, 2025

@llvm/pr-subscribers-clang

Author: Andy Kaylor (andykaylor)

Changes

This adds support for the cir.vtable.get_vptr operation and uses it to initialize the vptr member during constructors of dynamic classes.


Full diff: https://github.com/llvm/llvm-project/pull/153630.diff

8 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+33)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td (+8)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIRTypes.td (+4-4)
  • (modified) clang/lib/CIR/CodeGen/CIRGenBuilder.h (+4)
  • (modified) clang/lib/CIR/CodeGen/CIRGenClass.cpp (+21-4)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+14-1)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h (+10)
  • (modified) clang/test/CIR/CodeGen/virtual-function-calls.cpp (+2-2)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index b64fd2734a63c..0e06996266316 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1749,6 +1749,39 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// VTableGetVPtr
+//===----------------------------------------------------------------------===//
+
+def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", [Pure]> {
+  let summary = "Get a the address of the vtable pointer for an object";
+  let description = [{
+    The `vtable.get_vptr` operation retrieves the address of the vptr for a
+    C++ object. This operation requires that the object pointer points to
+    the start of a complete object. (TODO: Describe how we get that).
+    The vptr will always be at offset zero in the object, but this operation
+    is more explicit about what is being retrieved than a direct bitcast.
+
+    The return type is always `!cir.ptr<!cir.vptr>`.
+
+    Example:
+    ```mlir
+    %2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
+    %3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
+    ```
+  }];
+
+  let arguments = (ins
+    Arg<CIR_PointerType, "the vptr address", [MemRead]>:$src
+  );
+
+  let results = (outs CIR_PtrToVPtr:$result);
+
+  let assemblyFormat = [{
+      $src `:` qualified(type($src)) `->` qualified(type($result)) attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // SetBitfieldOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
index d7d55dfbc0654..82f6e1d33043e 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
@@ -289,6 +289,14 @@ def CIR_AnyFloatOrVecOfFloatType
     let cppFunctionName = "isFPOrVectorOfFPType";
 }
 
+//===----------------------------------------------------------------------===//
+// VPtr type predicates
+//===----------------------------------------------------------------------===//
+
+def CIR_AnyVPtrType : CIR_TypeBase<"::cir::VPtrType", "vptr type">;
+
+def CIR_PtrToVPtr : CIR_PtrToType<CIR_AnyVPtrType>;
+
 //===----------------------------------------------------------------------===//
 // Scalar Type predicates
 //===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
index a258df79a6184..312d0a9422673 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
@@ -296,10 +296,10 @@ def CIR_VPtrType : CIR_Type<"VPtr", "vptr", [
     access to the vptr.
 
     This type will be the element type of the 'vptr' member of structures that
-    require a vtable pointer. A pointer to this type is returned by the
-    `cir.vtable.address_point` and `cir.vtable.get_vptr` operations, and this
-    pointer may be passed to the `cir.vtable.get_virtual_fn_addr` operation to
-    get the address of a virtual function pointer.
+    require a vtable pointer. The `cir.vtable.address_point` operation returns
+    this type. The `cir.vtable.get_vptr` operations returns a pointer to this
+    type. This pointer may be passed to the `cir.vtable.get_virtual_fn_addr`
+    operation to get the address of a virtual function pointer.
 
     The pointer may also be cast to other pointer types in order to perform
     pointer arithmetic based on information encoded in the AST layout to get
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
index 8b2538c941f47..e4676bc044c6c 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h
+++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
@@ -83,6 +83,10 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
     llvm_unreachable("Unsupported format for long double");
   }
 
+  mlir::Type getPtrToVPtrType() {
+    return getPointerTo(cir::VPtrType::get(getContext()));
+  }
+
   /// Get a CIR record kind from a AST declaration tag.
   cir::RecordType::RecordKind getRecordKind(const clang::TagTypeKind kind) {
     switch (kind) {
diff --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
index 31c93cd00d083..a3947047de079 100644
--- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
@@ -289,7 +289,7 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
   }
 
   // Apply the offsets.
-  Address vtableField = loadCXXThisAddress();
+  Address classAddr = loadCXXThisAddress();
   if (!nonVirtualOffset.isZero() || virtualOffset) {
     cgm.errorNYI(loc,
                  "initializeVTablePointer: non-virtual and virtual offset");
@@ -300,9 +300,9 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
   // vtable field is derived from `this` pointer, therefore they should be in
   // the same addr space.
   assert(!cir::MissingFeatures::addressSpace());
-  // TODO(cir): This should be cir.vtable.get_vptr.
-  vtableField = builder.createElementBitCast(loc, vtableField,
-                                             vtableAddressPoint.getType());
+  auto vtablePtr = cir::VTableGetVPtrOp::create(
+      builder, loc, builder.getPtrToVPtrType(), classAddr.getPointer());
+  Address vtableField = Address(vtablePtr, classAddr.getAlignment());
   builder.createStore(loc, vtableAddressPoint, vtableField);
   assert(!cir::MissingFeatures::opTBAA());
   assert(!cir::MissingFeatures::createInvariantGroup());
@@ -657,6 +657,23 @@ Address CIRGenFunction::getAddressOfBaseClass(
   return value;
 }
 
+mlir::Value CIRGenFunction::getVTablePtr(mlir::Location loc, Address thisAddr,
+                                         const CXXRecordDecl *rd) {
+  auto vtablePtr = cir::VTableGetVPtrOp::create(
+      builder, loc, builder.getPtrToVPtrType(), thisAddr.getPointer());
+  Address vtablePtrAddr = Address(vtablePtr, thisAddr.getAlignment());
+
+  auto vtable = builder.createLoad(loc, vtablePtrAddr);
+  assert(!cir::MissingFeatures::opTBAA());
+
+  if (cgm.getCodeGenOpts().OptimizationLevel > 0 &&
+      cgm.getCodeGenOpts().StrictVTablePointers) {
+    assert(!cir::MissingFeatures::createInvariantGroup());
+  }
+
+  return vtable;
+}
+
 void CIRGenFunction::emitCXXConstructorCall(const clang::CXXConstructorDecl *d,
                                             clang::CXXCtorType type,
                                             bool forVirtualBase,
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index c3715c28f6890..afbf562792d16 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -2327,7 +2327,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
                CIRToLLVMVecShuffleOpLowering,
                CIRToLLVMVecSplatOpLowering,
                CIRToLLVMVecTernaryOpLowering,
-               CIRToLLVMVTableAddrPointOpLowering
+               CIRToLLVMVTableAddrPointOpLowering,
+               CIRToLLVMVTableGetVPtrOpLowering
       // clang-format on
       >(converter, patterns.getContext());
 
@@ -2451,6 +2452,18 @@ mlir::LogicalResult CIRToLLVMVTableAddrPointOpLowering::matchAndRewrite(
   return mlir::success();
 }
 
+mlir::LogicalResult CIRToLLVMVTableGetVPtrOpLowering::matchAndRewrite(
+    cir::VTableGetVPtrOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  // cir.vtable.get_vptr is equivalent to a bitcast from the source object
+  // pointer to the vptr type. Since the LLVM dialect uses opaque pointers
+  // we can just replace uses of this operation with the original pointer.
+  mlir::Value srcVal = adaptor.getSrc();
+  rewriter.replaceAllUsesWith(op, srcVal);
+  rewriter.eraseOp(op);
+  return mlir::success();
+}
+
 mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite(
     cir::StackSaveOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index 6fbe0079b90d0..2ffc6279c286b 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -467,6 +467,16 @@ class CIRToLLVMVTableAddrPointOpLowering
                   mlir::ConversionPatternRewriter &) const override;
 };
 
+class CIRToLLVMVTableGetVPtrOpLowering
+    : public mlir::OpConversionPattern<cir::VTableGetVPtrOp> {
+public:
+  using mlir::OpConversionPattern<cir::VTableGetVPtrOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::VTableGetVPtrOp op, OpAdaptor,
+                  mlir::ConversionPatternRewriter &) const override;
+};
+
 class CIRToLLVMStackSaveOpLowering
     : public mlir::OpConversionPattern<cir::StackSaveOp> {
 public:
diff --git a/clang/test/CIR/CodeGen/virtual-function-calls.cpp b/clang/test/CIR/CodeGen/virtual-function-calls.cpp
index 004b6dab30563..4787d78aa0e35 100644
--- a/clang/test/CIR/CodeGen/virtual-function-calls.cpp
+++ b/clang/test/CIR/CodeGen/virtual-function-calls.cpp
@@ -27,8 +27,8 @@ A::A() {}
 // CIR:    cir.store %arg0, %[[THIS_ADDR]] : !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>
 // CIR:    %[[THIS:.*]] = cir.load %[[THIS_ADDR]] : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
 // CIR:    %[[VPTR:.*]] = cir.vtable.address_point(@_ZTV1A, address_point = <index = 0, offset = 2>) : !cir.vptr
-// CIR:    %[[THIS_VPTR_PTR:.*]] = cir.cast(bitcast, %[[THIS]] : !cir.ptr<!rec_A>), !cir.ptr<!cir.vptr>
-// CIR:    cir.store align(8) %[[VPTR]], %[[THIS_VPTR_PTR]] : !cir.vptr, !cir.ptr<!cir.vptr>
+// CIR:    %[[THIS_VPTR_PTR:.*]] = cir.vtable.get_vptr %[[THIS]] : !cir.ptr<!rec_A> -> !cir.ptr<!cir.vptr>
+// CIR:    cir.store{{.*}} align(8) %[[VPTR]], %[[THIS_VPTR_PTR]] : !cir.vptr, !cir.ptr<!cir.vptr>
 // CIR:    cir.return
 
 // LLVM: define{{.*}} void @_ZN1AC2Ev(ptr %[[ARG0:.*]])

@llvmbot
Copy link
Member

llvmbot commented Aug 14, 2025

@llvm/pr-subscribers-clangir

Author: Andy Kaylor (andykaylor)

Changes

This adds support for the cir.vtable.get_vptr operation and uses it to initialize the vptr member during constructors of dynamic classes.


Full diff: https://github.com/llvm/llvm-project/pull/153630.diff

8 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+33)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td (+8)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIRTypes.td (+4-4)
  • (modified) clang/lib/CIR/CodeGen/CIRGenBuilder.h (+4)
  • (modified) clang/lib/CIR/CodeGen/CIRGenClass.cpp (+21-4)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+14-1)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h (+10)
  • (modified) clang/test/CIR/CodeGen/virtual-function-calls.cpp (+2-2)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index b64fd2734a63c..0e06996266316 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1749,6 +1749,39 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// VTableGetVPtr
+//===----------------------------------------------------------------------===//
+
+def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", [Pure]> {
+  let summary = "Get a the address of the vtable pointer for an object";
+  let description = [{
+    The `vtable.get_vptr` operation retrieves the address of the vptr for a
+    C++ object. This operation requires that the object pointer points to
+    the start of a complete object. (TODO: Describe how we get that).
+    The vptr will always be at offset zero in the object, but this operation
+    is more explicit about what is being retrieved than a direct bitcast.
+
+    The return type is always `!cir.ptr<!cir.vptr>`.
+
+    Example:
+    ```mlir
+    %2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
+    %3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
+    ```
+  }];
+
+  let arguments = (ins
+    Arg<CIR_PointerType, "the vptr address", [MemRead]>:$src
+  );
+
+  let results = (outs CIR_PtrToVPtr:$result);
+
+  let assemblyFormat = [{
+      $src `:` qualified(type($src)) `->` qualified(type($result)) attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // SetBitfieldOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
index d7d55dfbc0654..82f6e1d33043e 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
@@ -289,6 +289,14 @@ def CIR_AnyFloatOrVecOfFloatType
     let cppFunctionName = "isFPOrVectorOfFPType";
 }
 
+//===----------------------------------------------------------------------===//
+// VPtr type predicates
+//===----------------------------------------------------------------------===//
+
+def CIR_AnyVPtrType : CIR_TypeBase<"::cir::VPtrType", "vptr type">;
+
+def CIR_PtrToVPtr : CIR_PtrToType<CIR_AnyVPtrType>;
+
 //===----------------------------------------------------------------------===//
 // Scalar Type predicates
 //===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
index a258df79a6184..312d0a9422673 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
@@ -296,10 +296,10 @@ def CIR_VPtrType : CIR_Type<"VPtr", "vptr", [
     access to the vptr.
 
     This type will be the element type of the 'vptr' member of structures that
-    require a vtable pointer. A pointer to this type is returned by the
-    `cir.vtable.address_point` and `cir.vtable.get_vptr` operations, and this
-    pointer may be passed to the `cir.vtable.get_virtual_fn_addr` operation to
-    get the address of a virtual function pointer.
+    require a vtable pointer. The `cir.vtable.address_point` operation returns
+    this type. The `cir.vtable.get_vptr` operations returns a pointer to this
+    type. This pointer may be passed to the `cir.vtable.get_virtual_fn_addr`
+    operation to get the address of a virtual function pointer.
 
     The pointer may also be cast to other pointer types in order to perform
     pointer arithmetic based on information encoded in the AST layout to get
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
index 8b2538c941f47..e4676bc044c6c 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h
+++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
@@ -83,6 +83,10 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
     llvm_unreachable("Unsupported format for long double");
   }
 
+  mlir::Type getPtrToVPtrType() {
+    return getPointerTo(cir::VPtrType::get(getContext()));
+  }
+
   /// Get a CIR record kind from a AST declaration tag.
   cir::RecordType::RecordKind getRecordKind(const clang::TagTypeKind kind) {
     switch (kind) {
diff --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
index 31c93cd00d083..a3947047de079 100644
--- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
@@ -289,7 +289,7 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
   }
 
   // Apply the offsets.
-  Address vtableField = loadCXXThisAddress();
+  Address classAddr = loadCXXThisAddress();
   if (!nonVirtualOffset.isZero() || virtualOffset) {
     cgm.errorNYI(loc,
                  "initializeVTablePointer: non-virtual and virtual offset");
@@ -300,9 +300,9 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
   // vtable field is derived from `this` pointer, therefore they should be in
   // the same addr space.
   assert(!cir::MissingFeatures::addressSpace());
-  // TODO(cir): This should be cir.vtable.get_vptr.
-  vtableField = builder.createElementBitCast(loc, vtableField,
-                                             vtableAddressPoint.getType());
+  auto vtablePtr = cir::VTableGetVPtrOp::create(
+      builder, loc, builder.getPtrToVPtrType(), classAddr.getPointer());
+  Address vtableField = Address(vtablePtr, classAddr.getAlignment());
   builder.createStore(loc, vtableAddressPoint, vtableField);
   assert(!cir::MissingFeatures::opTBAA());
   assert(!cir::MissingFeatures::createInvariantGroup());
@@ -657,6 +657,23 @@ Address CIRGenFunction::getAddressOfBaseClass(
   return value;
 }
 
+mlir::Value CIRGenFunction::getVTablePtr(mlir::Location loc, Address thisAddr,
+                                         const CXXRecordDecl *rd) {
+  auto vtablePtr = cir::VTableGetVPtrOp::create(
+      builder, loc, builder.getPtrToVPtrType(), thisAddr.getPointer());
+  Address vtablePtrAddr = Address(vtablePtr, thisAddr.getAlignment());
+
+  auto vtable = builder.createLoad(loc, vtablePtrAddr);
+  assert(!cir::MissingFeatures::opTBAA());
+
+  if (cgm.getCodeGenOpts().OptimizationLevel > 0 &&
+      cgm.getCodeGenOpts().StrictVTablePointers) {
+    assert(!cir::MissingFeatures::createInvariantGroup());
+  }
+
+  return vtable;
+}
+
 void CIRGenFunction::emitCXXConstructorCall(const clang::CXXConstructorDecl *d,
                                             clang::CXXCtorType type,
                                             bool forVirtualBase,
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index c3715c28f6890..afbf562792d16 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -2327,7 +2327,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
                CIRToLLVMVecShuffleOpLowering,
                CIRToLLVMVecSplatOpLowering,
                CIRToLLVMVecTernaryOpLowering,
-               CIRToLLVMVTableAddrPointOpLowering
+               CIRToLLVMVTableAddrPointOpLowering,
+               CIRToLLVMVTableGetVPtrOpLowering
       // clang-format on
       >(converter, patterns.getContext());
 
@@ -2451,6 +2452,18 @@ mlir::LogicalResult CIRToLLVMVTableAddrPointOpLowering::matchAndRewrite(
   return mlir::success();
 }
 
+mlir::LogicalResult CIRToLLVMVTableGetVPtrOpLowering::matchAndRewrite(
+    cir::VTableGetVPtrOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  // cir.vtable.get_vptr is equivalent to a bitcast from the source object
+  // pointer to the vptr type. Since the LLVM dialect uses opaque pointers
+  // we can just replace uses of this operation with the original pointer.
+  mlir::Value srcVal = adaptor.getSrc();
+  rewriter.replaceAllUsesWith(op, srcVal);
+  rewriter.eraseOp(op);
+  return mlir::success();
+}
+
 mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite(
     cir::StackSaveOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index 6fbe0079b90d0..2ffc6279c286b 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -467,6 +467,16 @@ class CIRToLLVMVTableAddrPointOpLowering
                   mlir::ConversionPatternRewriter &) const override;
 };
 
+class CIRToLLVMVTableGetVPtrOpLowering
+    : public mlir::OpConversionPattern<cir::VTableGetVPtrOp> {
+public:
+  using mlir::OpConversionPattern<cir::VTableGetVPtrOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::VTableGetVPtrOp op, OpAdaptor,
+                  mlir::ConversionPatternRewriter &) const override;
+};
+
 class CIRToLLVMStackSaveOpLowering
     : public mlir::OpConversionPattern<cir::StackSaveOp> {
 public:
diff --git a/clang/test/CIR/CodeGen/virtual-function-calls.cpp b/clang/test/CIR/CodeGen/virtual-function-calls.cpp
index 004b6dab30563..4787d78aa0e35 100644
--- a/clang/test/CIR/CodeGen/virtual-function-calls.cpp
+++ b/clang/test/CIR/CodeGen/virtual-function-calls.cpp
@@ -27,8 +27,8 @@ A::A() {}
 // CIR:    cir.store %arg0, %[[THIS_ADDR]] : !cir.ptr<!rec_A>, !cir.ptr<!cir.ptr<!rec_A>>
 // CIR:    %[[THIS:.*]] = cir.load %[[THIS_ADDR]] : !cir.ptr<!cir.ptr<!rec_A>>, !cir.ptr<!rec_A>
 // CIR:    %[[VPTR:.*]] = cir.vtable.address_point(@_ZTV1A, address_point = <index = 0, offset = 2>) : !cir.vptr
-// CIR:    %[[THIS_VPTR_PTR:.*]] = cir.cast(bitcast, %[[THIS]] : !cir.ptr<!rec_A>), !cir.ptr<!cir.vptr>
-// CIR:    cir.store align(8) %[[VPTR]], %[[THIS_VPTR_PTR]] : !cir.vptr, !cir.ptr<!cir.vptr>
+// CIR:    %[[THIS_VPTR_PTR:.*]] = cir.vtable.get_vptr %[[THIS]] : !cir.ptr<!rec_A> -> !cir.ptr<!cir.vptr>
+// CIR:    cir.store{{.*}} align(8) %[[VPTR]], %[[THIS_VPTR_PTR]] : !cir.vptr, !cir.ptr<!cir.vptr>
 // CIR:    cir.return
 
 // LLVM: define{{.*}} void @_ZN1AC2Ev(ptr %[[ARG0:.*]])

Copy link
Member

@bcardosolopes bcardosolopes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@andykaylor andykaylor merged commit 0cd35e7 into llvm:main Aug 15, 2025
12 checks passed
@andykaylor andykaylor deleted the cir-get-vptr-op branch August 15, 2025 22:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants