21
21
#include " llvm/ADT/StringMap.h"
22
22
#include " llvm/ADT/StringRef.h"
23
23
#include " llvm/ADT/Twine.h"
24
+ #include " llvm/ADT/TypeSwitch.h"
24
25
#include " llvm/Support/ErrorHandling.h"
25
26
#include " llvm/Support/Path.h"
26
27
@@ -42,24 +43,66 @@ template <typename TargetOp> class StdRecognizer {
42
43
template <size_t ... Indices>
43
44
static TargetOp buildCall (CIRBaseBuilderTy &builder, CallOp call,
44
45
std::index_sequence<Indices...>) {
45
- return builder.create <TargetOp>(call.getLoc (), call.getResult ().getType (),
46
- call.getCalleeAttr (),
47
- call.getOperand (Indices)...);
46
+ return builder.create <TargetOp>(
47
+ call.getLoc (),
48
+ (call.getResult () ? call.getResult ().getType () : mlir::TypeRange{}),
49
+ call.getCalleeAttr (), call.getOperand (Indices)...);
48
50
}
49
51
50
52
public:
51
- static bool raise (CallOp call, mlir::MLIRContext &context, bool remark) {
53
+ static FuncOp getCalleeFromSymbol (mlir::ModuleOp theModule,
54
+ llvm::StringRef name) {
55
+ auto global = mlir::SymbolTable::lookupSymbolIn (theModule, name);
56
+ assert (global && " expected to find symbol for function" );
57
+ return dyn_cast<FuncOp>(global);
58
+ }
59
+
60
+ static std::optional<StringRef>
61
+ getRecordName (const clang::CXXRecordDecl *rd) {
62
+ if (!rd || !rd->getDeclContext ()->isStdNamespace ())
63
+ return std::nullopt ;
64
+
65
+ if (rd->getDeclName ().isIdentifier ())
66
+ return rd->getName ();
67
+
68
+ return std::nullopt ;
69
+ }
70
+
71
+ static std::optional<std::string>
72
+ resolveSpecialMember (mlir::Attribute specialMember) {
73
+ return TypeSwitch<Attribute, std::optional<std::string>>(specialMember)
74
+ .Case <CXXCtorAttr, CXXDtorAttr>(
75
+ [](auto attr) -> std::optional<std::string> {
76
+ if (!attr.getRecordDecl ())
77
+ return std::nullopt ;
78
+ if (auto recordName = getRecordName (*attr.getRecordDecl ()))
79
+ return recordName->str () + " _" + attr.getMnemonic ().str ();
80
+ return std::nullopt ;
81
+ })
82
+ .Default ([](Attribute) { return std::nullopt ; });
83
+ }
84
+
85
+ static bool raise (mlir::ModuleOp theModule, CallOp call,
86
+ mlir::MLIRContext &context, bool remark) {
52
87
constexpr int numArgs = TargetOp::getNumArgs ();
53
88
if (call.getNumOperands () != numArgs)
54
89
return false ;
55
90
56
- auto callExprAttr = call.getAstAttr ();
57
91
llvm::StringRef stdFuncName = TargetOp::getFunctionName ();
58
- if (!callExprAttr || !callExprAttr.isStdFunctionCall (stdFuncName))
59
- return false ;
60
-
61
- if (!checkArguments (call.getArgOperands ()))
62
- return false ;
92
+ auto calleeFunc = getCalleeFromSymbol (theModule, *call.getCallee ());
93
+
94
+ if (auto specialMember = calleeFunc.getCxxSpecialMemberAttr ()) {
95
+ auto resolved = resolveSpecialMember (specialMember);
96
+ if (!resolved || *resolved != stdFuncName.str ())
97
+ return false ;
98
+ } else {
99
+ auto callExprAttr = call.getAstAttr ();
100
+ if (!callExprAttr || !callExprAttr.isStdFunctionCall (stdFuncName))
101
+ return false ;
102
+
103
+ if (!checkArguments (call.getArgOperands ()))
104
+ return false ;
105
+ }
63
106
64
107
if (remark)
65
108
mlir::emitRemark (call.getLoc ())
@@ -194,12 +237,16 @@ void IdiomRecognizerPass::recognizeCall(CallOp call) {
194
237
195
238
bool remark = opts.emitRemarkFoundCalls ();
196
239
197
- using StdFunctionsRecognizer = std::tuple<StdRecognizer<StdFindOp>>;
240
+ using StdFunctionsRecognizer =
241
+ std::tuple<StdRecognizer<StdFindOp>, StdRecognizer<StdVectorCtorOp>,
242
+ StdRecognizer<StdVectorDtorOp>>;
198
243
199
244
// MSVC requires explicitly capturing these variables.
200
245
std::apply (
201
246
[&, call, remark, this ](auto ... recognizers) {
202
- (decltype (recognizers)::raise (call, this ->getContext (), remark) || ...);
247
+ (decltype (recognizers)::raise (theModule, call, this ->getContext (),
248
+ remark) ||
249
+ ...);
203
250
},
204
251
StdFunctionsRecognizer ());
205
252
}
0 commit comments