21
21
#include " llvm/ADT/StringMap.h"
22
22
#include " llvm/ADT/StringRef.h"
23
23
#include " llvm/ADT/Twine.h"
24
+ #include " llvm/ADT/TypeSwitch.h"
24
25
#include " llvm/Support/ErrorHandling.h"
25
26
#include " llvm/Support/Path.h"
26
27
@@ -56,16 +57,29 @@ template <typename TargetOp> class StdRecognizer {
56
57
return dyn_cast<FuncOp>(global);
57
58
}
58
59
59
- static bool isStdVector (const clang::CXXRecordDecl *RD) {
60
- if (!RD || !RD->getDeclContext ()->isStdNamespace ())
61
- return false ;
60
+ static std::optional<StringRef>
61
+ getRecordName (const clang::CXXRecordDecl *rd) {
62
+ if (!rd || !rd->getDeclContext ()->isStdNamespace ())
63
+ return std::nullopt ;
62
64
63
- if (RD->getDeclName ().isIdentifier ()) {
64
- StringRef Name = RD->getName ();
65
- return Name == " vector" ;
66
- }
65
+ if (rd->getDeclName ().isIdentifier ())
66
+ return rd->getName ();
67
67
68
- return false ;
68
+ return std::nullopt ;
69
+ }
70
+
71
+ static std::optional<std::string>
72
+ resolveSpecialMember (mlir::Attribute specialMember) {
73
+ return TypeSwitch<Attribute, std::optional<std::string>>(specialMember)
74
+ .Case <CXXCtorAttr, CXXDtorAttr>(
75
+ [](auto attr) -> std::optional<std::string> {
76
+ if (!attr.getRecordDecl ())
77
+ return std::nullopt ;
78
+ if (auto recordName = getRecordName (*attr.getRecordDecl ()))
79
+ return recordName->str () + " _" + attr.getMnemonic ().str ();
80
+ return std::nullopt ;
81
+ })
82
+ .Default ([](Attribute) { return std::nullopt ; });
69
83
}
70
84
71
85
static bool raise (mlir::ModuleOp theModule, CallOp call,
@@ -74,21 +88,12 @@ template <typename TargetOp> class StdRecognizer {
74
88
if (call.getNumOperands () != numArgs)
75
89
return false ;
76
90
77
- llvm::StringRef name = *call.getCallee ();
78
- auto calleeFunc = getCalleeFromSymbol (theModule, name);
79
-
80
91
llvm::StringRef stdFuncName = TargetOp::getFunctionName ();
92
+ auto calleeFunc = getCalleeFromSymbol (theModule, *call.getCallee ());
81
93
82
94
if (auto specialMember = calleeFunc.getCxxSpecialMemberAttr ()) {
83
- auto matches =
84
- (stdFuncName == " vector_ctor" && isa<CXXCtorAttr>(specialMember)) ||
85
- (stdFuncName == " vector_dtor" && isa<CXXDtorAttr>(specialMember));
86
- if (!matches)
87
- return false ;
88
-
89
- auto recordDeclAttr = call.getAstRecordAttr ();
90
- if (!recordDeclAttr ||
91
- !isStdVector (cast<clang::CXXRecordDecl>(recordDeclAttr.getRawDecl ())))
95
+ auto resolved = resolveSpecialMember (specialMember);
96
+ if (!resolved || *resolved != stdFuncName.str ())
92
97
return false ;
93
98
} else {
94
99
auto callExprAttr = call.getAstAttr ();
@@ -97,12 +102,12 @@ template <typename TargetOp> class StdRecognizer {
97
102
98
103
if (!checkArguments (call.getArgOperands ()))
99
104
return false ;
100
-
101
- if (remark)
102
- mlir::emitRemark (call.getLoc ())
103
- << " found call to std::" << stdFuncName << " ()" ;
104
105
}
105
106
107
+ if (remark)
108
+ mlir::emitRemark (call.getLoc ())
109
+ << " found call to std::" << stdFuncName << " ()" ;
110
+
106
111
CIRBaseBuilderTy builder (context);
107
112
builder.setInsertionPointAfter (call.getOperation ());
108
113
TargetOp op = buildCall (builder, call, std::make_index_sequence<numArgs>());
0 commit comments