Skip to content

Commit aff6245

Browse files
authored
[CIR] Add support for comparisons between pointers to member functions (#1390)
The CIRGen support is already there. This PR adds LLVM lowering support for comparisons between pointers to member functions. Note that pointers to member functions could only be compared for equality.
1 parent dc932de commit aff6245

File tree

4 files changed

+113
-3
lines changed

4 files changed

+113
-3
lines changed

clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ class CIRCXXABI {
122122
mlir::Value loweredRhs,
123123
mlir::OpBuilder &builder) const = 0;
124124

125+
virtual mlir::Value lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs,
126+
mlir::Value loweredRhs,
127+
mlir::OpBuilder &builder) const = 0;
128+
125129
virtual mlir::Value
126130
lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
127131
mlir::Value loweredSrc,

clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ class ItaniumCXXABI : public CIRCXXABI {
103103
mlir::Value loweredRhs,
104104
mlir::OpBuilder &builder) const override;
105105

106+
mlir::Value lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs,
107+
mlir::Value loweredRhs,
108+
mlir::OpBuilder &builder) const override;
109+
106110
mlir::Value lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
107111
mlir::Value loweredSrc,
108112
mlir::OpBuilder &builder) const override;
@@ -478,6 +482,61 @@ mlir::Value ItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op,
478482
loweredRhs);
479483
}
480484

