2828import org .openrewrite .java .tree .J ;
2929import org .openrewrite .java .tree .J .VariableDeclarations .NamedVariable ;
3030import org .openrewrite .java .tree .JavaType ;
31+ import org .openrewrite .java .tree .MethodCall ;
3132
3233import java .util .*;
3334import java .util .concurrent .atomic .AtomicBoolean ;
3738public class JodaTimeScanner extends ScopeAwareVisitor {
3839
3940 @ Getter
40- private final Set < NamedVariable > unsafeVars ;
41+ private final JodaTimeRecipe . Accumulator acc ;
4142
4243 private final Map <NamedVariable , Set <NamedVariable >> varDependencies = new HashMap <>();
44+ private final Map <JavaType , Set <String >> unsafeVarsByType = new HashMap <>();
4345
44- public JodaTimeScanner (Set < NamedVariable > unsafeVars , LinkedList <VariablesInScope > scopes ) {
46+ public JodaTimeScanner (JodaTimeRecipe . Accumulator acc , LinkedList <VariablesInScope > scopes ) {
4547 super (scopes );
46- this .unsafeVars = unsafeVars ;
48+ this .acc = acc ;
4749 }
4850
49- public JodaTimeScanner (Set < NamedVariable > unsafeVars ) {
50- this (unsafeVars , new LinkedList <>());
51+ public JodaTimeScanner (JodaTimeRecipe . Accumulator acc ) {
52+ this (acc , new LinkedList <>());
5153 }
5254
5355 @ Override
5456 public J visitCompilationUnit (J .CompilationUnit cu , ExecutionContext ctx ) {
5557 super .visitCompilationUnit (cu , ctx );
5658 Set <NamedVariable > allReachable = new HashSet <>();
57- for (NamedVariable var : unsafeVars ) {
59+ for (NamedVariable var : acc . getUnsafeVars () ) {
5860 dfs (var , allReachable );
5961 }
60- unsafeVars .addAll (allReachable );
62+ acc . getUnsafeVars () .addAll (allReachable );
6163 return cu ;
6264 }
6365
@@ -67,20 +69,31 @@ public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx)
6769 return variable ;
6870 }
6971 // TODO: handle class variables && method parameters
70- if (! isLocalVar (variable )) {
71- unsafeVars .add (variable );
72+ if (isClassVar (variable )) {
73+ acc . getUnsafeVars () .add (variable );
7274 return variable ;
7375 }
7476 variable = (NamedVariable ) super .visitVariable (variable , ctx );
7577
76- if (!variable .getType ().isAssignableFrom (JODA_CLASS_PATTERN ) || variable . getInitializer () == null ) {
78+ if (!variable .getType ().isAssignableFrom (JODA_CLASS_PATTERN )) {
7779 return variable ;
7880 }
79- List <Expression > sinks = findSinks (variable .getInitializer ());
81+ boolean isMethodParam = getCursor ().getParentTreeCursor () // VariableDeclaration
82+ .getParentTreeCursor () // MethodDeclaration
83+ .getValue () instanceof J .MethodDeclaration ;
84+ Cursor cursor = null ;
85+ if (isMethodParam ) {
86+ cursor = getCursor ();
87+ } else if (variable .getInitializer () != null ) {
88+ cursor = new Cursor (getCursor (), variable .getInitializer ());
89+ }
90+ if (cursor == null ) {
91+ return variable ;
92+ }
93+ List <Expression > sinks = findSinks (cursor );
8094
8195 Cursor currentScope = getCurrentScope ();
82- J .Block block = currentScope .getValue ();
83- new AddSafeCheckMarker (sinks ).visit (block , ctx , currentScope .getParent ());
96+ new AddSafeCheckMarker (sinks ).visit (currentScope .getValue (), ctx , currentScope .getParent ());
8497 processMarkersOnExpression (sinks , variable );
8598 return variable ;
8699 }
@@ -99,12 +112,24 @@ public J.Assignment visitAssignment(J.Assignment assignment, ExecutionContext ct
99112 }
100113 NamedVariable variable = mayBeVar .get ();
101114 Cursor varScope = findScope (variable );
102- List <Expression > sinks = findSinks (assignment .getAssignment ());
115+ List <Expression > sinks = findSinks (new Cursor ( getCursor (), assignment .getAssignment () ));
103116 new AddSafeCheckMarker (sinks ).visit (varScope .getValue (), ctx , varScope .getParent ());
104117 processMarkersOnExpression (sinks , variable );
105118 return assignment ;
106119 }
107120
121+ @ Override
122+ public J .MethodDeclaration visitMethodDeclaration (J .MethodDeclaration method , ExecutionContext ctx ) {
123+ acc .getVarTable ().addVars (method );
124+ unsafeVarsByType .getOrDefault (method .getMethodType (), Collections .emptySet ()).forEach (varName -> {
125+ NamedVariable var = acc .getVarTable ().getVarByName (method .getMethodType (), varName );
126+ if (var != null ) { // var can only be null if method is not correctly type attributed
127+ acc .getUnsafeVars ().add (var );
128+ }
129+ });
130+ return (J .MethodDeclaration ) super .visitMethodDeclaration (method , ctx );
131+ }
132+
108133 private void processMarkersOnExpression (List <Expression > expressions , NamedVariable var ) {
109134 for (Expression expr : expressions ) {
110135 Optional <SafeCheckMarker > mayBeMarker = expr .getMarkers ().findFirst (SafeCheckMarker .class );
@@ -113,7 +138,7 @@ private void processMarkersOnExpression(List<Expression> expressions, NamedVaria
113138 }
114139 SafeCheckMarker marker = mayBeMarker .get ();
115140 if (!marker .isSafe ()) {
116- unsafeVars .add (var );
141+ acc . getUnsafeVars () .add (var );
117142 }
118143 if (!marker .getReferences ().isEmpty ()) {
119144 varDependencies .compute (var , (k , v ) -> v == null ? new HashSet <>() : v ).addAll (marker .getReferences ());
@@ -128,21 +153,16 @@ private boolean isJodaExpr(Expression expression) {
128153 return expression .getType () != null && expression .getType ().isAssignableFrom (JODA_CLASS_PATTERN );
129154 }
130155
131- private List <Expression > findSinks (Expression expr ) {
132- Cursor cursor = new Cursor (getCursor (), expr );
156+ private List <Expression > findSinks (Cursor cursor ) {
133157 Option <SinkFlowSummary > mayBeSinks = Dataflow .startingAt (cursor ).findSinks (new JodaTimeFlowSpec ());
134158 if (mayBeSinks .isNone ()) {
135159 return Collections .emptyList ();
136160 }
137161 return mayBeSinks .some ().getExpressionSinks ();
138162 }
139163
140- private boolean isLocalVar (NamedVariable variable ) {
141- if (!(variable .getVariableType ().getOwner () instanceof JavaType .Method )) {
142- return false ;
143- }
144- J j = getCursor ().dropParentUntil (t -> t instanceof J .Block || t instanceof J .MethodDeclaration ).getValue ();
145- return j instanceof J .Block ;
164+ private boolean isClassVar (NamedVariable variable ) {
165+ return variable .getVariableType ().getOwner () instanceof JavaType .Class ;
146166 }
147167
148168 private void dfs (NamedVariable root , Set <NamedVariable > visited ) {
@@ -167,7 +187,17 @@ public Expression visitExpression(Expression expression, ExecutionContext ctx) {
167187 if (index == -1 ) {
168188 return super .visitExpression (expression , ctx );
169189 }
170- Expression withMarker = expression .withMarkers (expression .getMarkers ().addIfAbsent (getMarker (expression , ctx )));
190+ SafeCheckMarker marker = getMarker (expression , ctx );
191+ if (!marker .isSafe ()) {
192+ Optional <Cursor > mayBeArgCursor = findArgumentExprCursor ();
193+ if (mayBeArgCursor .isPresent ()) {
194+ MethodCall parentMethod = mayBeArgCursor .get ().getParentTreeCursor ().getValue ();
195+ int argPos = parentMethod .getArguments ().indexOf (mayBeArgCursor .get ().getValue ());
196+ String paramName = parentMethod .getMethodType ().getParameterNames ().get (argPos );
197+ unsafeVarsByType .computeIfAbsent (parentMethod .getMethodType (), k -> new HashSet <>()).add (paramName );
198+ }
199+ }
200+ Expression withMarker = expression .withMarkers (expression .getMarkers ().addIfAbsent (marker ));
171201 expressions .set (index , withMarker );
172202 return withMarker ;
173203 }
@@ -185,7 +215,8 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
185215 isSafe = false ;
186216 }
187217 Expression boundaryExpr = boundary .getValue ();
188- J j = new JodaTimeVisitor (new HashSet <>(), scopes ).visit (boundaryExpr , ctx , boundary .getParentTreeCursor ());
218+ J j = new JodaTimeVisitor (new JodaTimeRecipe .Accumulator (), false , scopes )
219+ .visit (boundaryExpr , ctx , boundary .getParentTreeCursor ());
189220 Set <NamedVariable > referencedVars = new HashSet <>();
190221 new FindVarReferences ().visit (expr , referencedVars , getCursor ().getParentTreeCursor ());
191222 AtomicBoolean hasJodaType = new AtomicBoolean ();
@@ -211,6 +242,19 @@ private Cursor findBoundaryCursorForJodaExpr() {
211242 }
212243 return cursor ;
213244 }
245+
246+ private Optional <Cursor > findArgumentExprCursor () {
247+ Cursor cursor = getCursor ();
248+ while (cursor .getValue () instanceof Expression && isJodaExpr (cursor .getValue ())) {
249+ Cursor parentCursor = cursor .getParentTreeCursor ();
250+ if (parentCursor .getValue () instanceof MethodCall
251+ && ((MethodCall ) parentCursor .getValue ()).getArguments ().contains (cursor .getValue ())) {
252+ return Optional .of (cursor );
253+ }
254+ cursor = parentCursor ;
255+ }
256+ return Optional .empty ();
257+ }
214258 }
215259
216260 private class FindVarReferences extends JavaIsoVisitor <Set <NamedVariable >> {
0 commit comments