1919import lombok .Getter ;
2020import lombok .NonNull ;
2121import lombok .RequiredArgsConstructor ;
22+ import org .jspecify .annotations .Nullable ;
2223import org .openrewrite .Cursor ;
2324import org .openrewrite .ExecutionContext ;
2425import org .openrewrite .analysis .dataflow .Dataflow ;
2829import org .openrewrite .java .tree .J ;
2930import org .openrewrite .java .tree .J .VariableDeclarations .NamedVariable ;
3031import org .openrewrite .java .tree .JavaType ;
32+ import org .openrewrite .java .tree .MethodCall ;
3133
3234import java .util .*;
3335import java .util .concurrent .atomic .AtomicBoolean ;
3436
3537import static org .openrewrite .java .migrate .joda .templates .TimeClassNames .JODA_CLASS_PATTERN ;
3638
37- public class JodaTimeScanner extends ScopeAwareVisitor {
39+ class JodaTimeScanner extends ScopeAwareVisitor {
3840
3941 @ Getter
40- private final Set < NamedVariable > unsafeVars ;
42+ private final JodaTimeRecipe . Accumulator acc ;
4143
4244 private final Map <NamedVariable , Set <NamedVariable >> varDependencies = new HashMap <>();
45+ private final Map <JavaType , Set <String >> unsafeVarsByType = new HashMap <>();
4346
44- public JodaTimeScanner (Set <NamedVariable > unsafeVars , LinkedList <VariablesInScope > scopes ) {
45- super (scopes );
46- this .unsafeVars = unsafeVars ;
47- }
48-
49- public JodaTimeScanner (Set <NamedVariable > unsafeVars ) {
50- this (unsafeVars , new LinkedList <>());
47+ public JodaTimeScanner (JodaTimeRecipe .Accumulator acc ) {
48+ super (new LinkedList <>());
49+ this .acc = acc ;
5150 }
5251
5352 @ Override
5453 public J visitCompilationUnit (J .CompilationUnit cu , ExecutionContext ctx ) {
5554 super .visitCompilationUnit (cu , ctx );
5655 Set <NamedVariable > allReachable = new HashSet <>();
57- for (NamedVariable var : unsafeVars ) {
56+ for (NamedVariable var : acc . getUnsafeVars () ) {
5857 dfs (var , allReachable );
5958 }
60- unsafeVars .addAll (allReachable );
59+ acc . getUnsafeVars () .addAll (allReachable );
6160 return cu ;
6261 }
6362
@@ -66,21 +65,32 @@ public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx)
6665 if (!variable .getType ().isAssignableFrom (JODA_CLASS_PATTERN )) {
6766 return variable ;
6867 }
69- // TODO: handle class variables && method parameters
70- if (! isLocalVar (variable )) {
71- unsafeVars .add (variable );
68+ // TODO: handle class variables
69+ if (isClassVar (variable )) {
70+ acc . getUnsafeVars () .add (variable );
7271 return variable ;
7372 }
7473 variable = (NamedVariable ) super .visitVariable (variable , ctx );
7574
76- if (!variable .getType ().isAssignableFrom (JODA_CLASS_PATTERN ) || variable .getInitializer () == null ) {
75+ if (!variable .getType ().isAssignableFrom (JODA_CLASS_PATTERN )) {
76+ return variable ;
77+ }
78+ boolean isMethodParam = getCursor ().getParentTreeCursor () // VariableDeclaration
79+ .getParentTreeCursor () // MethodDeclaration
80+ .getValue () instanceof J .MethodDeclaration ;
81+ Cursor cursor = null ;
82+ if (isMethodParam ) {
83+ cursor = getCursor ();
84+ } else if (variable .getInitializer () != null ) {
85+ cursor = new Cursor (getCursor (), variable .getInitializer ());
86+ }
87+ if (cursor == null ) {
7788 return variable ;
7889 }
79- List <Expression > sinks = findSinks (variable . getInitializer () );
90+ List <Expression > sinks = findSinks (cursor );
8091
8192 Cursor currentScope = getCurrentScope ();
82- J .Block block = currentScope .getValue ();
83- new AddSafeCheckMarker (sinks ).visit (block , ctx , currentScope .getParent ());
93+ new AddSafeCheckMarker (sinks ).visit (currentScope .getValue (), ctx , currentScope .getParentOrThrow ());
8494 processMarkersOnExpression (sinks , variable );
8595 return variable ;
8696 }
@@ -99,12 +109,24 @@ public J.Assignment visitAssignment(J.Assignment assignment, ExecutionContext ct
99109 }
100110 NamedVariable variable = mayBeVar .get ();
101111 Cursor varScope = findScope (variable );
102- List <Expression > sinks = findSinks (assignment .getAssignment ());
103- new AddSafeCheckMarker (sinks ).visit (varScope .getValue (), ctx , varScope .getParent ());
112+ List <Expression > sinks = findSinks (new Cursor ( getCursor (), assignment .getAssignment () ));
113+ new AddSafeCheckMarker (sinks ).visit (varScope .getValue (), ctx , varScope .getParentOrThrow ());
104114 processMarkersOnExpression (sinks , variable );
105115 return assignment ;
106116 }
107117
118+ @ Override
119+ public J .MethodDeclaration visitMethodDeclaration (J .MethodDeclaration method , ExecutionContext ctx ) {
120+ acc .getVarTable ().addVars (method );
121+ unsafeVarsByType .getOrDefault (method .getMethodType (), Collections .emptySet ()).forEach (varName -> {
122+ NamedVariable var = acc .getVarTable ().getVarByName (method .getMethodType (), varName );
123+ if (var != null ) { // var can only be null if method is not correctly type attributed
124+ acc .getUnsafeVars ().add (var );
125+ }
126+ });
127+ return (J .MethodDeclaration ) super .visitMethodDeclaration (method , ctx );
128+ }
129+
108130 private void processMarkersOnExpression (List <Expression > expressions , NamedVariable var ) {
109131 for (Expression expr : expressions ) {
110132 Optional <SafeCheckMarker > mayBeMarker = expr .getMarkers ().findFirst (SafeCheckMarker .class );
@@ -113,7 +135,7 @@ private void processMarkersOnExpression(List<Expression> expressions, NamedVaria
113135 }
114136 SafeCheckMarker marker = mayBeMarker .get ();
115137 if (!marker .isSafe ()) {
116- unsafeVars .add (var );
138+ acc . getUnsafeVars () .add (var );
117139 }
118140 if (!marker .getReferences ().isEmpty ()) {
119141 varDependencies .compute (var , (k , v ) -> v == null ? new HashSet <>() : v ).addAll (marker .getReferences ());
@@ -128,21 +150,16 @@ private boolean isJodaExpr(Expression expression) {
128150 return expression .getType () != null && expression .getType ().isAssignableFrom (JODA_CLASS_PATTERN );
129151 }
130152
131- private List <Expression > findSinks (Expression expr ) {
132- Cursor cursor = new Cursor (getCursor (), expr );
153+ private List <Expression > findSinks (Cursor cursor ) {
133154 Option <SinkFlowSummary > mayBeSinks = Dataflow .startingAt (cursor ).findSinks (new JodaTimeFlowSpec ());
134155 if (mayBeSinks .isNone ()) {
135156 return Collections .emptyList ();
136157 }
137158 return mayBeSinks .some ().getExpressionSinks ();
138159 }
139160
140- private boolean isLocalVar (NamedVariable variable ) {
141- if (!(variable .getVariableType ().getOwner () instanceof JavaType .Method )) {
142- return false ;
143- }
144- J j = getCursor ().dropParentUntil (t -> t instanceof J .Block || t instanceof J .MethodDeclaration ).getValue ();
145- return j instanceof J .Block ;
161+ private boolean isClassVar (NamedVariable variable ) {
162+ return variable .getVariableType ().getOwner () instanceof JavaType .Class ;
146163 }
147164
148165 private void dfs (NamedVariable root , Set <NamedVariable > visited ) {
@@ -167,7 +184,17 @@ public Expression visitExpression(Expression expression, ExecutionContext ctx) {
167184 if (index == -1 ) {
168185 return super .visitExpression (expression , ctx );
169186 }
170- Expression withMarker = expression .withMarkers (expression .getMarkers ().addIfAbsent (getMarker (expression , ctx )));
187+ SafeCheckMarker marker = getMarker (expression , ctx );
188+ if (!marker .isSafe ()) {
189+ Optional <Cursor > mayBeArgCursor = findArgumentExprCursor ();
190+ if (mayBeArgCursor .isPresent ()) {
191+ MethodCall parentMethod = mayBeArgCursor .get ().getParentTreeCursor ().getValue ();
192+ int argPos = parentMethod .getArguments ().indexOf (mayBeArgCursor .get ().getValue ());
193+ String paramName = parentMethod .getMethodType ().getParameterNames ().get (argPos );
194+ unsafeVarsByType .computeIfAbsent (parentMethod .getMethodType (), k -> new HashSet <>()).add (paramName );
195+ }
196+ }
197+ Expression withMarker = expression .withMarkers (expression .getMarkers ().addIfAbsent (marker ));
171198 expressions .set (index , withMarker );
172199 return withMarker ;
173200 }
@@ -185,8 +212,9 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
185212 isSafe = false ;
186213 }
187214 Expression boundaryExpr = boundary .getValue ();
188- J j = new JodaTimeVisitor (new HashSet <>(), scopes ).visit (boundaryExpr , ctx , boundary .getParentTreeCursor ());
189- Set <NamedVariable > referencedVars = new HashSet <>();
215+ J j = new JodaTimeVisitor (new JodaTimeRecipe .Accumulator (), false , scopes )
216+ .visit (boundaryExpr , ctx , boundary .getParentTreeCursor ());
217+ Set <@ Nullable NamedVariable > referencedVars = new HashSet <>();
190218 new FindVarReferences ().visit (expr , referencedVars , getCursor ().getParentTreeCursor ());
191219 AtomicBoolean hasJodaType = new AtomicBoolean ();
192220 new HasJodaType ().visit (j , hasJodaType );
@@ -211,12 +239,25 @@ private Cursor findBoundaryCursorForJodaExpr() {
211239 }
212240 return cursor ;
213241 }
242+
243+ private Optional <Cursor > findArgumentExprCursor () {
244+ Cursor cursor = getCursor ();
245+ while (cursor .getValue () instanceof Expression && isJodaExpr (cursor .getValue ())) {
246+ Cursor parentCursor = cursor .getParentTreeCursor ();
247+ if (parentCursor .getValue () instanceof MethodCall &&
248+ ((MethodCall ) parentCursor .getValue ()).getArguments ().contains (cursor .getValue ())) {
249+ return Optional .of (cursor );
250+ }
251+ cursor = parentCursor ;
252+ }
253+ return Optional .empty ();
254+ }
214255 }
215256
216- private class FindVarReferences extends JavaIsoVisitor <Set <NamedVariable >> {
257+ private class FindVarReferences extends JavaIsoVisitor <Set <@ Nullable NamedVariable >> {
217258
218259 @ Override
219- public J .Identifier visitIdentifier (J .Identifier ident , Set <NamedVariable > vars ) {
260+ public J .Identifier visitIdentifier (J .Identifier ident , Set <@ Nullable NamedVariable > vars ) {
220261 if (!isJodaExpr (ident ) || ident .getFieldType () == null ) {
221262 return ident ;
222263 }
0 commit comments