Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ public boolean isSource(@NonNull DataFlowNode srcNode) {
if (value instanceof J.VariableDeclarations.NamedVariable) {
return isJodaType(((J.VariableDeclarations.NamedVariable) value).getType());
}

if (value instanceof J.VariableDeclarations) {
if (srcNode.getCursor().getParentTreeCursor().getParentTreeCursor().getValue() instanceof J.MethodDeclaration) {
return isJodaType(((J.VariableDeclarations) value).getType());
}
}
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
*/
package org.openrewrite.java.migrate.joda;

import lombok.Getter;
import org.openrewrite.ExecutionContext;
import org.openrewrite.ScanningRecipe;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.J.VariableDeclarations.NamedVariable;
import org.openrewrite.java.tree.JavaType;

import java.util.HashSet;
import java.util.Set;
import java.util.*;

public class JodaTimeRecipe extends ScanningRecipe<Set<NamedVariable>> {
public class JodaTimeRecipe extends ScanningRecipe<JodaTimeRecipe.Accumulator> {
@Override
public String getDisplayName() {
return "Migrate Joda Time to Java Time";
return "Migrate joda time to java time";
}

@Override
Expand All @@ -34,17 +36,46 @@ public String getDescription() {
}

@Override
public Set<NamedVariable> getInitialValue(ExecutionContext ctx) {
return new HashSet<>();
public Accumulator getInitialValue(ExecutionContext ctx) {
return new Accumulator();
}

@Override
public JodaTimeScanner getScanner(Set<NamedVariable> acc) {
public JodaTimeScanner getScanner(Accumulator acc) {
return new JodaTimeScanner(acc);
}

@Override
public JodaTimeVisitor getVisitor(Set<NamedVariable> acc) {
public JodaTimeVisitor getVisitor(Accumulator acc) {
return new JodaTimeVisitor(acc);
}

@Getter
public static class Accumulator {
private final Set<NamedVariable> unsafeVars = new HashSet<>();
private final VarTable varTable = new VarTable();
}

static class VarTable {
private final Map<JavaType, List<NamedVariable>> vars = new HashMap<>();

public void addVars(J.MethodDeclaration methodDeclaration) {
JavaType type = methodDeclaration.getMethodType();

methodDeclaration.getParameters().forEach(p -> {
if (!(p instanceof J.VariableDeclarations) ) {
return;
}
J.VariableDeclarations.NamedVariable namedVariable = ((J.VariableDeclarations) p).getVariables().get(0);
vars.computeIfAbsent(type, k -> new ArrayList<>()).add(namedVariable);
});
}

public NamedVariable getVarByName(JavaType declaringType, String varName) {
return vars.getOrDefault(declaringType, Collections.emptyList()).stream()
.filter(v -> v.getSimpleName().equals(varName))
.findFirst() // there should be only one variable with the same name
.orElse(null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.J.VariableDeclarations.NamedVariable;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.MethodCall;

import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -37,27 +38,28 @@
public class JodaTimeScanner extends ScopeAwareVisitor {

@Getter
private final Set<NamedVariable> unsafeVars;
private final JodaTimeRecipe.Accumulator acc;

private final Map<NamedVariable, Set<NamedVariable>> varDependencies = new HashMap<>();
private final Map<JavaType, Set<String>> unsafeVarsByType = new HashMap<>();

public JodaTimeScanner(Set<NamedVariable> unsafeVars, LinkedList<VariablesInScope> scopes) {
public JodaTimeScanner(JodaTimeRecipe.Accumulator acc, LinkedList<VariablesInScope> scopes) {
super(scopes);
this.unsafeVars = unsafeVars;
this.acc = acc;
}

public JodaTimeScanner(Set<NamedVariable> unsafeVars) {
this(unsafeVars, new LinkedList<>());
public JodaTimeScanner(JodaTimeRecipe.Accumulator acc) {
this(acc, new LinkedList<>());
}

@Override
public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
super.visitCompilationUnit(cu, ctx);
Set<NamedVariable> allReachable = new HashSet<>();
for (NamedVariable var : unsafeVars) {
for (NamedVariable var : acc.getUnsafeVars()) {
dfs(var, allReachable);
}
unsafeVars.addAll(allReachable);
acc.getUnsafeVars().addAll(allReachable);
return cu;
}

Expand All @@ -67,20 +69,31 @@ public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx)
return variable;
}
// TODO: handle class variables && method parameters
if (!isLocalVar(variable)) {
unsafeVars.add(variable);
if (isClassVar(variable)) {
acc.getUnsafeVars().add(variable);
return variable;
}
variable = (NamedVariable) super.visitVariable(variable, ctx);

if (!variable.getType().isAssignableFrom(JODA_CLASS_PATTERN) || variable.getInitializer() == null) {
if (!variable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return variable;
}
List<Expression> sinks = findSinks(variable.getInitializer());
boolean isMethodParam = getCursor().getParentTreeCursor() // VariableDeclaration
.getParentTreeCursor() // MethodDeclaration
.getValue() instanceof J.MethodDeclaration;
Cursor cursor = null;
if (isMethodParam) {
cursor = getCursor();
} else if (variable.getInitializer() != null) {
cursor = new Cursor(getCursor(), variable.getInitializer());
}
if (cursor == null) {
return variable;
}
List<Expression> sinks = findSinks(cursor);

Cursor currentScope = getCurrentScope();
J.Block block = currentScope.getValue();
new AddSafeCheckMarker(sinks).visit(block, ctx, currentScope.getParent());
new AddSafeCheckMarker(sinks).visit(currentScope.getValue(), ctx, currentScope.getParent());
processMarkersOnExpression(sinks, variable);
return variable;
}
Expand All @@ -99,12 +112,24 @@ public J.Assignment visitAssignment(J.Assignment assignment, ExecutionContext ct
}
NamedVariable variable = mayBeVar.get();
Cursor varScope = findScope(variable);
List<Expression> sinks = findSinks(assignment.getAssignment());
List<Expression> sinks = findSinks(new Cursor(getCursor(), assignment.getAssignment()));
new AddSafeCheckMarker(sinks).visit(varScope.getValue(), ctx, varScope.getParent());
processMarkersOnExpression(sinks, variable);
return assignment;
}

@Override
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
acc.getVarTable().addVars(method);
unsafeVarsByType.getOrDefault(method.getMethodType(), Collections.emptySet()).forEach(varName -> {
NamedVariable var = acc.getVarTable().getVarByName(method.getMethodType(), varName);
if (var != null) { // var can only be null if method is not correctly type attributed
acc.getUnsafeVars().add(var);
}
});
return (J.MethodDeclaration) super.visitMethodDeclaration(method, ctx);
}

private void processMarkersOnExpression(List<Expression> expressions, NamedVariable var) {
for (Expression expr : expressions) {
Optional<SafeCheckMarker> mayBeMarker = expr.getMarkers().findFirst(SafeCheckMarker.class);
Expand All @@ -113,7 +138,7 @@ private void processMarkersOnExpression(List<Expression> expressions, NamedVaria
}
SafeCheckMarker marker = mayBeMarker.get();
if (!marker.isSafe()) {
unsafeVars.add(var);
acc.getUnsafeVars().add(var);
}
if (!marker.getReferences().isEmpty()) {
varDependencies.compute(var, (k, v) -> v == null ? new HashSet<>() : v).addAll(marker.getReferences());
Expand All @@ -128,21 +153,16 @@ private boolean isJodaExpr(Expression expression) {
return expression.getType() != null && expression.getType().isAssignableFrom(JODA_CLASS_PATTERN);
}

private List<Expression> findSinks(Expression expr) {
Cursor cursor = new Cursor(getCursor(), expr);
private List<Expression> findSinks(Cursor cursor) {
Option<SinkFlowSummary> mayBeSinks = Dataflow.startingAt(cursor).findSinks(new JodaTimeFlowSpec());
if (mayBeSinks.isNone()) {
return Collections.emptyList();
}
return mayBeSinks.some().getExpressionSinks();
}

private boolean isLocalVar(NamedVariable variable) {
if (!(variable.getVariableType().getOwner() instanceof JavaType.Method)) {
return false;
}
J j = getCursor().dropParentUntil(t -> t instanceof J.Block || t instanceof J.MethodDeclaration).getValue();
return j instanceof J.Block;
private boolean isClassVar(NamedVariable variable) {
return variable.getVariableType().getOwner() instanceof JavaType.Class;
}

private void dfs(NamedVariable root, Set<NamedVariable> visited) {
Expand All @@ -167,7 +187,17 @@ public Expression visitExpression(Expression expression, ExecutionContext ctx) {
if (index == -1) {
return super.visitExpression(expression, ctx);
}
Expression withMarker = expression.withMarkers(expression.getMarkers().addIfAbsent(getMarker(expression, ctx)));
SafeCheckMarker marker = getMarker(expression, ctx);
if (!marker.isSafe()) {
Optional<Cursor> mayBeArgCursor = findArgumentExprCursor();
if (mayBeArgCursor.isPresent()) {
MethodCall parentMethod = mayBeArgCursor.get().getParentTreeCursor().getValue();
int argPos = parentMethod.getArguments().indexOf(mayBeArgCursor.get().getValue());
String paramName = parentMethod.getMethodType().getParameterNames().get(argPos);
unsafeVarsByType.computeIfAbsent(parentMethod.getMethodType(), k -> new HashSet<>()).add(paramName);
}
}
Expression withMarker = expression.withMarkers(expression.getMarkers().addIfAbsent(marker));
expressions.set(index, withMarker);
return withMarker;
}
Expand All @@ -185,7 +215,8 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
isSafe = false;
}
Expression boundaryExpr = boundary.getValue();
J j = new JodaTimeVisitor(new HashSet<>(), scopes).visit(boundaryExpr, ctx, boundary.getParentTreeCursor());
J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes)
.visit(boundaryExpr, ctx, boundary.getParentTreeCursor());
Set<NamedVariable> referencedVars = new HashSet<>();
new FindVarReferences().visit(expr, referencedVars, getCursor().getParentTreeCursor());
AtomicBoolean hasJodaType = new AtomicBoolean();
Expand All @@ -211,6 +242,19 @@ private Cursor findBoundaryCursorForJodaExpr() {
}
return cursor;
}

private Optional<Cursor> findArgumentExprCursor() {
Cursor cursor = getCursor();
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
Cursor parentCursor = cursor.getParentTreeCursor();
if (parentCursor.getValue() instanceof MethodCall &&
((MethodCall) parentCursor.getValue()).getArguments().contains(cursor.getValue())) {
return Optional.of(cursor);
}
cursor = parentCursor;
}
return Optional.empty();
}
}

private class FindVarReferences extends JavaIsoVisitor<Set<NamedVariable>> {
Expand Down
Loading