Skip to content

Commit 7dba5ca

Browse files
committed
WIP provenance works
1 parent 8d3128c commit 7dba5ca

File tree

1 file changed

+96
-25
lines changed

1 file changed

+96
-25
lines changed

bosk-junit/src/main/java/works/bosk/junit/ParameterInjectionContextProvider.java

Lines changed: 96 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66
import java.util.ArrayList;
77
import java.util.Arrays;
88
import java.util.Collection;
9+
import java.util.HashSet;
910
import java.util.LinkedHashMap;
11+
import java.util.LinkedHashSet;
1012
import java.util.List;
1113
import java.util.Map;
1214
import java.util.Objects;
15+
import java.util.Set;
1316
import java.util.stream.Stream;
1417
import org.junit.jupiter.api.extension.Extension;
1518
import org.junit.jupiter.api.extension.ExtensionContext;
@@ -35,31 +38,53 @@ public boolean supportsTestTemplate(ExtensionContext context) {
3538

3639
@Override
3740
public Stream<TestTemplateInvocationContext> provideTestTemplateInvocationContexts(ExtensionContext context) {
38-
// TODO: This loop is not quite right.
39-
// We should never call the same test method twice with the same parameter values,
40-
// but if we expand branches for all injectors, we'll end up building up a cartesian product
41-
// of parameter values for injectors that aren't even used by the test method.
42-
// We should start with the test method parameters as the "root set",
43-
// find their injectors, then find the parameters of those injectors,
44-
// and so on, rather than just blindly expanding for all injectors.
45-
//
46-
// For the time being, we sometimes end up calling the same test method
47-
// multiple times with the same parameter values.
48-
List<Branch> branches = List.of(Branch.empty());
49-
for (var injectorClass : getAllInjectorClasses(context)) {
50-
branches = expandedBranches(branches, injectorClass);
41+
List<Parameter> requiredParameters = new ArrayList<>();
42+
requiredParameters.addAll(asList(context.getRequiredTestMethod().getParameters()));
43+
requiredParameters.addAll(asList(context.getRequiredTestClass().getDeclaredConstructors()[0].getParameters()));
44+
45+
List<Branch> allPossibleBranches = List.of(Branch.empty());
46+
var allInjectorClasses = getAllInjectorClasses(context);
47+
for (var injectorClass : allInjectorClasses) {
48+
allPossibleBranches = expandedBranches(allPossibleBranches, injectorClass);
49+
}
50+
51+
if (allPossibleBranches.isEmpty()) {
52+
return Stream.empty();
5153
}
5254

53-
List<Parameter> allParameters = new ArrayList<>();
54-
allParameters.addAll(asList(context.getRequiredTestMethod().getParameters()));
55-
allParameters.addAll(asList(context.getRequiredTestClass().getDeclaredConstructors()[0].getParameters()));
55+
// At this stage, we have a list of branches that have instantiated
56+
// all the injectors we could possibly have needed for any parameter,
57+
// but some injectors might be for parameters that aren't used
58+
// by the test method or the class constructor.
59+
//
60+
// If we don't prune out the unneeded injectors,
61+
// we will end up calling the test method with the same parameters
62+
// multiple times, varying only the values of parameters
63+
// that aren't even used.
64+
//
65+
// Let's determine which injector classes we actually needed.
66+
// We can do this by picking any Branch (they all have the same types of injectors)
67+
// and seeing which injectors are needed
68+
// to provide values for the requiredParameters.
69+
70+
var neededInjectorClasses = new HashSet<Class<? extends ParameterInjector>>();
71+
Branch someBranch = allPossibleBranches.getFirst();
72+
requiredParameters.forEach(p -> getNeededInjectorClasses(p, someBranch, neededInjectorClasses));
5673

57-
return branches.stream().flatMap(branch -> {
74+
// And now we can recalculate the branch list, expanding only the required injector classes
75+
List<Branch> neededBranches = List.of(Branch.empty());
76+
for (var injectorClass : allInjectorClasses) { // The order matters here
77+
if (neededInjectorClasses.contains(injectorClass)) {
78+
neededBranches = expandedBranches(neededBranches, injectorClass);
79+
}
80+
}
81+
82+
return neededBranches.stream().flatMap(branch -> {
5883
var valuesByInjector = new LinkedHashMap<ParameterInjector, List<Object>>();
59-
allParameters.forEach(p -> {
84+
requiredParameters.forEach(p -> {
6085
ParameterInjector injector = branch.injectorFor(p);
6186
if (injector != null) {
62-
valuesByInjector.computeIfAbsent(injector, branch.toInject::get);
87+
valuesByInjector.computeIfAbsent(injector, key -> branch.toInject.get(key).values());
6388
}
6489
});
6590

@@ -69,7 +94,7 @@ public Stream<TestTemplateInvocationContext> provideTestTemplateInvocationContex
6994
return combinations.stream().map(combo -> {
7095
// Swizzle the combo into a useful map from parameter to value
7196
var paramValueMap = new LinkedHashMap<Parameter, Object>();
72-
allParameters.forEach(parameter -> {
97+
requiredParameters.forEach(parameter -> {
7398
// TODO: There's essentially a copy of this in Branch.withInjectors
7499
ParameterInjector pi = branch.injectorFor(parameter);
75100
if (pi != null) {
@@ -105,6 +130,9 @@ public Object resolveParameter(ParameterContext pc, ExtensionContext ec) throws
105130
});
106131
}
107132

133+
/**
134+
* @return the injector classes in the order they should be instantiated
135+
*/
108136
private static List<Class<? extends ParameterInjector>> getAllInjectorClasses(ExtensionContext context) {
109137
List<Class<? >> bottomUp = new ArrayList<>();
110138
for (var c = context.getRequiredTestClass(); c != Object.class; c = c.getSuperclass()) {
@@ -118,6 +146,22 @@ private static List<Class<? extends ParameterInjector>> getAllInjectorClasses(Ex
118146
}
119147
return allInjectors;
120148
}
149+
150+
private void getNeededInjectorClasses(
151+
Parameter param,
152+
Branch branch,
153+
Set<Class<? extends ParameterInjector>> needed
154+
) {
155+
ParameterInjector pi = branch.injectorFor(param);
156+
if (pi != null && needed.add(pi.getClass())) {
157+
var provenance = branch.toInject.get(pi).provenance();
158+
var list = provenance.stream()
159+
.map(ParameterInjector::getClass)
160+
.toList();
161+
needed.addAll(list);
162+
}
163+
}
164+
121165
/**
122166
* Expand each branch by instantiating all injectors of the given type.
123167
*/
@@ -147,6 +191,19 @@ private static List<List<Object>> cartesianProduct(Collection<List<Object>> inpu
147191
return result;
148192
}
149193

194+
/**
195+
* @param values the subset of {@link ParameterInjector#values()} to be injected in this scenario
196+
* @param provenance the set of injectors required, directly or indirectly, to produce these values, with no guarantees on the order
197+
*/
198+
record Superposition(
199+
List<Object> values,
200+
Set<ParameterInjector> provenance
201+
){
202+
Superposition collapsed(Object singleValue) {
203+
return new Superposition(List.of(singleValue), provenance);
204+
}
205+
}
206+
150207
/**
151208
* A list of injectors that have been instantiated so far.
152209
* Represents one "scenario" for cartesian product expansion of parameter values.
@@ -161,7 +218,7 @@ private static List<List<Object>> cartesianProduct(Collection<List<Object>> inpu
161218
* this map will contain just the one value used to construct that injector on this branch.
162219
*/
163220
record Branch(
164-
Map<ParameterInjector, List<Object>> toInject
221+
Map<ParameterInjector, Superposition> toInject
165222
) {
166223
static Branch empty() {
167224
return new Branch(Map.of());
@@ -186,31 +243,45 @@ List<Branch> withInjectors(Class<? extends ParameterInjector> injectorType) {
186243
// and will not further expand the cartesian product.)
187244
List<List<Object>> valueLists = injectorsToUse.stream()
188245
.map(toInject::get)
246+
.map(Superposition::values)
189247
.toList();
190248

191249
List<Branch> result = new ArrayList<>();
192250
for (List<Object> combos : cartesianProduct(valueLists)) {
193251
try {
194252
List<Object> args = new ArrayList<>();
253+
var injectorsUsed = new LinkedHashSet<ParameterInjector>();
195254
for (var p: ctor.getParameters()) {
196255
var pi = injectorFor(p);
256+
injectorsUsed.add(pi);
197257
var index = injectorsToUse.indexOf(pi);
258+
assert index >= 0: "Internal error: injector not found for parameter " + p + " of constructor " + ctor;
198259
args.add(combos.get(index));
199260
}
200261

201-
// Instantiate the injector and add it to the map
262+
// Instantiate the injector
202263
var injector = (ParameterInjector) ctor.newInstance(args.toArray());
264+
265+
// Compute the provenance
266+
var provenance = new HashSet<ParameterInjector>();
267+
injectorsUsed.forEach(pi -> {
268+
provenance.add(pi);
269+
provenance.addAll(toInject.get(pi).provenance());
270+
});
271+
provenance.add(injector);
272+
273+
// Compute the new toInject map
203274
var map = new LinkedHashMap<>(this.toInject);
204-
map.put(injector, injector.values());
275+
map.put(injector, new Superposition(injector.values(), provenance));
205276

206277
// Collapse the wavefunction for injectors that provided a value
207278
int i = 0;
208279
for (var pi: injectorsToUse) {
209-
var possibleValues = toInject.get(pi);
280+
var possibleValues = toInject.get(pi).values();
210281
if (possibleValues.size() >= 2) {
211282
// pi has provided a value on this branch.
212283
// Record that fact for future uses of the same injector.
213-
map.put(pi, List.of(args.get(i)));
284+
map.put(pi, map.get(pi).collapsed(args.get(i)));
214285
}
215286
++i;
216287
}

0 commit comments

Comments
 (0)