485+
mlir::Value ItaniumCXXABI::lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs,
486+
mlir::Value loweredRhs,
487+
mlir::OpBuilder &builder) const {
488+
assert(op.getKind() == cir::CmpOpKind::eq ||
489+
op.getKind() == cir::CmpOpKind::ne);
490+
491+
cir::IntType ptrdiffCIRTy = getPtrDiffCIRTy(LM);
492+
mlir::Value ptrdiffZero = builder.create<cir::ConstantOp>(
493+
op.getLoc(), ptrdiffCIRTy, cir::IntAttr::get(ptrdiffCIRTy, 0));
494+
495+
mlir::Value lhsPtrField = builder.create<cir::ExtractMemberOp>(
496+
op.getLoc(), ptrdiffCIRTy, loweredLhs, 0);
497+
mlir::Value rhsPtrField = builder.create<cir::ExtractMemberOp>(
498+
op.getLoc(), ptrdiffCIRTy, loweredRhs, 0);
499+
mlir::Value ptrCmp = builder.create<cir::CmpOp>(op.getLoc(), op.getKind(),
500+
lhsPtrField, rhsPtrField);
501+
mlir::Value ptrCmpToNull = builder.create<cir::CmpOp>(
502+
op.getLoc(), op.getKind(), lhsPtrField, ptrdiffZero);
503+
504+
mlir::Value lhsAdjField = builder.create<cir::ExtractMemberOp>(
505+
op.getLoc(), ptrdiffCIRTy, loweredLhs, 1);
506+
mlir::Value rhsAdjField = builder.create<cir::ExtractMemberOp>(
507+
op.getLoc(), ptrdiffCIRTy, loweredRhs, 1);
508+
mlir::Value adjCmp = builder.create<cir::CmpOp>(op.getLoc(), op.getKind(),
509+
lhsAdjField, rhsAdjField);
510+
511+
// We use cir.select to represent "||" and "&&" operations below:
512+
// - cir.select if %a then %b else false => %a && %b
513+
// - cir.select if %a then true else %b => %a || %b
514+
// TODO: Do we need to invent dedicated "cir.logical_or" and "cir.logical_and"
515+
// operations for this?
516+
auto boolTy = cir::BoolType::get(op.getContext());
517+
mlir::Value trueValue = builder.create<cir::ConstantOp>(
518+
op.getLoc(), boolTy, cir::BoolAttr::get(op.getContext(), boolTy, true));
519+
mlir::Value falseValue = builder.create<cir::ConstantOp>(
520+
op.getLoc(), boolTy, cir::BoolAttr::get(op.getContext(), boolTy, false));
521+
auto create_and = [&](mlir::Value lhs, mlir::Value rhs) {
522+
return builder.create<cir::SelectOp>(op.getLoc(), lhs, rhs, falseValue);
523+
};
524+
auto create_or = [&](mlir::Value lhs, mlir::Value rhs) {
525+
return builder.create<cir::SelectOp>(op.getLoc(), lhs, trueValue, rhs);
526+
};
527+
528+
mlir::Value result;
529+
if (op.getKind() == cir::CmpOpKind::eq) {
530+
// (lhs.ptr == null || lhs.adj == rhs.adj) && lhs.ptr == rhs.ptr
531+
result = create_and(create_or(ptrCmpToNull, adjCmp), ptrCmp);
532+
} else {
533+
// (lhs.ptr != null && lhs.adj != rhs.adj) || lhs.ptr != rhs.ptr
534+
result = create_or(create_and(ptrCmpToNull, adjCmp), ptrCmp);
535+
}
536+
537+
return result;
538+
}
539+
481540
mlir::Value
482541
ItaniumCXXABI::lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
483542
mlir::Value loweredSrc,

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2893,10 +2893,17 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
28932893
mlir::ConversionPatternRewriter &rewriter) const {
28942894
auto type = cmpOp.getLhs().getType();
28952895

2896-
if (mlir::isa<cir::DataMemberType>(type)) {
2896+
if (mlir::isa<cir::DataMemberType, cir::MethodType>(type)) {
28972897
assert(lowerMod && "lowering module is not available");
2898-
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp(
2899-
cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter);
2898+
2899+
mlir::Value loweredResult;
2900+
if (mlir::isa<cir::DataMemberType>(type))
2901+
loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp(
2902+
cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter);
2903+
else
2904+
loweredResult = lowerMod->getCXXABI().lowerMethodCmp(
2905+
cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter);
2906+
29002907
rewriter.replaceOp(cmpOp, loweredResult);
29012908
return mlir::success();
29022909
}

clang/test/CIR/CodeGen/pointer-to-member-func.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,43 @@ void call(Foo *obj, void (Foo::*func)(int), int arg) {
7878
// LLVM-NEXT: %[[#arg:]] = load i32, ptr %{{.+}}
7979
// LLVM-NEXT: call void %[[#callee_ptr]](ptr %[[#adjusted_this]], i32 %[[#arg]])
8080
// LLVM: }
81+
82+
bool cmp_eq(void (Foo::*lhs)(int), void (Foo::*rhs)(int)) {
83+
return lhs == rhs;
84+
}
85+
86+
// CHECK-LABEL: @_Z6cmp_eqM3FooFviES1_
87+
// CHECK: %{{.+}} = cir.cmp(eq, %{{.+}}, %{{.+}}) : !cir.method<!cir.func<(!s32i)> in !ty_Foo>, !cir.bool
88+
89+
// LLVM-LABEL: @_Z6cmp_eqM3FooFviES1_
90+
// LLVM: %[[#lhs:]] = load { i64, i64 }, ptr %{{.+}}
91+
// LLVM-NEXT: %[[#rhs:]] = load { i64, i64 }, ptr %{{.+}}
92+
// LLVM-NEXT: %[[#lhs_ptr:]] = extractvalue { i64, i64 } %[[#lhs]], 0
93+
// LLVM-NEXT: %[[#rhs_ptr:]] = extractvalue { i64, i64 } %[[#rhs]], 0
94+
// LLVM-NEXT: %[[#ptr_cmp:]] = icmp eq i64 %[[#lhs_ptr]], %[[#rhs_ptr]]
95+
// LLVM-NEXT: %[[#ptr_null:]] = icmp eq i64 %[[#lhs_ptr]], 0
96+
// LLVM-NEXT: %[[#lhs_adj:]] = extractvalue { i64, i64 } %[[#lhs]], 1
97+
// LLVM-NEXT: %[[#rhs_adj:]] = extractvalue { i64, i64 } %[[#rhs]], 1
98+
// LLVM-NEXT: %[[#adj_cmp:]] = icmp eq i64 %[[#lhs_adj]], %[[#rhs_adj]]
99+
// LLVM-NEXT: %[[#tmp:]] = or i1 %[[#ptr_null]], %[[#adj_cmp]]
100+
// LLVM-NEXT: %{{.+}} = and i1 %[[#tmp]], %[[#ptr_cmp]]
101+
102+
bool cmp_ne(void (Foo::*lhs)(int), void (Foo::*rhs)(int)) {
103+
return lhs != rhs;
104+
}
105+
106+
// CHECK-LABEL: @_Z6cmp_neM3FooFviES1_
107+
// CHECK: %{{.+}} = cir.cmp(ne, %{{.+}}, %{{.+}}) : !cir.method<!cir.func<(!s32i)> in !ty_Foo>, !cir.bool
108+
109+
// LLVM-LABEL: @_Z6cmp_neM3FooFviES1_
110+
// LLVM: %[[#lhs:]] = load { i64, i64 }, ptr %{{.+}}
111+
// LLVM-NEXT: %[[#rhs:]] = load { i64, i64 }, ptr %{{.+}}
112+
// LLVM-NEXT: %[[#lhs_ptr:]] = extractvalue { i64, i64 } %[[#lhs]], 0
113+
// LLVM-NEXT: %[[#rhs_ptr:]] = extractvalue { i64, i64 } %[[#rhs]], 0
114+
// LLVM-NEXT: %[[#ptr_cmp:]] = icmp ne i64 %[[#lhs_ptr]], %[[#rhs_ptr]]
115+
// LLVM-NEXT: %[[#ptr_null:]] = icmp ne i64 %[[#lhs_ptr]], 0
116+
// LLVM-NEXT: %[[#lhs_adj:]] = extractvalue { i64, i64 } %[[#lhs]], 1
117+
// LLVM-NEXT: %[[#rhs_adj:]] = extractvalue { i64, i64 } %[[#rhs]], 1
118+
// LLVM-NEXT: %[[#adj_cmp:]] = icmp ne i64 %[[#lhs_adj]], %[[#rhs_adj]]
119+
// LLVM-NEXT: %[[#tmp:]] = and i1 %[[#ptr_null]], %[[#adj_cmp]]
120+
// LLVM-NEXT: %{{.+}} = or i1 %[[#tmp]], %[[#ptr_cmp]]

0 commit comments

Comments
 (0)