@@ -42,28 +42,66 @@ template <typename TargetOp> class StdRecognizer {
42
42
template <size_t ... Indices>
43
43
static TargetOp buildCall (CIRBaseBuilderTy &builder, CallOp call,
44
44
std::index_sequence<Indices...>) {
45
- return builder.create <TargetOp>(call.getLoc (), call.getResult ().getType (),
46
- call.getCalleeAttr (),
47
- call.getOperand (Indices)...);
45
+ return builder.create <TargetOp>(
46
+ call.getLoc (),
47
+ (call.getResult () ? call.getResult ().getType () : mlir::TypeRange{}),
48
+ call.getCalleeAttr (), call.getOperand (Indices)...);
48
49
}
49
50
50
51
public:
51
- static bool raise (CallOp call, mlir::MLIRContext &context, bool remark) {
52
+ static FuncOp getCalleeFromSymbol (mlir::ModuleOp theModule,
53
+ llvm::StringRef name) {
54
+ auto global = mlir::SymbolTable::lookupSymbolIn (theModule, name);
55
+ assert (global && " expected to find symbol for function" );
56
+ return dyn_cast<FuncOp>(global);
57
+ }
58
+
59
+ static bool isStdVector (const clang::CXXRecordDecl *RD) {
60
+ if (!RD || !RD->getDeclContext ()->isStdNamespace ())
61
+ return false ;
62
+
63
+ if (RD->getDeclName ().isIdentifier ()) {
64
+ StringRef Name = RD->getName ();
65
+ return Name == " vector" ;
66
+ }
67
+
68
+ return false ;
69
+ }
70
+
71
+ static bool raise (mlir::ModuleOp theModule, CallOp call,
72
+ mlir::MLIRContext &context, bool remark) {
52
73
constexpr int numArgs = TargetOp::getNumArgs ();
53
74
if (call.getNumOperands () != numArgs)
54
75
return false ;
55
76
56
- auto callExprAttr = call.getAstAttr ();
57
- llvm::StringRef stdFuncName = TargetOp::getFunctionName ();
58
- if (!callExprAttr || !callExprAttr.isStdFunctionCall (stdFuncName))
59
- return false ;
77
+ llvm::StringRef name = *call.getCallee ();
78
+ auto calleeFunc = getCalleeFromSymbol (theModule, name);
60
79
61
- if (!checkArguments (call.getArgOperands ()))
62
- return false ;
80
+ llvm::StringRef stdFuncName = TargetOp::getFunctionName ();
63
81
64
- if (remark)
65
- mlir::emitRemark (call.getLoc ())
66
- << " found call to std::" << stdFuncName << " ()" ;
82
+ if (auto specialMember = calleeFunc.getCxxSpecialMemberAttr ()) {
83
+ auto matches =
84
+ (stdFuncName == " vector_ctor" && isa<CXXCtorAttr>(specialMember)) ||
85
+ (stdFuncName == " vector_dtor" && isa<CXXDtorAttr>(specialMember));
86
+ if (!matches)
87
+ return false ;
88
+
89
+ auto recordDeclAttr = call.getAstRecordAttr ();
90
+ if (!recordDeclAttr ||
91
+ !isStdVector (cast<clang::CXXRecordDecl>(recordDeclAttr.getRawDecl ())))
92
+ return false ;
93
+ } else {
94
+ auto callExprAttr = call.getAstAttr ();
95
+ if (!callExprAttr || !callExprAttr.isStdFunctionCall (stdFuncName))
96
+ return false ;
97
+
98
+ if (!checkArguments (call.getArgOperands ()))
99
+ return false ;
100
+
101
+ if (remark)
102
+ mlir::emitRemark (call.getLoc ())
103
+ << " found call to std::" << stdFuncName << " ()" ;
104
+ }
67
105
68
106
CIRBaseBuilderTy builder (context);
69
107
builder.setInsertionPointAfter (call.getOperation ());
@@ -194,12 +232,16 @@ void IdiomRecognizerPass::recognizeCall(CallOp call) {
194
232
195
233
bool remark = opts.emitRemarkFoundCalls ();
196
234
197
- using StdFunctionsRecognizer = std::tuple<StdRecognizer<StdFindOp>>;
235
+ using StdFunctionsRecognizer =
236
+ std::tuple<StdRecognizer<StdFindOp>, StdRecognizer<StdVectorCtorOp>,
237
+ StdRecognizer<StdVectorDtorOp>>;
198
238
199
239
// MSVC requires explicitly capturing these variables.
200
240
std::apply (
201
241
[&, call, remark, this ](auto ... recognizers) {
202
- (decltype (recognizers)::raise (call, this ->getContext (), remark) || ...);
242
+ (decltype (recognizers)::raise (theModule, call, this ->getContext (),
243
+ remark) ||
244
+ ...);
203
245
},
204
246
StdFunctionsRecognizer ());
205
247
}
0 commit comments