Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public JodaTimeVisitor getVisitor(Accumulator acc) {
@Getter
public static class Accumulator {
private final Set<NamedVariable> unsafeVars = new HashSet<>();
private final Map<JavaType.Method, Boolean> safeMethodMap = new HashMap<>();
private final VarTable varTable = new VarTable();
}

Expand Down
168 changes: 139 additions & 29 deletions src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.Value;
import org.jspecify.annotations.Nullable;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
Expand All @@ -43,6 +44,8 @@ class JodaTimeScanner extends ScopeAwareVisitor {

private final Map<NamedVariable, Set<NamedVariable>> varDependencies = new HashMap<>();
private final Map<JavaType, Set<String>> unsafeVarsByType = new HashMap<>();
private final Map<JavaType.Method, Set<NamedVariable>> methodReferencedVars = new HashMap<>();
private final Map<JavaType.Method, Set<UnresolvedVar>> methodUnresolvedReferencedVars = new HashMap<>();

public JodaTimeScanner(JodaTimeRecipe.Accumulator acc) {
super(new LinkedList<>());
Expand All @@ -57,13 +60,30 @@ public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
dfs(var, allReachable);
}
acc.getUnsafeVars().addAll(allReachable);

Set<JavaType.Method> unsafeMethods = new HashSet<>();
acc.getSafeMethodMap().forEach((method, isSafe) -> {
if (!isSafe) {
unsafeMethods.add(method);
return;
}
Set<NamedVariable> intersection = new HashSet<>(methodReferencedVars.getOrDefault(method, Collections.emptySet()));
intersection.retainAll(acc.getUnsafeVars());
if (!intersection.isEmpty()) {
unsafeMethods.add(method);
}
});
for (JavaType.Method method : unsafeMethods) {
acc.getSafeMethodMap().put(method, false);
acc.getUnsafeVars().addAll(methodReferencedVars.getOrDefault(method, Collections.emptySet()));
}
return cu;
}

@Override
public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx) {
public J visitVariable(NamedVariable variable, ExecutionContext ctx) {
if (!variable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return (NamedVariable) super.visitVariable(variable, ctx);
return super.visitVariable(variable, ctx);
}
// TODO: handle class variables
if (isClassVar(variable)) {
Expand Down Expand Up @@ -96,35 +116,112 @@ public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx)
}

@Override
public J.Assignment visitAssignment(J.Assignment assignment, ExecutionContext ctx) {
public J visitAssignment(J.Assignment assignment, ExecutionContext ctx) {
Expression var = assignment.getVariable();
// not joda expr or not local variable
if (!isJodaExpr(var) || !(var instanceof J.Identifier)) {
return assignment;
return super.visitAssignment(assignment, ctx);
}
J.Identifier ident = (J.Identifier) var;
Optional<NamedVariable> mayBeVar = findVarInScope(ident.getSimpleName());
if (!mayBeVar.isPresent()) {
return assignment;
return super.visitAssignment(assignment, ctx);
}
NamedVariable variable = mayBeVar.get();
Cursor varScope = findScope(variable);
List<Expression> sinks = findSinks(new Cursor(getCursor(), assignment.getAssignment()));
new AddSafeCheckMarker(sinks).visit(varScope.getValue(), ctx, varScope.getParentOrThrow());
processMarkersOnExpression(sinks, variable);
return assignment;
return super.visitAssignment(assignment, ctx);
}

@Override
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
public J visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
acc.getVarTable().addVars(method);
unsafeVarsByType.getOrDefault(method.getMethodType(), Collections.emptySet()).forEach(varName -> {
NamedVariable var = acc.getVarTable().getVarByName(method.getMethodType(), varName);
if (var != null) { // var can only be null if method is not correctly type attributed
acc.getUnsafeVars().add(var);
}
});
return (J.MethodDeclaration) super.visitMethodDeclaration(method, ctx);
Set<UnresolvedVar> unresolvedVars = methodUnresolvedReferencedVars.remove(method.getMethodType());
if (unresolvedVars != null) {
unresolvedVars.forEach(var -> {
NamedVariable namedVar = acc.getVarTable().getVarByName(var.getDeclaringType(), var.getVarName());
if (namedVar != null) {
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(namedVar);
}
});
}
return super.visitMethodDeclaration(method, ctx);
}

@Override
public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
if (!isJodaExpr(method) || method.getMethodType().getDeclaringType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return super.visitMethodInvocation(method, ctx);
}
Cursor boundary = findBoundaryCursorForJodaExpr(getCursor());
J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes)
.visit(boundary.getValue(), ctx, boundary.getParentTreeCursor());

