Skip to content

Commit 2ab0704

Browse files
authored
[CIR] Base-to-derived and derived-to-base casts on pointers to member functions (#1424)
This PR adds CIRGen and LLVM lowering support for base-to-derived and derived-to-base cast operations on pointers to member functions. This PR includes a new operation `cir.update_member` to help the LLVM lowering procedure of such cast operations. Resolve #973 .
1 parent 0bedc28 commit 2ab0704

File tree

9 files changed

+394
-10
lines changed

9 files changed

+394
-10
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2930,6 +2930,67 @@ def ExtractMemberOp : CIR_Op<"extract_member", [Pure]> {
29302930
let hasVerifier = 1;
29312931
}
29322932

2933+
//===----------------------------------------------------------------------===//
2934+
// InsertMemberOp
2935+
//===----------------------------------------------------------------------===//
2936+
2937+
def InsertMemberOp : CIR_Op<"insert_member",
2938+
[Pure, AllTypesMatch<["record", "result"]>]> {
2939+
let summary = "Overwrite the value of a member of a struct value";
2940+
let description = [{
2941+
The `cir.insert_member` operation overwrites the value of a particular
2942+
member in the input struct value, and returns the modified struct value. The
2943+
result of this operation is equal to the input struct value, except for the
2944+
member specified by `index_attr` whose value is equal to the given value.
2945+
2946+
This operation is named after the LLVM instruction `insertvalue`.
2947+
2948+
Currently `cir.insert_member` does not work on unions.
2949+
2950+
Example:
2951+
2952+
```mlir
2953+
// Suppose we have a struct with multiple members.
2954+
!s32i = !cir.int<s, 32>
2955+
!s8i = !cir.int<s, 32>
2956+
!struct_ty = !cir.struct<"struct.Bar" {!s32i, !s8i}>
2957+
2958+
// And suppose we have a value of the struct type.
2959+
%0 = cir.const #cir.const_struct<{#cir.int<1> : !s32i, #cir.int<2> : !s8i}> : !struct_ty
2960+
// %0 is {1, 2}
2961+
2962+
// Overwrite the second member of the struct value.
2963+
%1 = cir.const #cir.int<3> : !s8i
2964+
%2 = cir.insert_member %0[1], %1 : !struct_ty, !s8i
2965+
// %2 is {1, 3}
2966+
```
2967+
}];
2968+
2969+
let arguments = (ins CIR_StructType:$record, IndexAttr:$index_attr,
2970+
CIR_AnyType:$value);
2971+
let results = (outs CIR_StructType:$result);
2972+
2973+
let builders = [
2974+
OpBuilder<(ins "mlir::Value":$record, "uint64_t":$index,
2975+
"mlir::Value":$value), [{
2976+
mlir::APInt fieldIdx(64, index);
2977+
build($_builder, $_state, record, fieldIdx, value);
2978+
}]>
2979+
];
2980+
2981+
let extraClassDeclaration = [{
2982+
/// Get the index of the struct member being accessed.
2983+
uint64_t getIndex() { return getIndexAttr().getZExtValue(); }
2984+
}];
2985+
2986+
let assemblyFormat = [{
2987+
$record `[` $index_attr `]` `,` $value attr-dict
2988+
`:` qualified(type($record)) `,` qualified(type($value))
2989+
}];
2990+
2991+
let hasVerifier = 1;
2992+
}
2993+
29332994
//===----------------------------------------------------------------------===//
29342995
// GetRuntimeMemberOp
29352996
//===----------------------------------------------------------------------===//
@@ -3421,6 +3482,74 @@ def DerivedDataMemberOp : CIR_Op<"derived_data_member", [Pure]> {
34213482
let hasVerifier = 1;
34223483
}
34233484

3485+
//===----------------------------------------------------------------------===//
3486+
// BaseMethodOp & DerivedMethodOp
3487+
//===----------------------------------------------------------------------===//
3488+
3489+
def BaseMethodOp : CIR_Op<"base_method", [Pure]> {
3490+
let summary = [{
3491+
Cast a derived class pointer-to-member-function to a base class
3492+
pointer-to-member-function
3493+
}];
3494+
let description = [{
3495+
The `cir.base_method` operation casts a pointer-to-member-function of type
3496+
`Ret (Derived::*)(Args)` to a pointer-to-member-function of type
3497+
`Ret (Base::*)(Args)`, where `Base` is a non-virtual base class of
3498+
`Derived`.
3499+
3500+
The `offset` parameter gives the offset in bytes of the `Base` base class
3501+
subobject within a `Derived` object.
3502+
3503+
Example:
3504+
3505+
```mlir
3506+
%1 = cir.base_method(%0 : !cir.method<!cir.func<(!s32i)> in !ty_Derived>) [16] -> !cir.method<!cir.func<(!s32i)> in !ty_Base>
3507+
```
3508+
}];
3509+
3510+
let arguments = (ins CIR_MethodType:$src, IndexAttr:$offset);
3511+
let results = (outs CIR_MethodType:$result);
3512+
3513+
let assemblyFormat = [{
3514+
`(` $src `:` qualified(type($src)) `)`
3515+
`[` $offset `]` `->` qualified(type($result)) attr-dict
3516+
}];
3517+
3518+
let hasVerifier = 1;
3519+
}
3520+
3521+
def DerivedMethodOp : CIR_Op<"derived_method", [Pure]> {
3522+
let summary = [{
3523+
Cast a base class pointer-to-member-function to a derived class
3524+
pointer-to-member-function
3525+
}];
3526+
let description = [{
3527+
The `cir.derived_method` operation casts a pointer-to-member-function of
3528+
type `Ret (Base::*)(Args)` to a pointer-to-member-function of type
3529+
`Ret (Derived::*)(Args)`, where `Base` is a non-virtual base class of
3530+
`Derived`.
3531+
3532+
The `offset` parameter gives the offset in bytes of the `Base` base class
3533+
subobject within a `Derived` object.
3534+
3535+
Example:
3536+
3537+
```mlir
3538+
%1 = cir.derived_method(%0 : !cir.method<!cir.func<(!s32i)> in !ty_Base>) [16] -> !cir.method<!cir.func<(!s32i)> in !ty_Derived>
3539+
```
3540+
}];
3541+
3542+
let arguments = (ins CIR_MethodType:$src, IndexAttr:$offset);
3543+
let results = (outs CIR_MethodType:$result);
3544+
3545+
let assemblyFormat = [{
3546+
`(` $src `:` qualified(type($src)) `)`
3547+
`[` $offset `]` `->` qualified(type($result)) attr-dict
3548+
}];
3549+
3550+
let hasVerifier = 1;
3551+
}
3552+
34243553
//===----------------------------------------------------------------------===//
34253554
// FuncOp
34263555
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct MissingFeatures {
7070
static bool tbaaPointer() { return false; }
7171
static bool emitNullabilityCheck() { return false; }
7272
static bool ptrAuth() { return false; }
73+
static bool memberFuncPtrAuthInfo() { return false; }
7374
static bool emitCFICheck() { return false; }
7475
static bool emitVFEInfo() { return false; }
7576
static bool emitWPDInfo() { return false; }

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,9 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
17551755
case CK_DerivedToBaseMemberPointer: {
17561756
mlir::Value src = Visit(E);
17571757

1758+
if (E->getType()->isMemberFunctionPointerType())
1759+
assert(!cir::MissingFeatures::memberFuncPtrAuthInfo());
1760+
17581761
QualType derivedTy =
17591762
Kind == CK_DerivedToBaseMemberPointer ? E->getType() : CE->getType();
17601763
const CXXRecordDecl *derivedClass = derivedTy->castAs<MemberPointerType>()
@@ -1763,13 +1766,17 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
17631766
CharUnits offset = CGF.CGM.computeNonVirtualBaseClassOffset(
17641767
derivedClass, CE->path_begin(), CE->path_end());
17651768

1766-
if (E->getType()->isMemberFunctionPointerType())
1767-
llvm_unreachable("NYI");
1768-
17691769
mlir::Location loc = CGF.getLoc(E->getExprLoc());
17701770
mlir::Type resultTy = CGF.convertType(DestTy);
17711771
mlir::IntegerAttr offsetAttr = Builder.getIndexAttr(offset.getQuantity());
17721772

1773+
if (E->getType()->isMemberFunctionPointerType()) {
1774+
if (Kind == CK_BaseToDerivedMemberPointer)
1775+
return Builder.create<cir::DerivedMethodOp>(loc, resultTy, src,
1776+
offsetAttr);
1777+
return Builder.create<cir::BaseMethodOp>(loc, resultTy, src, offsetAttr);
1778+
}
1779+
17731780
if (Kind == CK_BaseToDerivedMemberPointer)
17741781
return Builder.create<cir::DerivedDataMemberOp>(loc, resultTy, src,
17751782
offsetAttr);

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -841,13 +841,21 @@ LogicalResult cir::DynamicCastOp::verify() {
841841
// BaseDataMemberOp & DerivedDataMemberOp
842842
//===----------------------------------------------------------------------===//
843843

844-
static LogicalResult verifyDataMemberCast(Operation *op, mlir::Value src,
845-
mlir::Type resultTy) {
844+
static LogicalResult verifyMemberPtrCast(Operation *op, mlir::Value src,
845+
mlir::Type resultTy) {
846846
// Let the operand type be T1 C1::*, let the result type be T2 C2::*.
847847
// Verify that T1 and T2 are the same type.
848-
auto inputMemberTy =
849-
mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
850-
auto resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
848+
mlir::Type inputMemberTy;
849+
mlir::Type resultMemberTy;
850+
if (mlir::isa<cir::DataMemberType>(src.getType())) {
851+
inputMemberTy =
852+
mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
853+
resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
854+
} else {
855+
inputMemberTy =
856+
mlir::cast<cir::MethodType>(src.getType()).getMemberFuncTy();
857+
resultMemberTy = mlir::cast<cir::MethodType>(resultTy).getMemberFuncTy();
858+
}
851859
if (inputMemberTy != resultMemberTy)
852860
return op->emitOpError()
853861
<< "member types of the operand and the result do not match";
@@ -856,11 +864,23 @@ static LogicalResult verifyDataMemberCast(Operation *op, mlir::Value src,
856864
}
857865

858866
LogicalResult cir::BaseDataMemberOp::verify() {
859-
return verifyDataMemberCast(getOperation(), getSrc(), getType());
867+
return verifyMemberPtrCast(getOperation(), getSrc(), getType());
860868
}
861869

862870
LogicalResult cir::DerivedDataMemberOp::verify() {
863-
return verifyDataMemberCast(getOperation(), getSrc(), getType());
871+
return verifyMemberPtrCast(getOperation(), getSrc(), getType());
872+
}
873+
874+
//===----------------------------------------------------------------------===//
875+
// BaseMethodOp & DerivedMethodOp
876+
//===----------------------------------------------------------------------===//
877+
878+
LogicalResult cir::BaseMethodOp::verify() {
879+
return verifyMemberPtrCast(getOperation(), getSrc(), getType());
880+
}
881+
882+
LogicalResult cir::DerivedMethodOp::verify() {
883+
return verifyMemberPtrCast(getOperation(), getSrc(), getType());
864884
}
865885

866886
//===----------------------------------------------------------------------===//
@@ -3580,6 +3600,22 @@ LogicalResult cir::ExtractMemberOp::verify() {
35803600
return mlir::success();
35813601
}
35823602

3603+
//===----------------------------------------------------------------------===//
3604+
// InsertMemberOp Definitions
3605+
//===----------------------------------------------------------------------===//
3606+
3607+
LogicalResult cir::InsertMemberOp::verify() {
3608+
auto recordTy = mlir::cast<cir::StructType>(getRecord().getType());
3609+
if (recordTy.getKind() == cir::StructType::Union)
3610+
return emitError() << "cir.update_member currently does not work on unions";
3611+
if (recordTy.getMembers().size() <= getIndex())
3612+
return emitError() << "member index out of range";
3613+
if (recordTy.getMembers()[getIndex()] != getValue().getType())
3614+
return emitError() << "member type mismatch";
3615+
// The op trait already checks that the types of $result and $record match.
3616+
return mlir::success();
3617+
}
3618+
35833619
//===----------------------------------------------------------------------===//
35843620
// GetRuntimeMemberOp Definitions
35853621
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,18 @@ class CIRCXXABI {
118118
lowerDerivedDataMember(cir::DerivedDataMemberOp op, mlir::Value loweredSrc,
119119
mlir::OpBuilder &builder) const = 0;
120120

121+
/// Lower the given cir.base_method op to a sequence of more "primitive" CIR
122+
/// operations that act on the ABI types.
123+
virtual mlir::Value lowerBaseMethod(cir::BaseMethodOp op,
124+
mlir::Value loweredSrc,
125+
mlir::OpBuilder &builder) const = 0;
126+
127+
/// Lower the given cir.derived_method op to a sequence of more "primitive"
128+
/// CIR operations that act on the ABI types.
129+
virtual mlir::Value lowerDerivedMethod(cir::DerivedMethodOp op,
130+
mlir::Value loweredSrc,
131+
mlir::OpBuilder &builder) const = 0;
132+
121133
virtual mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
122134
mlir::Value loweredRhs,
123135
mlir::OpBuilder &builder) const = 0;

clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ class ItaniumCXXABI : public CIRCXXABI {
9999
mlir::Value loweredSrc,
100100
mlir::OpBuilder &builder) const override;
101101

102+
mlir::Value lowerBaseMethod(cir::BaseMethodOp op, mlir::Value loweredSrc,
103+
mlir::OpBuilder &builder) const override;
104+
105+
mlir::Value lowerDerivedMethod(cir::DerivedMethodOp op,
106+
mlir::Value loweredSrc,
107+
mlir::OpBuilder &builder) const override;
108+
102109
mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
103110
mlir::Value loweredRhs,
104111
mlir::OpBuilder &builder) const override;
@@ -466,6 +473,27 @@ static mlir::Value lowerDataMemberCast(mlir::Operation *op,
466473
isNull, nullValue, adjustedPtr);
467474
}
468475

476+
static mlir::Value lowerMethodCast(mlir::Operation *op, mlir::Value loweredSrc,
477+
std::int64_t offset, bool isDerivedToBase,
478+
LowerModule &lowerMod,
479+
mlir::OpBuilder &builder) {
480+
if (offset == 0)
481+
return loweredSrc;
482+
483+
cir::IntType ptrdiffCIRTy = getPtrDiffCIRTy(lowerMod);
484+
auto adjField = builder.create<cir::ExtractMemberOp>(
485+
op->getLoc(), ptrdiffCIRTy, loweredSrc, 1);
486+
487+
auto offsetValue = builder.create<cir::ConstantOp>(
488+
op->getLoc(), cir::IntAttr::get(ptrdiffCIRTy, offset));
489+
auto binOpKind = isDerivedToBase ? cir::BinOpKind::Sub : cir::BinOpKind::Add;
490+
auto adjustedAdjField = builder.create<cir::BinOp>(
491+
op->getLoc(), ptrdiffCIRTy, binOpKind, adjField, offsetValue);
492+
493+
return builder.create<cir::InsertMemberOp>(op->getLoc(), loweredSrc, 1,
494+
adjustedAdjField);
495+
}
496+
469497
mlir::Value ItaniumCXXABI::lowerBaseDataMember(cir::BaseDataMemberOp op,
470498
mlir::Value loweredSrc,
471499
mlir::OpBuilder &builder) const {
@@ -481,6 +509,20 @@ ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op,
481509
/*isDerivedToBase=*/false, builder);
482510
}
483511

512+
mlir::Value ItaniumCXXABI::lowerBaseMethod(cir::BaseMethodOp op,
513+
mlir::Value loweredSrc,
514+
mlir::OpBuilder &builder) const {
515+
return lowerMethodCast(op, loweredSrc, op.getOffset().getSExtValue(),
516+
/*isDerivedToBase=*/true, LM, builder);
517+
}
518+
519+
mlir::Value ItaniumCXXABI::lowerDerivedMethod(cir::DerivedMethodOp op,
520+
mlir::Value loweredSrc,
521+
mlir::OpBuilder &builder) const {
522+
return lowerMethodCast(op, loweredSrc, op.getOffset().getSExtValue(),
523+
/*isDerivedToBase=*/false, LM, builder);
524+
}
525+
484526
mlir::Value ItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op,
485527
mlir::Value loweredLhs,
486528
mlir::Value loweredRhs,

0 commit comments

Comments
 (0)