Skip to content

Commit 2c654e8

Browse files
authored
Joda-Time to Java time: Add support for Method Return Statement Migration (#626)
* Joda-Time to Java time: Add support for Method Return Expression Migration * Add few tests * remove foo test * formatting
1 parent ccb2010 commit 2c654e8

File tree

5 files changed

+467
-59
lines changed

5 files changed

+467
-59
lines changed

src/main/java/org/openrewrite/java/migrate/joda/JodaTimeRecipe.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public JodaTimeVisitor getVisitor(Accumulator acc) {
5454
@Getter
5555
public static class Accumulator {
5656
private final Set<NamedVariable> unsafeVars = new HashSet<>();
57+
private final Map<JavaType.Method, Boolean> safeMethodMap = new HashMap<>();
5758
private final VarTable varTable = new VarTable();
5859
}
5960

src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java

Lines changed: 139 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import lombok.Getter;
2020
import lombok.NonNull;
2121
import lombok.RequiredArgsConstructor;
22+
import lombok.Value;
2223
import org.jspecify.annotations.Nullable;
2324
import org.openrewrite.Cursor;
2425
import org.openrewrite.ExecutionContext;
@@ -43,6 +44,8 @@ class JodaTimeScanner extends ScopeAwareVisitor {
4344

4445
private final Map<NamedVariable, Set<NamedVariable>> varDependencies = new HashMap<>();
4546
private final Map<JavaType, Set<String>> unsafeVarsByType = new HashMap<>();
47+
private final Map<JavaType.Method, Set<NamedVariable>> methodReferencedVars = new HashMap<>();
48+
private final Map<JavaType.Method, Set<UnresolvedVar>> methodUnresolvedReferencedVars = new HashMap<>();
4649

4750
public JodaTimeScanner(JodaTimeRecipe.Accumulator acc) {
4851
super(new LinkedList<>());
@@ -57,13 +60,30 @@ public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
5760
dfs(var, allReachable);
5861
}
5962
acc.getUnsafeVars().addAll(allReachable);
63+
64+
Set<JavaType.Method> unsafeMethods = new HashSet<>();
65+
acc.getSafeMethodMap().forEach((method, isSafe) -> {
66+
if (!isSafe) {
67+
unsafeMethods.add(method);
68+
return;
69+
}
70+
Set<NamedVariable> intersection = new HashSet<>(methodReferencedVars.getOrDefault(method, Collections.emptySet()));
71+
intersection.retainAll(acc.getUnsafeVars());
72+
if (!intersection.isEmpty()) {
73+
unsafeMethods.add(method);
74+
}
75+
});
76+
for (JavaType.Method method : unsafeMethods) {
77+
acc.getSafeMethodMap().put(method, false);
78+
acc.getUnsafeVars().addAll(methodReferencedVars.getOrDefault(method, Collections.emptySet()));
79+
}
6080
return cu;
6181
}
6282

6383
@Override
64-
public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx) {
84+
public J visitVariable(NamedVariable variable, ExecutionContext ctx) {
6585
if (!variable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
66-
return (NamedVariable) super.visitVariable(variable, ctx);
86+
return super.visitVariable(variable, ctx);
6787
}
6888
// TODO: handle class variables
6989
if (isClassVar(variable)) {
@@ -96,35 +116,112 @@ public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx)
96116
}
97117

98118
@Override
99-
public J.Assignment visitAssignment(J.Assignment assignment, ExecutionContext ctx) {
119+
public J visitAssignment(J.Assignment assignment, ExecutionContext ctx) {
100120
Expression var = assignment.getVariable();
101121
// not joda expr or not local variable
102122
if (!isJodaExpr(var) || !(var instanceof J.Identifier)) {
103-
return assignment;
123+
return super.visitAssignment(assignment, ctx);
104124
}
105125
J.Identifier ident = (J.Identifier) var;
106126
Optional<NamedVariable> mayBeVar = findVarInScope(ident.getSimpleName());
107127
if (!mayBeVar.isPresent()) {
108-
return assignment;
128+
return super.visitAssignment(assignment, ctx);
109129
}
110130
NamedVariable variable = mayBeVar.get();
111131
Cursor varScope = findScope(variable);
112132
List<Expression> sinks = findSinks(new Cursor(getCursor(), assignment.getAssignment()));
113133
new AddSafeCheckMarker(sinks).visit(varScope.getValue(), ctx, varScope.getParentOrThrow());
114134
processMarkersOnExpression(sinks, variable);
115-
return assignment;
135+
return super.visitAssignment(assignment, ctx);
116136
}
117137

118138
@Override
119-
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
139+
public J visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
120140
acc.getVarTable().addVars(method);
121141
unsafeVarsByType.getOrDefault(method.getMethodType(), Collections.emptySet()).forEach(varName -> {
122142
NamedVariable var = acc.getVarTable().getVarByName(method.getMethodType(), varName);
123143
if (var != null) { // var can only be null if method is not correctly type attributed
124144
acc.getUnsafeVars().add(var);
125145
}
126146
});
127-
return (J.MethodDeclaration) super.visitMethodDeclaration(method, ctx);
147+
Set<UnresolvedVar> unresolvedVars = methodUnresolvedReferencedVars.remove(method.getMethodType());
148+
if (unresolvedVars != null) {
149+
unresolvedVars.forEach(var -> {
150+
NamedVariable namedVar = acc.getVarTable().getVarByName(var.getDeclaringType(), var.getVarName());
151+
if (namedVar != null) {
152+
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(namedVar);
153+
}
154+
});
155+
}
156+
return super.visitMethodDeclaration(method, ctx);
157+
}
158+
159+
@Override
160+
public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
161+
if (!isJodaExpr(method) || method.getMethodType().getDeclaringType().isAssignableFrom(JODA_CLASS_PATTERN)) {
162+
return super.visitMethodInvocation(method, ctx);
163+
}
164+
Cursor boundary = findBoundaryCursorForJodaExpr(getCursor());
165+
J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes)
166+
.visit(boundary.getValue(), ctx, boundary.getParentTreeCursor());
167+
168+
boolean isSafe = j != boundary.getValue();
169+
acc.getSafeMethodMap().compute(method.getMethodType(), (k, v) -> v == null ? isSafe : v && isSafe);
170+
J parent = boundary.getParentTreeCursor().getValue();
171+
if (parent instanceof NamedVariable) {
172+
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>())
173+
.add((NamedVariable) parent);
174+
}
175+
if (parent instanceof J.Assignment) {
176+
J.Assignment assignment = (J.Assignment) parent;
177+
if (assignment.getVariable() instanceof J.Identifier) {
178+
J.Identifier ident = (J.Identifier) assignment.getVariable();
179+
findVarInScope(ident.getSimpleName())
180+
.map(var -> methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(var));
181+
}
182+
}
183+
if (parent instanceof MethodCall) {
184+
MethodCall parentMethod = (MethodCall) parent;
185+
int argPos = parentMethod.getArguments().indexOf(boundary.getValue());
186+
if (argPos == -1) {
187+
return method;
188+
}
189+
String paramName = parentMethod.getMethodType().getParameterNames().get(argPos);
190+
NamedVariable var = acc.getVarTable().getVarByName(parentMethod.getMethodType(), paramName);
191+
if (var != null) {
192+
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(var);
193+
} else {
194+
methodUnresolvedReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>())
195+
.add(new UnresolvedVar(parentMethod.getMethodType(), paramName));
196+
}
197+
}
198+
return method;
199+
}
200+
201+
@Override
202+
public J.Return visitReturn(J.Return _return, ExecutionContext ctx) {
203+
if (_return.getExpression() == null) {
204+
return _return;
205+
}
206+
Expression expr = _return.getExpression();
207+
if (!isJodaExpr(expr)) {
208+
return _return;
209+
}
210+
J methodOrLambda = getCursor().dropParentUntil(j -> j instanceof J.MethodDeclaration || j instanceof J.Lambda).getValue();
211+
if (methodOrLambda instanceof J.Lambda) {
212+
return _return;
213+
}
214+
J.MethodDeclaration method = (J.MethodDeclaration) methodOrLambda;
215+
Expression updatedExpr = (Expression) new JodaTimeVisitor(acc, true, scopes)
216+
.visit(expr, ctx, getCursor().getParentTreeCursor());
217+
boolean isSafe = !isJodaExpr(updatedExpr);
218+
219+
addReferencedVars(expr, method.getMethodType());
220+
acc.getSafeMethodMap().compute(method.getMethodType(), (k, v) -> v == null ? isSafe : v && isSafe);
221+
if (!isSafe) {
222+
acc.getUnsafeVars().addAll(methodReferencedVars.get(method.getMethodType()));
223+
}
224+
return _return;
128225
}
129226