boolean isSafe = j != boundary.getValue();
acc.getSafeMethodMap().compute(method.getMethodType(), (k, v) -> v == null ? isSafe : v && isSafe);
J parent = boundary.getParentTreeCursor().getValue();
if (parent instanceof NamedVariable) {
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>())
.add((NamedVariable) parent);
}
if (parent instanceof J.Assignment) {
J.Assignment assignment = (J.Assignment) parent;
if (assignment.getVariable() instanceof J.Identifier) {
J.Identifier ident = (J.Identifier) assignment.getVariable();
findVarInScope(ident.getSimpleName())
.map(var -> methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(var));
}
}
if (parent instanceof MethodCall) {
MethodCall parentMethod = (MethodCall) parent;
int argPos = parentMethod.getArguments().indexOf(boundary.getValue());
if (argPos == -1) {
return method;
}
String paramName = parentMethod.getMethodType().getParameterNames().get(argPos);
NamedVariable var = acc.getVarTable().getVarByName(parentMethod.getMethodType(), paramName);
if (var != null) {
methodReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>()).add(var);
} else {
methodUnresolvedReferencedVars.computeIfAbsent(method.getMethodType(), k -> new HashSet<>())
.add(new UnresolvedVar(parentMethod.getMethodType(), paramName));
}
}
return method;
}

@Override
public J.Return visitReturn(J.Return _return, ExecutionContext ctx) {
if (_return.getExpression() == null) {
return _return;
}
Expression expr = _return.getExpression();
if (!isJodaExpr(expr)) {
return _return;
}
J methodOrLambda = getCursor().dropParentUntil(j -> j instanceof J.MethodDeclaration || j instanceof J.Lambda).getValue();
if (methodOrLambda instanceof J.Lambda) {
return _return;
}
J.MethodDeclaration method = (J.MethodDeclaration) methodOrLambda;
Expression updatedExpr = (Expression) new JodaTimeVisitor(acc, true, scopes)
.visit(expr, ctx, getCursor().getParentTreeCursor());
boolean isSafe = !isJodaExpr(updatedExpr);

addReferencedVars(expr, method.getMethodType());
acc.getSafeMethodMap().compute(method.getMethodType(), (k, v) -> v == null ? isSafe : v && isSafe);
if (!isSafe) {
acc.getUnsafeVars().addAll(methodReferencedVars.get(method.getMethodType()));
}
return _return;
}

private void processMarkersOnExpression(List<Expression> expressions, NamedVariable var) {
Expand All @@ -146,7 +243,23 @@ private void processMarkersOnExpression(List<Expression> expressions, NamedVaria
}
}

private boolean isJodaExpr(Expression expression) {
/**
* Traverses the cursor to find the first non-Joda expression in the path.
* If no non-Joda expression is found, it returns the cursor pointing
* to the last Joda expression whose parent is not an Expression.
*/
private static Cursor findBoundaryCursorForJodaExpr(Cursor cursor) {
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
Cursor parent = cursor.getParentTreeCursor();
if (parent.getValue() instanceof J && !(parent.getValue() instanceof Expression)) {
return cursor;
}
cursor = parent;
}
return cursor;
}

private static boolean isJodaExpr(Expression expression) {
return expression.getType() != null && expression.getType().isAssignableFrom(JODA_CLASS_PATTERN);
}

Expand All @@ -172,6 +285,13 @@ private void dfs(NamedVariable root, Set<NamedVariable> visited) {
}
}

private void addReferencedVars(Expression expr, JavaType.Method method) {
Set<@Nullable NamedVariable> referencedVars = new HashSet<>();
new FindVarReferences().visit(expr, referencedVars, getCursor().getParentTreeCursor());
referencedVars.remove(null);
methodReferencedVars.computeIfAbsent(method, k -> new HashSet<>()).addAll(referencedVars);
}

@RequiredArgsConstructor
private class AddSafeCheckMarker extends JavaIsoVisitor<ExecutionContext> {

Expand Down Expand Up @@ -205,11 +325,12 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
return mayBeMarker.get();
}

