1717 */
1818package org .apache .beam .sdk .testing ;
1919
20- import static org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .Preconditions .checkState ;
20+ import static org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .Preconditions .checkNotNull ;
2121
2222import java .lang .annotation .Annotation ;
2323import java .lang .reflect .Method ;
24- import java .util .Arrays ;
24+ import java .util .Collection ;
2525import java .util .Optional ;
2626import org .apache .beam .sdk .options .ApplicationNameOptions ;
2727import org .apache .beam .sdk .options .PipelineOptions ;
28- import org .apache .beam .sdk .testing .TestPipeline .PipelineAbandonedNodeEnforcement ;
2928import org .apache .beam .sdk .testing .TestPipeline .PipelineRunEnforcement ;
30- import org .junit .experimental .categories .Category ;
29+ import org .apache .beam .vendor .grpc .v1p69p0 .com .google .common .collect .ImmutableList ;
30+ import org .checkerframework .checker .nullness .qual .Nullable ;
3131import org .junit .jupiter .api .extension .AfterEachCallback ;
3232import org .junit .jupiter .api .extension .BeforeEachCallback ;
3333import org .junit .jupiter .api .extension .ExtensionContext ;
@@ -86,16 +86,16 @@ public static TestPipelineExtension fromOptions(PipelineOptions options) {
8686 return new TestPipelineExtension (options );
8787 }
8888
89- private TestPipeline testPipeline ;
89+ private @ Nullable PipelineOptions options ;
9090
9191 /** Creates a TestPipelineExtension with default options. */
9292 public TestPipelineExtension () {
93- this .testPipeline = TestPipeline . create () ;
93+ this .options = null ;
9494 }
9595
9696 /** Creates a TestPipelineExtension with custom options. */
9797 public TestPipelineExtension (PipelineOptions options ) {
98- this .testPipeline = TestPipeline . fromOptions ( options ) ;
98+ this .options = options ;
9999 }
100100
101101 @ Override
@@ -107,43 +107,38 @@ public boolean supportsParameter(
107107 @ Override
108108 public Object resolveParameter (
109109 ParameterContext parameterContext , ExtensionContext extensionContext ) {
110- if (this .testPipeline == null ) {
111- return getOrCreateTestPipeline (extensionContext );
112- } else {
113- return this .testPipeline ;
114- }
110+ return getOrCreateTestPipeline (extensionContext );
115111 }
116112
117113 @ Override
118- public void beforeEach (ExtensionContext context ) throws Exception {
119- TestPipeline pipeline ;
120-
121- if (this .testPipeline != null ) {
122- pipeline = this .testPipeline ;
123- } else {
124- pipeline = getOrCreateTestPipeline (context );
125- }
114+ public void beforeEach (ExtensionContext context ) {
115+ TestPipeline pipeline = getOrCreateTestPipeline (context );
126116
127117 // Set application name based on test method
128118 String appName = getAppName (context );
129119 pipeline .getOptions ().as (ApplicationNameOptions .class ).setAppName (appName );
130120
131121 // Set up enforcement based on annotations
132- setDeducedEnforcementLevel (context , pipeline );
122+ pipeline . setDeducedEnforcementLevel (getAnnotations ( context ) );
133123 }
134124
135125 @ Override
136- public void afterEach (ExtensionContext context ) throws Exception {
137- Optional <PipelineRunEnforcement > enforcement = getEnforcement (context );
138- if (enforcement .isPresent ()) {
139- enforcement .get ().afterUserCodeFinished ();
140- }
126+ public void afterEach (ExtensionContext context ) {
127+ TestPipeline pipeline = getRequiredTestPipeline (context );
128+ pipeline .afterUserCodeFinished ();
141129 }
142130
143131 private TestPipeline getOrCreateTestPipeline (ExtensionContext context ) {
144132 return context
145133 .getStore (NAMESPACE )
146- .getOrComputeIfAbsent (PIPELINE_KEY , key -> TestPipeline .create (), TestPipeline .class );
134+ .getOrComputeIfAbsent (
135+ PIPELINE_KEY ,
136+ key -> options == null ? TestPipeline .create () : TestPipeline .fromOptions (options ),
137+ TestPipeline .class );
138+ }
139+
140+ private TestPipeline getRequiredTestPipeline (ExtensionContext context ) {
141+ return checkNotNull (context .getStore (NAMESPACE ).get (PIPELINE_KEY , TestPipeline .class ));
147142 }
148143
149144 private Optional <PipelineRunEnforcement > getEnforcement (ExtensionContext context ) {
@@ -161,53 +156,10 @@ private String getAppName(ExtensionContext context) {
161156 return className + "-" + methodName ;
162157 }
163158
164- private void setDeducedEnforcementLevel (ExtensionContext context , TestPipeline pipeline ) {
165- // If enforcement level has not been set, do auto-inference
166- if (!getEnforcement (context ).isPresent ()) {
167- boolean annotatedWithNeedsRunner = hasNeedsRunnerAnnotation (context );
168-
169- PipelineOptions options = pipeline .getOptions ();
170- boolean crashingRunner = CrashingRunner .class .isAssignableFrom (options .getRunner ());
171-
172- checkState (
173- !(annotatedWithNeedsRunner && crashingRunner ),
174- "The test was annotated with a [@%s] / [@%s] while the runner "
175- + "was set to [%s]. Please re-check your configuration." ,
176- NeedsRunner .class .getSimpleName (),
177- ValidatesRunner .class .getSimpleName (),
178- CrashingRunner .class .getSimpleName ());
179-
180- if (annotatedWithNeedsRunner || !crashingRunner ) {
181- setEnforcement (context , new PipelineAbandonedNodeEnforcement (pipeline ));
182- }
183- }
184- }
185-
186- private boolean hasNeedsRunnerAnnotation (ExtensionContext context ) {
187- // Check method annotations
188- Method testMethod = context .getTestMethod ().orElse (null );
189- if (testMethod != null ) {
190- if (hasNeedsRunnerCategory (testMethod .getAnnotations ())) {
191- return true ;
192- }
193- }
194-
195- // Check class annotations
196- Class <?> testClass = context .getTestClass ().orElse (null );
197- if (testClass != null ) {
198- if (hasNeedsRunnerCategory (testClass .getAnnotations ())) {
199- return true ;
200- }
201- }
202-
203- return false ;
204- }
205-
206- private boolean hasNeedsRunnerCategory (Annotation [] annotations ) {
207- return Arrays .stream (annotations )
208- .filter (annotation -> annotation instanceof Category )
209- .map (annotation -> (Category ) annotation )
210- .flatMap (category -> Arrays .stream (category .value ()))
211- .anyMatch (categoryClass -> NeedsRunner .class .isAssignableFrom (categoryClass ));
159+ private static Collection <Annotation > getAnnotations (ExtensionContext context ) {
160+ ImmutableList .Builder <Annotation > builder = ImmutableList .builder ();
161+ context .getTestMethod ().ifPresent (testMethod -> builder .add (testMethod .getAnnotations ()));
162+ context .getTestClass ().ifPresent (testClass -> builder .add (testClass .getAnnotations ()));
163+ return builder .build ();
212164 }
213165}
0 commit comments