130227
private void processMarkersOnExpression(List<Expression> expressions, NamedVariable var) {
@@ -146,7 +243,23 @@ private void processMarkersOnExpression(List<Expression> expressions, NamedVaria
146243
}
147244
}
148245

149-
private boolean isJodaExpr(Expression expression) {
246+
/**
247+
* Traverses the cursor to find the first non-Joda expression in the path.
248+
* If no non-Joda expression is found, it returns the cursor pointing
249+
* to the last Joda expression whose parent is not an Expression.
250+
*/
251+
private static Cursor findBoundaryCursorForJodaExpr(Cursor cursor) {
252+
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
253+
Cursor parent = cursor.getParentTreeCursor();
254+
if (parent.getValue() instanceof J && !(parent.getValue() instanceof Expression)) {
255+
return cursor;
256+
}
257+
cursor = parent;
258+
}
259+
return cursor;
260+
}
261+
262+
private static boolean isJodaExpr(Expression expression) {
150263
return expression.getType() != null && expression.getType().isAssignableFrom(JODA_CLASS_PATTERN);
151264
}
152265

@@ -172,6 +285,13 @@ private void dfs(NamedVariable root, Set<NamedVariable> visited) {
172285
}
173286
}
174287

288+
private void addReferencedVars(Expression expr, JavaType.Method method) {
289+
Set<@Nullable NamedVariable> referencedVars = new HashSet<>();
290+
new FindVarReferences().visit(expr, referencedVars, getCursor().getParentTreeCursor());
291+
referencedVars.remove(null);
292+
methodReferencedVars.computeIfAbsent(method, k -> new HashSet<>()).addAll(referencedVars);
293+
}
294+
175295
@RequiredArgsConstructor
176296
private class AddSafeCheckMarker extends JavaIsoVisitor<ExecutionContext> {
177297

@@ -205,11 +325,12 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
205325
return mayBeMarker.get();
206326
}
207327