Cursor boundary = findBoundaryCursorForJodaExpr();
Cursor boundary = findBoundaryCursorForJodaExpr(getCursor());
boolean isSafe = true;
// TODO: handle return statement
if (boundary.getParentTreeCursor().getValue() instanceof J.Return) {
isSafe = false;
// TODO: handle return statement in lambda
isSafe = boundary.dropParentUntil(j -> j instanceof J.MethodDeclaration || j instanceof J.Lambda)
.getValue() instanceof J.MethodDeclaration;
}
Expression boundaryExpr = boundary.getValue();
J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes)
Expand All @@ -223,23 +344,6 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
return new SafeCheckMarker(UUID.randomUUID(), isSafe, referencedVars);
}

/**
* Traverses the cursor to find the first non-Joda expression in the path.
* If no non-Joda expression is found, it returns the cursor pointing
* to the last Joda expression whose parent is not an Expression.
*/
private Cursor findBoundaryCursorForJodaExpr() {
Cursor cursor = getCursor();
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
Cursor parent = cursor.getParentTreeCursor();
if (parent.getValue() instanceof J && !(parent.getValue() instanceof Expression)) {
return cursor;
}
cursor = parent;
}
return cursor;
}

private Optional<Cursor> findArgumentExprCursor() {
Cursor cursor = getCursor();
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
Expand Down Expand Up @@ -283,4 +387,10 @@ public Expression visitExpression(Expression expression, AtomicBoolean hasJodaTy
return super.visitExpression(expression, hasJodaType);
}
}

@Value
private static class UnresolvedVar {
JavaType declaringType;
String varName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx)
return super.visitCompilationUnit(cu, ctx);
}

@Override
public @NonNull J visitMethodDeclaration(@NonNull J.MethodDeclaration method, @NonNull ExecutionContext ctx) {
J.MethodDeclaration m = (J.MethodDeclaration) super.visitMethodDeclaration(method, ctx);
if (m.getReturnTypeExpression() == null || !m.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return m;
}
if (safeMigration && !acc.getSafeMethodMap().getOrDefault(m.getMethodType(), false)) {
return m;
}

JavaType.Class returnType = TimeClassMap.getJavaTimeType(((JavaType.Class) m.getType()).getFullyQualifiedName());
J.Identifier returnExpr = TypeTree.build(returnType.getClassName()).withType(returnType).withPrefix(Space.format(" "));
return m.withReturnTypeExpression(returnExpr)
.withMethodType(m.getMethodType().withReturnType(returnType));
}

@Override
public @NonNull J visitVariableDeclarations(@NonNull J.VariableDeclarations multiVariable, @NonNull ExecutionContext ctx) {
if (multiVariable.getTypeExpression() == null || !multiVariable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
Expand Down Expand Up @@ -147,6 +163,13 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx)
@Override
public @NonNull J visitMethodInvocation(@NonNull J.MethodInvocation method, @NonNull ExecutionContext ctx) {
J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, ctx);

// internal method with Joda class as return type
if (!method.getMethodType().getDeclaringType().isAssignableFrom(JODA_CLASS_PATTERN) &&
method.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return migrateNonJodaMethod(method, m);
}

if (hasJodaType(m.getArguments()) || isJodaVarRef(m.getSelect())) {
return method;
}
Expand Down Expand Up @@ -179,7 +202,9 @@ public Javadoc visitReference(Javadoc.Reference reference, ExecutionContext ctx)

JavaType.FullyQualified jodaType = ((JavaType.Class) ident.getType());
JavaType.FullyQualified fqType = TimeClassMap.getJavaTimeType(jodaType.getFullyQualifiedName());

if (fqType == null) {
return ident;
}
return ident.withType(fqType)
.withFieldType(ident.getFieldType().withType(fqType));
}
Expand Down Expand Up @@ -218,6 +243,19 @@ private J migrateMethodCall(MethodCall original, MethodCall updated) {
return original;
}

private J.MethodInvocation migrateNonJodaMethod(J.MethodInvocation original, J.MethodInvocation updated) {
if (safeMigration && !acc.getSafeMethodMap().getOrDefault(updated.getMethodType(), false)) {
return original;
}
JavaType.Class returnType = (JavaType.Class) updated.getMethodType().getReturnType();
JavaType.Class updatedReturnType = TimeClassMap.getJavaTimeType(returnType.getFullyQualifiedName());
if (updatedReturnType == null) {
return original; // unhandled case
}
return updated.withMethodType(updated.getMethodType().withReturnType(updatedReturnType))
.withName(updated.getName().withType(updatedReturnType));
}

private boolean hasJodaType(List<Expression> exprs) {
for (Expression expr : exprs) {
JavaType exprType = expr.getType();
Expand Down
Loading
Loading