1919import lombok .Getter ;
2020import lombok .NonNull ;
2121import lombok .RequiredArgsConstructor ;
22+ import lombok .Value ;
2223import org .jspecify .annotations .Nullable ;
2324import org .openrewrite .Cursor ;
2425import org .openrewrite .ExecutionContext ;
@@ -43,6 +44,8 @@ class JodaTimeScanner extends ScopeAwareVisitor {
4344
4445 private final Map <NamedVariable , Set <NamedVariable >> varDependencies = new HashMap <>();
4546 private final Map <JavaType , Set <String >> unsafeVarsByType = new HashMap <>();
47+ private final Map <JavaType .Method , Set <NamedVariable >> methodReferencedVars = new HashMap <>();
48+ private final Map <JavaType .Method , Set <UnresolvedVar >> methodUnresolvedReferencedVars = new HashMap <>();
4649
4750 public JodaTimeScanner (JodaTimeRecipe .Accumulator acc ) {
4851 super (new LinkedList <>());
@@ -57,13 +60,30 @@ public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
5760 dfs (var , allReachable );
5861 }
5962 acc .getUnsafeVars ().addAll (allReachable );
63+
64+ Set <JavaType .Method > unsafeMethods = new HashSet <>();
65+ acc .getSafeMethodMap ().forEach ((method , isSafe ) -> {
66+ if (!isSafe ) {
67+ unsafeMethods .add (method );
68+ return ;
69+ }
70+ Set <NamedVariable > intersection = new HashSet <>(methodReferencedVars .getOrDefault (method , Collections .emptySet ()));
71+ intersection .retainAll (acc .getUnsafeVars ());
72+ if (!intersection .isEmpty ()) {
73+ unsafeMethods .add (method );
74+ }
75+ });
76+ for (JavaType .Method method : unsafeMethods ) {
77+ acc .getSafeMethodMap ().put (method , false );
78+ acc .getUnsafeVars ().addAll (methodReferencedVars .getOrDefault (method , Collections .emptySet ()));
79+ }
6080 return cu ;
6181 }
6282
6383 @ Override
64- public NamedVariable visitVariable (NamedVariable variable , ExecutionContext ctx ) {
84+ public J visitVariable (NamedVariable variable , ExecutionContext ctx ) {
6585 if (!variable .getType ().isAssignableFrom (JODA_CLASS_PATTERN )) {
66- return ( NamedVariable ) super .visitVariable (variable , ctx );
86+ return super .visitVariable (variable , ctx );
6787 }
6888 // TODO: handle class variables
6989 if (isClassVar (variable )) {
@@ -96,35 +116,112 @@ public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx)
96116 }
97117
98118 @ Override
99- public J . Assignment visitAssignment (J .Assignment assignment , ExecutionContext ctx ) {
119+ public J visitAssignment (J .Assignment assignment , ExecutionContext ctx ) {
100120 Expression var = assignment .getVariable ();
101121 // not joda expr or not local variable
102122 if (!isJodaExpr (var ) || !(var instanceof J .Identifier )) {
103- return assignment ;
123+ return super . visitAssignment ( assignment , ctx ) ;
104124 }
105125 J .Identifier ident = (J .Identifier ) var ;
106126 Optional <NamedVariable > mayBeVar = findVarInScope (ident .getSimpleName ());
107127 if (!mayBeVar .isPresent ()) {
108- return assignment ;
128+ return super . visitAssignment ( assignment , ctx ) ;
109129 }
110130 NamedVariable variable = mayBeVar .get ();
111131 Cursor varScope = findScope (variable );
112132 List <Expression > sinks = findSinks (new Cursor (getCursor (), assignment .getAssignment ()));
113133 new AddSafeCheckMarker (sinks ).visit (varScope .getValue (), ctx , varScope .getParentOrThrow ());
114134 processMarkersOnExpression (sinks , variable );
115- return assignment ;
135+ return super . visitAssignment ( assignment , ctx ) ;
116136 }
117137
118138 @ Override
119- public J . MethodDeclaration visitMethodDeclaration (J .MethodDeclaration method , ExecutionContext ctx ) {
139+ public J visitMethodDeclaration (J .MethodDeclaration method , ExecutionContext ctx ) {
120140 acc .getVarTable ().addVars (method );
121141 unsafeVarsByType .getOrDefault (method .getMethodType (), Collections .emptySet ()).forEach (varName -> {
122142 NamedVariable var = acc .getVarTable ().getVarByName (method .getMethodType (), varName );
123143 if (var != null ) { // var can only be null if method is not correctly type attributed
124144 acc .getUnsafeVars ().add (var );
125145 }
126146 });
127- return (J .MethodDeclaration ) super .visitMethodDeclaration (method , ctx );
147+ Set <UnresolvedVar > unresolvedVars = methodUnresolvedReferencedVars .remove (method .getMethodType ());
148+ if (unresolvedVars != null ) {
149+ unresolvedVars .forEach (var -> {
150+ NamedVariable namedVar = acc .getVarTable ().getVarByName (var .getDeclaringType (), var .getVarName ());
151+ if (namedVar != null ) {
152+ methodReferencedVars .computeIfAbsent (method .getMethodType (), k -> new HashSet <>()).add (namedVar );
153+ }
154+ });
155+ }
156+ return super .visitMethodDeclaration (method , ctx );
157+ }
158+
159+ @ Override
160+ public J visitMethodInvocation (J .MethodInvocation method , ExecutionContext ctx ) {
161+ if (!isJodaExpr (method ) || method .getMethodType ().getDeclaringType ().isAssignableFrom (JODA_CLASS_PATTERN )) {
162+ return super .visitMethodInvocation (method , ctx );
163+ }
164+ Cursor boundary = findBoundaryCursorForJodaExpr (getCursor ());
165+ J j = new JodaTimeVisitor (new JodaTimeRecipe .Accumulator (), false , scopes )
166+ .visit (boundary .getValue (), ctx , boundary .getParentTreeCursor ());
167+
168+ boolean isSafe = j != boundary .getValue ();
169+ acc .getSafeMethodMap ().compute (method .getMethodType (), (k , v ) -> v == null ? isSafe : v && isSafe );
170+ J parent = boundary .getParentTreeCursor ().getValue ();
171+ if (parent instanceof NamedVariable ) {
172+ methodReferencedVars .computeIfAbsent (method .getMethodType (), k -> new HashSet <>())
173+ .add ((NamedVariable ) parent );
174+ }
175+ if (parent instanceof J .Assignment ) {
176+ J .Assignment assignment = (J .Assignment ) parent ;
177+ if (assignment .getVariable () instanceof J .Identifier ) {
178+ J .Identifier ident = (J .Identifier ) assignment .getVariable ();
179+ findVarInScope (ident .getSimpleName ())
180+ .map (var -> methodReferencedVars .computeIfAbsent (method .getMethodType (), k -> new HashSet <>()).add (var ));
181+ }
182+ }
183+ if (parent instanceof MethodCall ) {
184+ MethodCall parentMethod = (MethodCall ) parent ;
185+ int argPos = parentMethod .getArguments ().indexOf (boundary .getValue ());
186+ if (argPos == -1 ) {
187+ return method ;
188+ }
189+ String paramName = parentMethod .getMethodType ().getParameterNames ().get (argPos );
190+ NamedVariable var = acc .getVarTable ().getVarByName (parentMethod .getMethodType (), paramName );
191+ if (var != null ) {
192+ methodReferencedVars .computeIfAbsent (method .getMethodType (), k -> new HashSet <>()).add (var );
193+ } else {
194+ methodUnresolvedReferencedVars .computeIfAbsent (method .getMethodType (), k -> new HashSet <>())
195+ .add (new UnresolvedVar (parentMethod .getMethodType (), paramName ));
196+ }
197+ }
198+ return method ;
199+ }
200+
201+ @ Override
202+ public J .Return visitReturn (J .Return _return , ExecutionContext ctx ) {
203+ if (_return .getExpression () == null ) {
204+ return _return ;
205+ }
206+ Expression expr = _return .getExpression ();
207+ if (!isJodaExpr (expr )) {
208+ return _return ;
209+ }
210+ J methodOrLambda = getCursor ().dropParentUntil (j -> j instanceof J .MethodDeclaration || j instanceof J .Lambda ).getValue ();
211+ if (methodOrLambda instanceof J .Lambda ) {
212+ return _return ;
213+ }
214+ J .MethodDeclaration method = (J .MethodDeclaration ) methodOrLambda ;
215+ Expression updatedExpr = (Expression ) new JodaTimeVisitor (acc , true , scopes )
216+ .visit (expr , ctx , getCursor ().getParentTreeCursor ());
217+ boolean isSafe = !isJodaExpr (updatedExpr );
218+
219+ addReferencedVars (expr , method .getMethodType ());
220+ acc .getSafeMethodMap ().compute (method .getMethodType (), (k , v ) -> v == null ? isSafe : v && isSafe );
221+ if (!isSafe ) {
222+ acc .getUnsafeVars ().addAll (methodReferencedVars .get (method .getMethodType ()));
223+ }
224+ return _return ;
128225 }
129226
130227 private void processMarkersOnExpression (List <Expression > expressions , NamedVariable var ) {
@@ -146,7 +243,23 @@ private void processMarkersOnExpression(List<Expression> expressions, NamedVaria
146243 }
147244 }
148245
149- private boolean isJodaExpr (Expression expression ) {
246+ /**
247+ * Traverses the cursor to find the first non-Joda expression in the path.
248+ * If no non-Joda expression is found, it returns the cursor pointing
249+ * to the last Joda expression whose parent is not an Expression.
250+ */
251+ private static Cursor findBoundaryCursorForJodaExpr (Cursor cursor ) {
252+ while (cursor .getValue () instanceof Expression && isJodaExpr (cursor .getValue ())) {
253+ Cursor parent = cursor .getParentTreeCursor ();
254+ if (parent .getValue () instanceof J && !(parent .getValue () instanceof Expression )) {
255+ return cursor ;
256+ }
257+ cursor = parent ;
258+ }
259+ return cursor ;
260+ }
261+
262+ private static boolean isJodaExpr (Expression expression ) {
150263 return expression .getType () != null && expression .getType ().isAssignableFrom (JODA_CLASS_PATTERN );
151264 }
152265
@@ -172,6 +285,13 @@ private void dfs(NamedVariable root, Set<NamedVariable> visited) {
172285 }
173286 }
174287
288+ private void addReferencedVars (Expression expr , JavaType .Method method ) {
289+ Set <@ Nullable NamedVariable > referencedVars = new HashSet <>();
290+ new FindVarReferences ().visit (expr , referencedVars , getCursor ().getParentTreeCursor ());
291+ referencedVars .remove (null );
292+ methodReferencedVars .computeIfAbsent (method , k -> new HashSet <>()).addAll (referencedVars );
293+ }
294+
175295 @ RequiredArgsConstructor
176296 private class AddSafeCheckMarker extends JavaIsoVisitor <ExecutionContext > {
177297
@@ -205,11 +325,12 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
205325 return mayBeMarker .get ();
206326 }
207327
208- Cursor boundary = findBoundaryCursorForJodaExpr ();
328+ Cursor boundary = findBoundaryCursorForJodaExpr (getCursor () );
209329 boolean isSafe = true ;
210- // TODO: handle return statement
211330 if (boundary .getParentTreeCursor ().getValue () instanceof J .Return ) {
212- isSafe = false ;
331+ // TODO: handle return statement in lambda
332+ isSafe = boundary .dropParentUntil (j -> j instanceof J .MethodDeclaration || j instanceof J .Lambda )
333+ .getValue () instanceof J .MethodDeclaration ;
213334 }
214335 Expression boundaryExpr = boundary .getValue ();
215336 J j = new JodaTimeVisitor (new JodaTimeRecipe .Accumulator (), false , scopes )
@@ -223,23 +344,6 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
223344 return new SafeCheckMarker (UUID .randomUUID (), isSafe , referencedVars );
224345 }
225346
226- /**
227- * Traverses the cursor to find the first non-Joda expression in the path.
228- * If no non-Joda expression is found, it returns the cursor pointing
229- * to the last Joda expression whose parent is not an Expression.
230- */
231- private Cursor findBoundaryCursorForJodaExpr () {
232- Cursor cursor = getCursor ();
233- while (cursor .getValue () instanceof Expression && isJodaExpr (cursor .getValue ())) {
234- Cursor parent = cursor .getParentTreeCursor ();
235- if (parent .getValue () instanceof J && !(parent .getValue () instanceof Expression )) {
236- return cursor ;
237- }
238- cursor = parent ;
239- }
240- return cursor ;
241- }
242-
243347 private Optional <Cursor > findArgumentExprCursor () {
244348 Cursor cursor = getCursor ();
245349 while (cursor .getValue () instanceof Expression && isJodaExpr (cursor .getValue ())) {
@@ -283,4 +387,10 @@ public Expression visitExpression(Expression expression, AtomicBoolean hasJodaTy
283387 return super .visitExpression (expression , hasJodaType );
284388 }
285389 }
390+
391+ @ Value
392+ private static class UnresolvedVar {
393+ JavaType declaringType ;
394+ String varName ;
395+ }
286396}
0 commit comments