|
5 | 5 | import static edu.cuny.hunter.streamrefactoring.core.analysis.Util.getLineNumberFromAST;
|
6 | 6 | import static edu.cuny.hunter.streamrefactoring.core.analysis.Util.getLineNumberFromIR;
|
7 | 7 | import static edu.cuny.hunter.streamrefactoring.core.analysis.Util.getPossibleTypesInterprocedurally;
|
| 8 | +import static edu.cuny.hunter.streamrefactoring.core.analysis.Util.implementsBaseStream; |
8 | 9 | import static edu.cuny.hunter.streamrefactoring.core.analysis.Util.matches;
|
9 | 10 | import static edu.cuny.hunter.streamrefactoring.core.safe.Util.instanceKeyCorrespondsWithInstantiationInstruction;
|
10 | 11 |
|
|
27 | 28 | import org.eclipse.jdt.core.IMethod;
|
28 | 29 | import org.eclipse.jdt.core.IType;
|
29 | 30 | import org.eclipse.jdt.core.JavaModelException;
|
| 31 | +import org.eclipse.jdt.core.dom.AST; |
30 | 32 | import org.eclipse.jdt.core.dom.ASTNode;
|
31 | 33 | import org.eclipse.jdt.core.dom.CompilationUnit;
|
| 34 | +import org.eclipse.jdt.core.dom.Expression; |
32 | 35 | import org.eclipse.jdt.core.dom.IMethodBinding;
|
33 | 36 | import org.eclipse.jdt.core.dom.ITypeBinding;
|
34 | 37 | import org.eclipse.jdt.core.dom.MethodDeclaration;
|
@@ -315,8 +318,64 @@ protected void convertToParallel(CompilationUnitRewrite rewrite) {
|
315 | 318 | LOGGER.info("Converting to parallel.");
|
316 | 319 | MethodInvocation creation = this.getCreation();
|
317 | 320 | ASTRewrite astRewrite = rewrite.getASTRewrite();
|
318 |
| - SimpleName newMethodName = creation.getAST().newSimpleName("parallelStream"); |
319 |
| - astRewrite.replace(creation.getName(), newMethodName, null); |
| 321 | + |
| 322 | + MethodInvocation termOp = findTerminalOperation(creation); |
| 323 | + |
| 324 | + // get the terminal operation's expression. |
| 325 | + Expression expression = termOp.getExpression(); |
| 326 | + |
| 327 | + boolean done = false; |
| 328 | + |
| 329 | + while (expression != null && !done) { |
| 330 | + if (expression.getNodeType() == ASTNode.METHOD_INVOCATION) { |
| 331 | + MethodInvocation inv = (MethodInvocation) expression; |
| 332 | + AST ast = creation.getAST(); |
| 333 | + |
| 334 | + switch (inv.getName().getIdentifier()) { |
| 335 | + case "sequential": |
| 336 | + // remove it. |
| 337 | + astRewrite.replace(inv, inv.getExpression(), null); |
| 338 | + break; |
| 339 | + case "parallel": |
| 340 | + done = true; |
| 341 | + break; |
| 342 | + case "stream": { |
| 343 | + // Replace with parallelStream(). |
| 344 | + SimpleName newMethodName = ast.newSimpleName("parallelStream"); |
| 345 | + astRewrite.replace(creation.getName(), newMethodName, null); |
| 346 | + break; |
| 347 | + } |
| 348 | + case "parallelStream": |
| 349 | + done = true; |
| 350 | + break; |
| 351 | + default: { |
| 352 | + // if we're at the end. |
| 353 | + if (inv.getExpression().getNodeType() != ASTNode.METHOD_INVOCATION |
| 354 | + || inv.getExpression().getNodeType() == ASTNode.METHOD_INVOCATION |
| 355 | + && !implementsBaseStream(inv.getExpression().resolveTypeBinding())) { |
| 356 | + MethodInvocation newMethodInvocation = ast.newMethodInvocation(); |
| 357 | + newMethodInvocation.setName(ast.newSimpleName("parallel")); |
| 358 | + MethodInvocation invCopy = (MethodInvocation) ASTNode.copySubtree(ast, inv); |
| 359 | + newMethodInvocation.setExpression(invCopy); |
| 360 | + astRewrite.replace(inv, newMethodInvocation, null); |
| 361 | + } |
| 362 | + } |
| 363 | + } |
| 364 | + expression = inv.getExpression(); |
| 365 | + } else |
| 366 | + done = true; |
| 367 | + } |
| 368 | + } |
| 369 | + |
| 370 | + private static MethodInvocation findTerminalOperation(ASTNode astNode) { |
| 371 | + if (astNode == null) |
| 372 | + return null; |
| 373 | + else if (astNode.getNodeType() != ASTNode.METHOD_INVOCATION) |
| 374 | + throw new IllegalArgumentException(astNode + " must be a method invocation."); |
| 375 | + if (astNode.getParent().getNodeType() != ASTNode.METHOD_INVOCATION) |
| 376 | + return (MethodInvocation) astNode; |
| 377 | + else |
| 378 | + return findTerminalOperation(astNode.getParent()); |
320 | 379 | }
|
321 | 380 |
|
322 | 381 | protected void convertToSequential(CompilationUnitRewrite rewrite) {
|
|
0 commit comments