@@ -28,13 +28,37 @@ extern DpctOption<opt, bool> AsyncHandler;
2828namespace clang {
2929namespace dpct {
3030
31+ void GraphAnalysisRule::registerMatcher (MatchFinder &MF) {
32+ auto kernelNodeTypeName = [&]() {
33+ return hasAnyName (" cudaKernelNodeParams" );
34+ };
35+ MF.addMatcher (
36+ memberExpr (
37+ hasObjectExpression (hasType (type (hasUnqualifiedDesugaredType (
38+ recordType (hasDeclaration (recordDecl (kernelNodeTypeName ()))))))))
39+ .bind (" KernelNodeType" ),
40+ this );
41+ }
42+
43+ void GraphAnalysisRule::runRule (const MatchFinder::MatchResult &Result) {
44+ if (auto ME = getNodeAsType<MemberExpr>(Result, " KernelNodeType" )) {
45+ auto BaseTy = DpctGlobalInfo::getUnqualifiedTypeName (
46+ ME->getBase ()->getType ().getDesugaredType (*Result.Context ),
47+ *Result.Context );
48+ auto MemberName = ME->getMemberNameInfo ().getAsString ();
49+ if (BaseTy == " cudaKernelNodeParams" ) {
50+ DpctGlobalInfo::setUseWrapperRegisterFnPtr ();
51+ }
52+ }
53+ }
54+
3155void GraphRule::registerMatcher (MatchFinder &MF) {
3256 auto functionName = [&]() {
33- return hasAnyName (" cudaGraphInstantiate " , " cudaGraphLaunch " ,
34- " cudaGraphExecDestroy " , " cudaGraphAddEmptyNode " ,
35- " cudaGraphAddDependencies " , " cudaGraphExecUpdate " ,
36- " cudaGraphNodeGetType" , " cudaGraphGetNodes" ,
37- " cudaGraphGetRootNodes" , " cudaGraphDestroy" );
57+ return hasAnyName (
58+ " cudaGraphInstantiate " , " cudaGraphLaunch " , " cudaGraphExecDestroy " ,
59+ " cudaGraphAddEmptyNode " , " cudaGraphAddDependencies " ,
60+ " cudaGraphExecUpdate " , " cudaGraphNodeGetType" , " cudaGraphGetNodes" ,
61+ " cudaGraphGetRootNodes" , " cudaGraphDestroy" , " cudaGraphAddKernelNode " );
3862 };
3963 MF.addMatcher (
4064 callExpr (callee (functionDecl (functionName ()))).bind (" FunctionCall" ),
@@ -55,29 +79,67 @@ void GraphRule::runRule(const MatchFinder::MatchResult &Result) {
5579 *Result.Context );
5680 auto MemberName = ME->getMemberNameInfo ().getAsString ();
5781 if (BaseTy == " cudaKernelNodeParams" ) {
58- std::cout <<" NODE PARAMS FOUND\n " ;
59- DpctGlobalInfo::setCVersionCUDALaunchUsed ();
6082 auto FieldName = KernelNodeParamNames[MemberName];
6183 if (FieldName.empty ()) {
6284 report (ME->getBeginLoc (), Diagnostics::API_NOT_MIGRATED, false ,
6385 DpctGlobalInfo::getOriginalTypeName (ME->getBase ()->getType ()) +
6486 " ::" + ME->getMemberDecl ()->getName ().str ());
6587 return ;
66-
6788 }
68- // if(FieldName == "func"){
69- // Check for the binary operator and fetch the RHS
70- // Strip the explicit typecast if it exists
71- // Check for VarDecl on the StrippedRHS
72- // If not a VarDecl, then insert user warning
73- // Check for VarDecl Type to be a FunctionDecl
74- // If FunctionDecl, then
75- // VarDecl, get var name, Get kernel_node_params variable name
76- // Create the expression, hardcoded strting
77- // Create new replace object and emplace transformation (nodeParams.set_func((void*)dpct::wrapper_register(&incrementKernel_wrapper).get());)
78- // If VarDecl and not a FunctionDecl and if type of VarDecl is function pointer
79- // Create a hardcoded string (nodeParams.set_func(a.get()));
80- // }
89+ if (FieldName == " func" ) {
90+ if (auto BO = dyn_cast<BinaryOperator>(
91+ getParentAsAssignedBO (ME, *Result.Context ))) {
92+ auto *LHS = BO->getLHS ()->IgnoreCasts ();
93+ if (auto *ME = dyn_cast<MemberExpr>(LHS)) {
94+ std::cout << " Member Expr\n " ;
95+ // Get the base expression of the MemberExpr
96+ auto *Base = ME->getBase ()->IgnoreImpCasts ();
97+
98+ // Check if the base is a DeclRefExpr
99+ if (auto *DRE = dyn_cast<DeclRefExpr>(Base)) {
100+ std::cout << " DeclRef Expr\n " ;
101+ // Get the variable declaration
102+ if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl ())) {
103+ std::cout << " Base VarDecl Expr\n " ;
104+ // Get the variable name
105+ std::string varName = VD->getNameAsString ();
106+
107+ // Get the RHS of the assignment
108+ clang::Expr *RHS = BO->getRHS ()->IgnoreCasts ();
109+
110+ // Check if RHS is a DeclRefExpr referring to a function
111+ if (auto *RHS_DRE = dyn_cast<DeclRefExpr>(RHS)) {
112+ std::cout << " RHS DRE Expr\n " ;
113+ if (auto *FD = dyn_cast<FunctionDecl>(RHS_DRE->getDecl ())) {
114+ std::cout << " RHS FunctionDecl Expr\n " ;
115+ // Get the function name
116+ std::string funcName = FD->getNameAsString ();
117+ std::string wrapperName = funcName + " _wrapper" ;
118+
119+ // Construct the replacement expression
120+ std::string ReplacementExpr =
121+ varName + " .set_func((void*) dpct::wrapper_register(&" +
122+ wrapperName + " ).get());" ;
123+ std::cout << " Replacement String: " << ReplacementExpr
124+ << " \n " ;
125+ std::string rp = " (void*) dpct::wrapper_register(&" +
126+ wrapperName + " ).get()" ;
127+ StringRef ReplacedArg = rp;
128+ emplaceTransformation (ReplaceMemberAssignAsSetMethod (
129+ BO, ME, FieldName, ReplacedArg));
130+ // Replace the original assignment with the new expression
131+ // emplaceTransformation(
132+ // new ReplaceToken(ME->getBeginLoc(), ME->getEndLoc(),
133+ // std ::move(ReplacementExpr)));
134+ return ;
135+ }
136+ }
137+ }
138+ }
139+ }
140+ }
141+ }
142+ std::cout << " Coming here\n " ;
81143 if (auto BO = getParentAsAssignedBO (ME, *Result.Context )) {
82144 StringRef ReplacedArg = " " ;
83145 emplaceTransformation (
@@ -106,8 +168,8 @@ const Expr *GraphRule::getParentAsAssignedBO(const Expr *E,
106168 return nullptr ;
107169}
108170
109- // Return the binary operator if E is the lhs of an assign expression, otherwise
110- // nullptr.
171+ // Return the binary operator if E is the lhs of an assign expression,
172+ // otherwise nullptr.
111173const Expr *GraphRule::getAssignedBO (const Expr *E, ASTContext &Context) {
112174 if (dyn_cast<MemberExpr>(E)) {
113175 // Continue finding parents when E is MemberExpr.
0 commit comments