208-
Cursor boundary = findBoundaryCursorForJodaExpr();
328+
Cursor boundary = findBoundaryCursorForJodaExpr(getCursor());
209329
boolean isSafe = true;
210-
// TODO: handle return statement
211330
if (boundary.getParentTreeCursor().getValue() instanceof J.Return) {
212-
isSafe = false;
331+
// TODO: handle return statement in lambda
332+
isSafe = boundary.dropParentUntil(j -> j instanceof J.MethodDeclaration || j instanceof J.Lambda)
333+
.getValue() instanceof J.MethodDeclaration;
213334
}
214335
Expression boundaryExpr = boundary.getValue();
215336
J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes)
@@ -223,23 +344,6 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
223344
return new SafeCheckMarker(UUID.randomUUID(), isSafe, referencedVars);
224345
}
225346

226-
/**
227-
* Traverses the cursor to find the first non-Joda expression in the path.
228-
* If no non-Joda expression is found, it returns the cursor pointing
229-
* to the last Joda expression whose parent is not an Expression.
230-
*/
231-
private Cursor findBoundaryCursorForJodaExpr() {
232-
Cursor cursor = getCursor();
233-
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
234-
Cursor parent = cursor.getParentTreeCursor();
235-
if (parent.getValue() instanceof J && !(parent.getValue() instanceof Expression)) {
236-
return cursor;
237-
}
238-
cursor = parent;
239-
}
240-
return cursor;
241-
}
242-
243347
private Optional<Cursor> findArgumentExprCursor() {
244348
Cursor cursor = getCursor();
245349
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
@@ -283,4 +387,10 @@ public Expression visitExpression(Expression expression, AtomicBoolean hasJodaTy
283387
return super.visitExpression(expression, hasJodaType);
284388
}
285389
}
390+
391+
@Value
392+
private static class UnresolvedVar {
393+
JavaType declaringType;
394+
String varName;
395+
}
286396
}

