diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index f237642700924..89cc97d168e10 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1906,6 +1906,70 @@ def CIR_VTableGetVirtualFnAddrOp : CIR_Op<"vtable.get_virtual_fn_addr", [ }]; } +//===----------------------------------------------------------------------===// +// VTTAddrPointOp +//===----------------------------------------------------------------------===// + +def CIR_VTTAddrPointOp : CIR_Op<"vtt.address_point", [ + Pure, DeclareOpInterfaceMethods +]> { + let summary = "Get the VTT address point"; + let description = [{ + The `vtt.address_point` operation retrieves an element from the virtual + table table (VTT), which is the address point of a C++ vtable. In virtual + inheritance, a set of internal `__vptr` members for an object are + initialized by this operation, which assigns an element from the VTT. The + initialization order is as follows: + + The complete object constructors and destructors find the VTT, + via the mangled name of the VTT global variable. They pass the address of + the subobject's sub-VTT entry in the VTT as a second parameter + when calling the base object constructors and destructors. + The base object constructors and destructors use the address passed to + initialize the primary virtual pointer and virtual pointers that point to + the classes which either have virtual bases or override virtual functions + with a virtual step. + + The first parameter is either the mangled name of VTT global variable + or the address of the subobject's sub-VTT entry in the VTT. + The second parameter `offset` provides a virtual step to adjust to + the actual address point of the vtable. + + The return type is always a `!cir.ptr>`. + + Example: + ```mlir + cir.global linkonce_odr @_ZTV1B = ... + ... + %3 = cir.base_class_addr(%1 : !cir.ptr nonnull) [0] + -> !cir.ptr + %4 = cir.vtt.address_point @_ZTT1D, offset = 1 + -> !cir.ptr> + cir.call @_ZN1BC2Ev(%3, %4) + ``` + Or: + ```mlir + %7 = cir.vtt.address_point %3 : !cir.ptr>, offset = 1 + -> !cir.ptr> + ``` + }]; + + let arguments = (ins OptionalAttr:$name, + Optional:$sym_addr, + I32Attr:$offset); + let results = (outs CIR_PointerType:$addr); + + let assemblyFormat = [{ + ($name^)? + ($sym_addr^ `:` type($sym_addr))? + `,` + `offset` `=` $offset + `->` qualified(type($addr)) attr-dict + }]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // SetBitfieldOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 8c53939e89d01..83fff09d4fab3 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1474,6 +1474,53 @@ cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +//===----------------------------------------------------------------------===// +// VTTAddrPointOp +//===----------------------------------------------------------------------===// + +LogicalResult +cir::VTTAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // VTT ptr is not coming from a symbol. + if (!getName()) + return success(); + StringRef name = *getName(); + + // Verify that the result type underlying pointer type matches the type of + // the referenced cir.global op. + auto op = + symbolTable.lookupNearestSymbolFrom(*this, getNameAttr()); + if (!op) + return emitOpError("'") + << name << "' does not reference a valid cir.global"; + std::optional init = op.getInitialValue(); + if (!init) + return success(); + if (!isa(*init)) + return emitOpError( + "Expected constant array in initializer for global VTT '") + << name << "'"; + return success(); +} + +LogicalResult cir::VTTAddrPointOp::verify() { + // The operation uses either a symbol or a value to operate, but not both + if (getName() && getSymAddr()) + return emitOpError("should use either a symbol or value, but not both"); + + // If not a symbol, stick with the concrete type used for getSymAddr. + if (getSymAddr()) + return success(); + + mlir::Type resultType = getAddr().getType(); + mlir::Type resTy = cir::PointerType::get( + cir::PointerType::get(cir::VoidType::get(getContext()))); + + if (resultType != resTy) + return emitOpError("result type must be ") + << resTy << ", but provided result type is " << resultType; + return success(); +} + //===----------------------------------------------------------------------===// // FuncOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 9f0e4e6ecb8ce..03955dc737828 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -2451,7 +2451,8 @@ void ConvertCIRToLLVMPass::runOnOperation() { CIRToLLVMVecTernaryOpLowering, CIRToLLVMVTableAddrPointOpLowering, CIRToLLVMVTableGetVPtrOpLowering, - CIRToLLVMVTableGetVirtualFnAddrOpLowering + CIRToLLVMVTableGetVirtualFnAddrOpLowering, + CIRToLLVMVTTAddrPointOpLowering // clang-format on >(converter, patterns.getContext()); @@ -2600,6 +2601,36 @@ mlir::LogicalResult CIRToLLVMVTableGetVirtualFnAddrOpLowering::matchAndRewrite( return mlir::success(); } +mlir::LogicalResult CIRToLLVMVTTAddrPointOpLowering::matchAndRewrite( + cir::VTTAddrPointOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + const mlir::Type resultType = getTypeConverter()->convertType(op.getType()); + llvm::SmallVector offsets; + mlir::Type eltType; + mlir::Value llvmAddr = adaptor.getSymAddr(); + + if (op.getSymAddr()) { + if (op.getOffset() == 0) { + rewriter.replaceOp(op, {llvmAddr}); + return mlir::success(); + } + + offsets.push_back(adaptor.getOffset()); + eltType = mlir::IntegerType::get(resultType.getContext(), 8, + mlir::IntegerType::Signless); + } else { + llvmAddr = getValueForVTableSymbol(op, rewriter, getTypeConverter(), + op.getNameAttr(), eltType); + assert(eltType && "Shouldn't ever be missing an eltType here"); + offsets.push_back(0); + offsets.push_back(adaptor.getOffset()); + } + rewriter.replaceOpWithNewOp( + op, resultType, eltType, llvmAddr, offsets, + mlir::LLVM::GEPNoWrapFlags::inbounds); + return mlir::success(); +} + mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite( cir::StackSaveOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 7b109c5cef9d3..513ad37839f1b 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -508,6 +508,16 @@ class CIRToLLVMVTableGetVirtualFnAddrOpLowering mlir::ConversionPatternRewriter &) const override; }; +class CIRToLLVMVTTAddrPointOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VTTAddrPointOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + class CIRToLLVMStackSaveOpLowering : public mlir::OpConversionPattern { public: diff --git a/clang/test/CIR/IR/invalid-vtable.cir b/clang/test/CIR/IR/invalid-vtable.cir index 41ddd4c2419be..2e880169abbfa 100644 --- a/clang/test/CIR/IR/invalid-vtable.cir +++ b/clang/test/CIR/IR/invalid-vtable.cir @@ -1,6 +1,5 @@ // RUN: cir-opt %s -verify-diagnostics -split-input-file -!s8i = !cir.int !u32i = !cir.int cir.func @reference_unknown_vtable() { // expected-error @below {{'cir.vtable.address_point' op 'some_vtable' does not reference a valid cir.global}} @@ -13,7 +12,7 @@ cir.func @reference_unknown_vtable() { !u8i = !cir.int !u32i = !cir.int cir.global linkonce_odr @_ZTT1D = #cir.const_array<[#cir.global_view<@_ZTV1D, [0 : i32, 3 : i32]> : !cir.ptr, #cir.global_view<@_ZTC1D0_1B, [0 : i32, 3 : i32]> : !cir.ptr]> : !cir.array x 2> -cir.func @reference_unknown_vtable() { +cir.func @reference_non_vtable() { // expected-error @below {{Expected #cir.vtable in initializer for global '_ZTT1D'}} %0 = cir.vtable.address_point(@_ZTT1D, address_point = ) : !cir.vptr cir.return @@ -82,3 +81,54 @@ module { cir.func private dso_local @_ZN1S6nonKeyEv(%arg0: !cir.ptr) cir.func private dso_local @_ZN2S23keyEv(%arg0: !cir.ptr) } + +// ----- + +!u32i = !cir.int +!void = !cir.void +cir.func @reference_unknown_vtt() { + // expected-error @below {{'cir.vtt.address_point' op 'some_vtt' does not reference a valid cir.global}} + %0 = cir.vtt.address_point @some_vtt, offset = 1 -> !cir.ptr> + cir.return +} + +// ----- + +!u8i = !cir.int +!u32i = !cir.int +!void = !cir.void +!rec_anon_struct = !cir.record x 4>}> +cir.global external @_ZTV1S = #cir.vtable<{#cir.const_array<[#cir.ptr : !cir.ptr, #cir.ptr : !cir.ptr, #cir.global_view<@_ZN1S3keyEv> : !cir.ptr, #cir.global_view<@_ZN1S6nonKeyEv> : !cir.ptr]> : !cir.array x 4>}> : !rec_anon_struct {alignment = 8 : i64} +cir.func @reference_non_vtt() { + // expected-error @below {{'cir.vtt.address_point' op Expected constant array in initializer for global VTT '_ZTV1S'}} + %0 = cir.vtt.address_point @_ZTV1S, offset = 1 -> !cir.ptr> + cir.return +} + +// ----- + +!u8i = !cir.int +!u32i = !cir.int +!void = !cir.void +!rec_anon_struct = !cir.record x 4>}> +!rec_C = !cir.record +cir.global linkonce_odr @_ZTT1C = #cir.const_array<[#cir.global_view<@_ZTV1C, [0 : i32, 3 : i32]> : !cir.ptr, #cir.global_view<@_ZTC1C0_1B, [0 : i32, 3 : i32]> : !cir.ptr]> : !cir.array x 2> {alignment = 8 : i64} +cir.func @reference_name_and_value(%arg0: !cir.ptr, %arg1: !cir.ptr>) { + // expected-error @below {{'cir.vtt.address_point' op should use either a symbol or value, but not both}} + %0 = cir.vtt.address_point @_ZTT1C %arg1 : !cir.ptr>, offset = 1 -> !cir.ptr> + cir.return +} + +// ----- + +!u8i = !cir.int +!u32i = !cir.int +!void = !cir.void +!rec_anon_struct = !cir.record x 4>}> +!rec_C = !cir.record +cir.global linkonce_odr @_ZTT1C = #cir.const_array<[#cir.global_view<@_ZTV1C, [0 : i32, 3 : i32]> : !cir.ptr, #cir.global_view<@_ZTC1C0_1B, [0 : i32, 3 : i32]> : !cir.ptr]> : !cir.array x 2> {alignment = 8 : i64} +cir.func @bad_return_type_for_vtt_addrpoint() { + // expected-error @below {{result type must be '!cir.ptr>', but provided result type is '!cir.ptr>'}} + %0 = cir.vtt.address_point @_ZTT1C, offset = 1 -> !cir.ptr + cir.return +} diff --git a/clang/test/CIR/IR/vtt-addrpoint.cir b/clang/test/CIR/IR/vtt-addrpoint.cir new file mode 100644 index 0000000000000..f05bb782c6911 --- /dev/null +++ b/clang/test/CIR/IR/vtt-addrpoint.cir @@ -0,0 +1,55 @@ +// RUN: cir-opt %s | FileCheck %s + +// Test the parsing and printing of the two forms of vtt.address_point op, as +// they will appear in constructors. + +!u8i = !cir.int +!void = !cir.void +!rec_A = !cir.record +!rec_B = !cir.record +!rec_C = !cir.record +!rec_anon_struct = !cir.record x 3>}> +module { + cir.func private @_ZN1AC2Ev(!cir.ptr) + cir.func private @_ZN1BC2Ev(!cir.ptr, !cir.ptr>) + cir.func dso_local @_ZN1CC2Ev(%arg0: !cir.ptr, %arg1: !cir.ptr>) { + %0 = cir.alloca !cir.ptr, !cir.ptr>, ["this", init] {alignment = 8 : i64} + %1 = cir.alloca !cir.ptr>, !cir.ptr>>, ["vtt", init] {alignment = 8 : i64} + cir.store %arg0, %0 : !cir.ptr, !cir.ptr> + cir.store %arg1, %1 : !cir.ptr>, !cir.ptr>> + %2 = cir.load %0 : !cir.ptr>, !cir.ptr + %3 = cir.load align(8) %1 : !cir.ptr>>, !cir.ptr> + %4 = cir.base_class_addr %2 : !cir.ptr nonnull [0] -> !cir.ptr + + %5 = cir.vtt.address_point %3 : !cir.ptr>, offset = 1 -> !cir.ptr> + // CHECK: cir.vtt.address_point %{{.*}} : !cir.ptr>, offset = 1 -> !cir.ptr> + + cir.call @_ZN1BC2Ev(%4, %5) : (!cir.ptr, !cir.ptr>) -> () + %6 = cir.vtt.address_point %3 : !cir.ptr>, offset = 0 -> !cir.ptr> + %7 = cir.cast(bitcast, %6 : !cir.ptr>), !cir.ptr + %8 = cir.load align(8) %7 : !cir.ptr, !cir.vptr + %9 = cir.vtable.get_vptr %2 : !cir.ptr -> !cir.ptr + cir.store align(8) %8, %9 : !cir.vptr, !cir.ptr + cir.return + } + cir.global linkonce_odr dso_local @_ZTV1C = #cir.vtable<{#cir.const_array<[#cir.ptr : !cir.ptr, #cir.ptr : !cir.ptr, #cir.ptr : !cir.ptr]> : !cir.array x 3>}> : !rec_anon_struct {alignment = 8 : i64} + cir.global linkonce_odr @_ZTT1C = #cir.const_array<[#cir.global_view<@_ZTV1C, [0 : i32, 3 : i32]> : !cir.ptr, #cir.global_view<@_ZTC1C0_1B, [0 : i32, 3 : i32]> : !cir.ptr]> : !cir.array x 2> {alignment = 8 : i64} + cir.func dso_local @_ZN1CC1Ev(%arg0: !cir.ptr) { + %0 = cir.alloca !cir.ptr, !cir.ptr>, ["this", init] {alignment = 8 : i64} + cir.store %arg0, %0 : !cir.ptr, !cir.ptr> + %1 = cir.load %0 : !cir.ptr>, !cir.ptr + %2 = cir.base_class_addr %1 : !cir.ptr nonnull [0] -> !cir.ptr + cir.call @_ZN1AC2Ev(%2) : (!cir.ptr) -> () + %3 = cir.base_class_addr %1 : !cir.ptr nonnull [0] -> !cir.ptr + + %4 = cir.vtt.address_point @_ZTT1C, offset = 1 -> !cir.ptr> + // CHECK: cir.vtt.address_point @_ZTT1C, offset = 1 -> !cir.ptr> + + cir.call @_ZN1BC2Ev(%3, %4) : (!cir.ptr, !cir.ptr>) -> () + %5 = cir.vtable.address_point(@_ZTV1C, address_point = ) : !cir.vptr + %6 = cir.vtable.get_vptr %1 : !cir.ptr -> !cir.ptr + cir.store align(8) %5, %6 : !cir.vptr, !cir.ptr + cir.return + } + cir.global linkonce_odr dso_local @_ZTC1C0_1B = #cir.const_record<{#cir.const_array<[#cir.ptr : !cir.ptr, #cir.ptr : !cir.ptr, #cir.ptr : !cir.ptr]> : !cir.array x 3>}> : !rec_anon_struct {alignment = 8 : i64} +} diff --git a/clang/test/CIR/Lowering/vtt-addrpoint.cir b/clang/test/CIR/Lowering/vtt-addrpoint.cir new file mode 100644 index 0000000000000..a3e7271f7446e --- /dev/null +++ b/clang/test/CIR/Lowering/vtt-addrpoint.cir @@ -0,0 +1,59 @@ +// RUN: cir-translate %s -cir-to-llvmir --target x86_64-unknown-linux-gnu -o %t.ll +// RUN: FileCheck %s --input-file=%t.ll + +// Test the lowering of the two forms of vtt.address_point op, as they will +// appear in constructors. + +!u8i = !cir.int +!void = !cir.void +!rec_A = !cir.record +!rec_B = !cir.record +!rec_C = !cir.record +!rec_anon_struct = !cir.record x 3>}> +module { + cir.func private @_ZN1AC2Ev(!cir.ptr) + cir.func private @_ZN1BC2Ev(!cir.ptr, !cir.ptr>) + cir.func dso_local @_ZN1CC2Ev(%arg0: !cir.ptr, %arg1: !cir.ptr>) { + %0 = cir.alloca !cir.ptr, !cir.ptr>, ["this", init] {alignment = 8 : i64} + %1 = cir.alloca !cir.ptr>, !cir.ptr>>, ["vtt", init] {alignment = 8 : i64} + cir.store %arg0, %0 : !cir.ptr, !cir.ptr> + cir.store %arg1, %1 : !cir.ptr>, !cir.ptr>> + %2 = cir.load %0 : !cir.ptr>, !cir.ptr + %3 = cir.load align(8) %1 : !cir.ptr>>, !cir.ptr> + %4 = cir.base_class_addr %2 : !cir.ptr nonnull [0] -> !cir.ptr + %5 = cir.vtt.address_point %3 : !cir.ptr>, offset = 1 -> !cir.ptr> + cir.call @_ZN1BC2Ev(%4, %5) : (!cir.ptr, !cir.ptr>) -> () + %6 = cir.vtt.address_point %3 : !cir.ptr>, offset = 0 -> !cir.ptr> + %7 = cir.cast(bitcast, %6 : !cir.ptr>), !cir.ptr + %8 = cir.load align(8) %7 : !cir.ptr, !cir.vptr + %9 = cir.vtable.get_vptr %2 : !cir.ptr -> !cir.ptr + cir.store align(8) %8, %9 : !cir.vptr, !cir.ptr + cir.return + } + +// CHECK: define{{.*}} void @_ZN1CC2Ev +// CHECK: %[[VTT:.*]] = getelementptr inbounds i8, ptr %{{.*}}, i32 1 +// CHECK: call void @_ZN1BC2Ev(ptr %{{.*}}, ptr %[[VTT]]) + + cir.global linkonce_odr dso_local @_ZTV1C = #cir.vtable<{#cir.const_array<[#cir.ptr : !cir.ptr, #cir.ptr : !cir.ptr, #cir.ptr : !cir.ptr]> : !cir.array x 3>}> : !rec_anon_struct {alignment = 8 : i64} + cir.global linkonce_odr @_ZTT1C = #cir.const_array<[#cir.global_view<@_ZTV1C, [0 : i32, 3 : i32]> : !cir.ptr, #cir.global_view<@_ZTC1C0_1B, [0 : i32, 3 : i32]> : !cir.ptr]> : !cir.array x 2> {alignment = 8 : i64} + cir.func dso_local @_ZN1CC1Ev(%arg0: !cir.ptr) { + %0 = cir.alloca !cir.ptr, !cir.ptr>, ["this", init] {alignment = 8 : i64} + cir.store %arg0, %0 : !cir.ptr, !cir.ptr> + %1 = cir.load %0 : !cir.ptr>, !cir.ptr + %2 = cir.base_class_addr %1 : !cir.ptr nonnull [0] -> !cir.ptr + cir.call @_ZN1AC2Ev(%2) : (!cir.ptr) -> () + %3 = cir.base_class_addr %1 : !cir.ptr nonnull [0] -> !cir.ptr + %4 = cir.vtt.address_point @_ZTT1C, offset = 1 -> !cir.ptr> + cir.call @_ZN1BC2Ev(%3, %4) : (!cir.ptr, !cir.ptr>) -> () + %5 = cir.vtable.address_point(@_ZTV1C, address_point = ) : !cir.vptr + %6 = cir.vtable.get_vptr %1 : !cir.ptr -> !cir.ptr + cir.store align(8) %5, %6 : !cir.vptr, !cir.ptr + cir.return + } + +// CHECK: define{{.*}} void @_ZN1CC1Ev +// CHECK: store ptr getelementptr inbounds nuw (i8, ptr @_ZTV1C, i64 24), ptr %{{.*}} + + cir.global linkonce_odr dso_local @_ZTC1C0_1B = #cir.const_record<{#cir.const_array<[#cir.ptr : !cir.ptr, #cir.ptr : !cir.ptr, #cir.ptr : !cir.ptr]> : !cir.array x 3>}> : !rec_anon_struct {alignment = 8 : i64} +}