66import java .util .ArrayList ;
77import java .util .Arrays ;
88import java .util .Collection ;
9+ import java .util .HashSet ;
910import java .util .LinkedHashMap ;
11+ import java .util .LinkedHashSet ;
1012import java .util .List ;
1113import java .util .Map ;
1214import java .util .Objects ;
15+ import java .util .Set ;
1316import java .util .stream .Stream ;
1417import org .junit .jupiter .api .extension .Extension ;
1518import org .junit .jupiter .api .extension .ExtensionContext ;
@@ -35,31 +38,53 @@ public boolean supportsTestTemplate(ExtensionContext context) {
3538
3639 @ Override
3740 public Stream <TestTemplateInvocationContext > provideTestTemplateInvocationContexts (ExtensionContext context ) {
38- // TODO: This loop is not quite right.
39- // We should never call the same test method twice with the same parameter values,
40- // but if we expand branches for all injectors, we'll end up building up a cartesian product
41- // of parameter values for injectors that aren't even used by the test method.
42- // We should start with the test method parameters as the "root set",
43- // find their injectors, then find the parameters of those injectors,
44- // and so on, rather than just blindly expanding for all injectors.
45- //
46- // For the time being, we sometimes end up calling the same test method
47- // multiple times with the same parameter values.
48- List <Branch > branches = List .of (Branch .empty ());
49- for (var injectorClass : getAllInjectorClasses (context )) {
50- branches = expandedBranches (branches , injectorClass );
41+ List <Parameter > requiredParameters = new ArrayList <>();
42+ requiredParameters .addAll (asList (context .getRequiredTestMethod ().getParameters ()));
43+ requiredParameters .addAll (asList (context .getRequiredTestClass ().getDeclaredConstructors ()[0 ].getParameters ()));
44+
45+ List <Branch > allPossibleBranches = List .of (Branch .empty ());
46+ var allInjectorClasses = getAllInjectorClasses (context );
47+ for (var injectorClass : allInjectorClasses ) {
48+ allPossibleBranches = expandedBranches (allPossibleBranches , injectorClass );
49+ }
50+
51+ if (allPossibleBranches .isEmpty ()) {
52+ return Stream .empty ();
5153 }
5254
53- List <Parameter > allParameters = new ArrayList <>();
54- allParameters .addAll (asList (context .getRequiredTestMethod ().getParameters ()));
55- allParameters .addAll (asList (context .getRequiredTestClass ().getDeclaredConstructors ()[0 ].getParameters ()));
55+ // At this stage, we have a list of branches that have instantiated
56+ // all the injectors we could possibly have needed for any parameter,
57+ // but some injectors might be for parameters that aren't used
58+ // by the test method or the class constructor.
59+ //
60+ // If we don't prune out the unneeded injectors,
61+ // we will end up calling the test method with the same parameters
62+ // multiple times, varying only the values of parameters
63+ // that aren't even used.
64+ //
65+ // Let's determine which injector classes we actually needed.
66+ // We can do this by picking any Branch (they all have the same types of injectors)
67+ // and seeing which injectors are needed
68+ // to provide values for the requiredParameters.
69+
70+ var neededInjectorClasses = new HashSet <Class <? extends ParameterInjector >>();
71+ Branch someBranch = allPossibleBranches .getFirst ();
72+ requiredParameters .forEach (p -> getNeededInjectorClasses (p , someBranch , neededInjectorClasses ));
5673
57- return branches .stream ().flatMap (branch -> {
74+ // And now we can recalculate the branch list, expanding only the required injector classes
75+ List <Branch > neededBranches = List .of (Branch .empty ());
76+ for (var injectorClass : allInjectorClasses ) { // The order matters here
77+ if (neededInjectorClasses .contains (injectorClass )) {
78+ neededBranches = expandedBranches (neededBranches , injectorClass );
79+ }
80+ }
81+
82+ return neededBranches .stream ().flatMap (branch -> {
5883 var valuesByInjector = new LinkedHashMap <ParameterInjector , List <Object >>();
59- allParameters .forEach (p -> {
84+ requiredParameters .forEach (p -> {
6085 ParameterInjector injector = branch .injectorFor (p );
6186 if (injector != null ) {
62- valuesByInjector .computeIfAbsent (injector , branch .toInject :: get );
87+ valuesByInjector .computeIfAbsent (injector , key -> branch .toInject . get ( key ). values () );
6388 }
6489 });
6590
@@ -69,7 +94,7 @@ public Stream<TestTemplateInvocationContext> provideTestTemplateInvocationContex
6994 return combinations .stream ().map (combo -> {
7095 // Swizzle the combo into a useful map from parameter to value
7196 var paramValueMap = new LinkedHashMap <Parameter , Object >();
72- allParameters .forEach (parameter -> {
97+ requiredParameters .forEach (parameter -> {
7398 // TODO: There's essentially a copy of this in Branch.withInjectors
7499 ParameterInjector pi = branch .injectorFor (parameter );
75100 if (pi != null ) {
@@ -105,6 +130,9 @@ public Object resolveParameter(ParameterContext pc, ExtensionContext ec) throws
105130 });
106131 }
107132
133+ /**
134+ * @return the injector classes in the order they should be instantiated
135+ */
108136 private static List <Class <? extends ParameterInjector >> getAllInjectorClasses (ExtensionContext context ) {
109137 List <Class <? >> bottomUp = new ArrayList <>();
110138 for (var c = context .getRequiredTestClass (); c != Object .class ; c = c .getSuperclass ()) {
@@ -118,6 +146,22 @@ private static List<Class<? extends ParameterInjector>> getAllInjectorClasses(Ex
118146 }
119147 return allInjectors ;
120148 }
149+
150+ private void getNeededInjectorClasses (
151+ Parameter param ,
152+ Branch branch ,
153+ Set <Class <? extends ParameterInjector >> needed
154+ ) {
155+ ParameterInjector pi = branch .injectorFor (param );
156+ if (pi != null && needed .add (pi .getClass ())) {
157+ var provenance = branch .toInject .get (pi ).provenance ();
158+ var list = provenance .stream ()
159+ .map (ParameterInjector ::getClass )
160+ .toList ();
161+ needed .addAll (list );
162+ }
163+ }
164+
121165 /**
122166 * Expand each branch by instantiating all injectors of the given type.
123167 */
@@ -147,6 +191,19 @@ private static List<List<Object>> cartesianProduct(Collection<List<Object>> inpu
147191 return result ;
148192 }
149193
194+ /**
195+ * @param values the subset of {@link ParameterInjector#values()} to be injected in this scenario
196+ * @param provenance the set of injectors required, directly or indirectly, to produce these values, with no guarantees on the order
197+ */
198+ record Superposition (
199+ List <Object > values ,
200+ Set <ParameterInjector > provenance
201+ ){
202+ Superposition collapsed (Object singleValue ) {
203+ return new Superposition (List .of (singleValue ), provenance );
204+ }
205+ }
206+
150207 /**
151208 * A list of injectors that have been instantiated so far.
152209 * Represents one "scenario" for cartesian product expansion of parameter values.
@@ -161,7 +218,7 @@ private static List<List<Object>> cartesianProduct(Collection<List<Object>> inpu
161218 * this map will contain just the one value used to construct that injector on this branch.
162219 */
163220 record Branch (
164- Map <ParameterInjector , List < Object > > toInject
221+ Map <ParameterInjector , Superposition > toInject
165222 ) {
166223 static Branch empty () {
167224 return new Branch (Map .of ());
@@ -186,31 +243,45 @@ List<Branch> withInjectors(Class<? extends ParameterInjector> injectorType) {
186243 // and will not further expand the cartesian product.)
187244 List <List <Object >> valueLists = injectorsToUse .stream ()
188245 .map (toInject ::get )
246+ .map (Superposition ::values )
189247 .toList ();
190248
191249 List <Branch > result = new ArrayList <>();
192250 for (List <Object > combos : cartesianProduct (valueLists )) {
193251 try {
194252 List <Object > args = new ArrayList <>();
253+ var injectorsUsed = new LinkedHashSet <ParameterInjector >();
195254 for (var p : ctor .getParameters ()) {
196255 var pi = injectorFor (p );
256+ injectorsUsed .add (pi );
197257 var index = injectorsToUse .indexOf (pi );
258+ assert index >= 0 : "Internal error: injector not found for parameter " + p + " of constructor " + ctor ;
198259 args .add (combos .get (index ));
199260 }
200261
201- // Instantiate the injector and add it to the map
262+ // Instantiate the injector
202263 var injector = (ParameterInjector ) ctor .newInstance (args .toArray ());
264+
265+ // Compute the provenance
266+ var provenance = new HashSet <ParameterInjector >();
267+ injectorsUsed .forEach (pi -> {
268+ provenance .add (pi );
269+ provenance .addAll (toInject .get (pi ).provenance ());
270+ });
271+ provenance .add (injector );
272+
273+ // Compute the new toInject map
203274 var map = new LinkedHashMap <>(this .toInject );
204- map .put (injector , injector .values ());
275+ map .put (injector , new Superposition ( injector .values (), provenance ));
205276
206277 // Collapse the wavefunction for injectors that provided a value
207278 int i = 0 ;
208279 for (var pi : injectorsToUse ) {
209- var possibleValues = toInject .get (pi );
280+ var possibleValues = toInject .get (pi ). values () ;
210281 if (possibleValues .size () >= 2 ) {
211282 // pi has provided a value on this branch.
212283 // Record that fact for future uses of the same injector.
213- map .put (pi , List . of (args .get (i )));
284+ map .put (pi , map . get ( pi ). collapsed (args .get (i )));
214285 }
215286 ++i ;
216287 }
0 commit comments