Skip to content

Commit 6e8c550

Browse files
Feat add access to acc ai rewrite comprehension (#671)
* Add access to accumulator so rewrite-comprehension can use results from FindUnitTests * Change UnitTest to be in the value of the map of unitTestAndItsMethods instead of key * license
1 parent ed11a6e commit 6e8c550

File tree

2 files changed

+97
-32
lines changed

2 files changed

+97
-32
lines changed

src/main/java/org/openrewrite/java/testing/search/FindUnitTests.java

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,34 +24,71 @@
2424
import org.openrewrite.java.search.IsLikelyNotTest;
2525
import org.openrewrite.java.search.IsLikelyTest;
2626
import org.openrewrite.java.tree.J;
27+
import org.openrewrite.marker.SearchResult;
2728

2829
import java.util.HashMap;
2930
import java.util.HashSet;
3031
import java.util.Map;
3132
import java.util.Set;
3233

33-
import static java.util.Collections.singletonList;
3434

3535
public class FindUnitTests extends ScanningRecipe<FindUnitTests.Accumulator> {
3636

37+
38+
private transient Accumulator acc = new Accumulator();
39+
40+
transient FindUnitTestTable unitTestTable = new FindUnitTestTable(this);
41+
42+
public FindUnitTests() {
43+
}
44+
45+
public FindUnitTests(Accumulator acc) {
46+
this.acc = acc;
47+
}
48+
3749
@Override
3850
public String getDisplayName() {
3951
return "Find unit tests";
4052
}
4153

4254
@Override
4355
public String getDescription() {
44-
return "Produces a data table showing examples of how methods declared get used in unit tests.";
56+
return "Produces a data table showing how methods are used in unit tests.";
4557
}
4658

47-
transient FindUnitTestTable unitTestTable = new FindUnitTestTable(this);
48-
4959
public static class Accumulator {
50-
Map<UnitTest, Set<J.MethodInvocation>> unitTestAndTheirMethods = new HashMap<>();
60+
private final Map<String, AccumulatorValue> unitTestsByKey = new HashMap<>();
61+
62+
public Map<String, AccumulatorValue> getUnitTestAndTheirMethods(){
63+
return this.unitTestsByKey;
64+
}
65+
66+
public void addMethodInvocation(String clazz, String testName, String testBody, J.MethodInvocation invocation) {
67+
String key = clazz + "#" + testName;
68+
AccumulatorValue value = unitTestsByKey.get(key);
69+
if (value == null) {
70+
UnitTest unitTest = new UnitTest(clazz, testName, testBody);
71+
value = new AccumulatorValue(unitTest, new HashSet<>());
72+
unitTestsByKey.put(key, value);
73+
}
74+
value.getMethodInvocations().add(invocation);
75+
}
76+
77+
public Map<String, AccumulatorValue> getUnitTestsByKey() {
78+
return unitTestsByKey;
79+
}
80+
}
81+
82+
83+
@Value
84+
public static class AccumulatorValue {
85+
UnitTest unitTest;
86+
Set<J.MethodInvocation> methodInvocations;
5187
}
5288

5389
@Override
5490
public Accumulator getInitialValue(ExecutionContext ctx) {
91+
if (acc != null) return acc;
5592
return new Accumulator();
5693
}
5794

@@ -60,26 +97,25 @@ public TreeVisitor<?, ExecutionContext> getScanner(Accumulator acc) {
6097
JavaVisitor<ExecutionContext> scanningVisitor = new JavaVisitor<ExecutionContext>() {
6198
@Override
6299
public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
63-
// get the method declaration the method invocation is in
100+
// Identify the method declaration that encloses this invocation
64101
J.MethodDeclaration methodDeclaration = getCursor().firstEnclosing(J.MethodDeclaration.class);
65102
if (methodDeclaration != null &&
66103
methodDeclaration.getLeadingAnnotations().stream()
67-
.filter(o -> o.getAnnotationType() instanceof J.Identifier)
68-
.anyMatch(o -> "Test".equals(o.getSimpleName()))) {
69-
UnitTest unitTest = new UnitTest(
70-
getCursor().firstEnclosingOrThrow(J.ClassDeclaration.class).getType().getFullyQualifiedName(),
71-
methodDeclaration.getSimpleName(),
72-
methodDeclaration.printTrimmed(getCursor()));
73-
acc.unitTestAndTheirMethods.merge(unitTest,
74-
new HashSet<>(singletonList(method)),
75-
(a, b) -> {
76-
a.addAll(b);
77-
return a;
78-
});
104+
.filter(o -> o.getAnnotationType() instanceof J.Identifier)
105+
.anyMatch(o -> "Test".equals(o.getSimpleName()))) {
106+
String clazz = getCursor().firstEnclosingOrThrow(J.ClassDeclaration.class)
107+
.getType().getFullyQualifiedName();
108+
109+
String testName = methodDeclaration.getSimpleName();
110+
111+
String testBody = methodDeclaration.printTrimmed(getCursor());
112+
113+
acc.addMethodInvocation(clazz, testName, testBody, method);
79114
}
80115
return super.visitMethodInvocation(method, ctx);
81116
}
82117
};
118+
83119
return Preconditions.check(new IsLikelyTest().getVisitor(), scanningVisitor);
84120
}
85121

@@ -88,30 +124,28 @@ public TreeVisitor<?, ExecutionContext> getVisitor(Accumulator acc) {
88124
JavaVisitor<ExecutionContext> tableRowVisitor = new JavaVisitor<ExecutionContext>() {
89125
@Override
90126
public J visitMethodDeclaration(J.MethodDeclaration methodDeclaration, ExecutionContext ctx) {
91-
for (Map.Entry<UnitTest, Set<J.MethodInvocation>> entry : acc.unitTestAndTheirMethods.entrySet()) {
92-
for (J.MethodInvocation method : entry.getValue()) {
93-
if (method.getSimpleName().equals(methodDeclaration.getSimpleName())) {
127+
// Iterate over each stored AccumulatorValue
128+
for (AccumulatorValue value : acc.getUnitTestsByKey().values()) {
129+
UnitTest unitTest = value.getUnitTest();
130+
for (J.MethodInvocation invocation : value.getMethodInvocations()) {
131+
// If the invoked method name matches the current methodDeclaration's name,
132+
// we assume we've found "usage" of that method inside the test
133+
if (invocation.getSimpleName().equals(methodDeclaration.getSimpleName())) {
94134
unitTestTable.insertRow(ctx, new FindUnitTestTable.Row(
95135
methodDeclaration.getName().toString(),
96136
methodDeclaration.getSimpleName(),
97-
method.printTrimmed(getCursor()),
98-
entry.getKey().getClazz(),
99-
entry.getKey().getUnitTestName()
137+
invocation.printTrimmed(getCursor()),
138+
unitTest.getClazz(),
139+
unitTest.getUnitTestName()
100140
));
101141
}
102142
}
103143
}
144+
SearchResult.found(methodDeclaration);
104145
return super.visitMethodDeclaration(methodDeclaration, ctx);
105146
}
106147
};
148+
107149
return Preconditions.check(new IsLikelyNotTest().getVisitor(), tableRowVisitor);
108150
}
109-
110-
}
111-
112-
@Value
113-
class UnitTest {
114-
String clazz;
115-
String unitTestName;
116-
String unitTest;
117151
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
* <p>
4+
* Licensed under the Moderne Source Available License (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* <p>
8+
* https://docs.moderne.io/licensing/moderne-source-available-license
9+
* <p>
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.openrewrite.java.testing.search;
18+
19+
import lombok.Value;
20+
21+
@Value
22+
public class UnitTest {
23+
String clazz;
24+
String unitTestName;
25+
String unitTest;
26+
27+
@Override
28+
public String toString() {
29+
return clazz + "." + unitTestName;
30+
}
31+
}

0 commit comments

Comments
 (0)