src/main/java/org/openrewrite/java/migrate/joda/JodaTimeVisitor.java

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,22 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx)
8484
return super.visitCompilationUnit(cu, ctx);
8585
}
8686

87+
@Override
88+
public @NonNull J visitMethodDeclaration(@NonNull J.MethodDeclaration method, @NonNull ExecutionContext ctx) {
89+
J.MethodDeclaration m = (J.MethodDeclaration) super.visitMethodDeclaration(method, ctx);
90+
if (m.getReturnTypeExpression() == null || !m.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
91+
return m;
92+
}
93+
if (safeMigration && !acc.getSafeMethodMap().getOrDefault(m.getMethodType(), false)) {
94+
return m;
95+
}
96+
97+
JavaType.Class returnType = TimeClassMap.getJavaTimeType(((JavaType.Class) m.getType()).getFullyQualifiedName());
98+
J.Identifier returnExpr = TypeTree.build(returnType.getClassName()).withType(returnType).withPrefix(Space.format(" "));
99+
return m.withReturnTypeExpression(returnExpr)
100+
.withMethodType(m.getMethodType().withReturnType(returnType));
101+
}
102+
87103
@Override
88104
public @NonNull J visitVariableDeclarations(@NonNull J.VariableDeclarations multiVariable, @NonNull ExecutionContext ctx) {
89105
if (multiVariable.getTypeExpression() == null || !multiVariable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
@@ -147,6 +163,13 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx)
147163
@Override
148164
public @NonNull J visitMethodInvocation(@NonNull J.MethodInvocation method, @NonNull ExecutionContext ctx) {
149165
J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, ctx);
166+
167+
// internal method with Joda class as return type
168+
if (!method.getMethodType().getDeclaringType().isAssignableFrom(JODA_CLASS_PATTERN) &&
169+
method.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
170+
return migrateNonJodaMethod(method, m);
171+
}
172+
150173
if (hasJodaType(m.getArguments()) || isJodaVarRef(m.getSelect())) {
151174
return method;
152175
}
@@ -179,7 +202,9 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx)
179202

180203
JavaType.FullyQualified jodaType = ((JavaType.Class) ident.getType());
181204
JavaType.FullyQualified fqType = TimeClassMap.getJavaTimeType(jodaType.getFullyQualifiedName());
182-
205+
if (fqType == null) {
206+
return ident;
207+
}
183208
return ident.withType(fqType)
184209
.withFieldType(ident.getFieldType().withType(fqType));
185210
}
@@ -218,6 +243,19 @@ private J migrateMethodCall(MethodCall original, MethodCall updated) {
218243
return original;
219244
}
220245

246+
private J.MethodInvocation migrateNonJodaMethod(J.MethodInvocation original, J.MethodInvocation updated) {
247+
if (safeMigration && !acc.getSafeMethodMap().getOrDefault(updated.getMethodType(), false)) {
248+
return original;
249+
}
250+
JavaType.Class returnType = (JavaType.Class) updated.getMethodType().getReturnType();
251+
JavaType.Class updatedReturnType = TimeClassMap.getJavaTimeType(returnType.getFullyQualifiedName());
252+
if (updatedReturnType == null) {
253+
return original; // unhandled case
254+
}
255+
return updated.withMethodType(updated.getMethodType().withReturnType(updatedReturnType))
256+
.withName(updated.getName().withType(updatedReturnType));
257+
}
258+
221259
private boolean hasJodaType(List<Expression> exprs) {
222260
for (Expression expr : exprs) {
223261
JavaType exprType = expr.getType();

0 commit comments

Comments
 (0)