2222import org .apache .flink .api .common .functions .RichFlatMapFunction ;
2323import org .apache .flink .api .common .state .ValueState ;
2424import org .apache .flink .api .common .state .ValueStateDescriptor ;
25+ import org .apache .flink .api .connector .sink2 .Sink ;
26+ import org .apache .flink .api .connector .sink2 .SinkWriter ;
27+ import org .apache .flink .api .connector .sink2 .WriterInitContext ;
2528import org .apache .flink .api .java .functions .KeySelector ;
2629import org .apache .flink .api .java .tuple .Tuple2 ;
2730import org .apache .flink .configuration .CheckpointingOptions ;
3841import org .apache .flink .streaming .api .checkpoint .CheckpointedFunction ;
3942import org .apache .flink .streaming .api .datastream .DataStream ;
4043import org .apache .flink .streaming .api .environment .StreamExecutionEnvironment ;
41- import org .apache .flink .streaming .api .functions .sink .legacy .SinkFunction ;
4244import org .apache .flink .streaming .api .functions .source .legacy .RichParallelSourceFunction ;
4345import org .apache .flink .test .util .MiniClusterWithClientResource ;
4446import org .apache .flink .testutils .junit .SharedObjects ;
@@ -152,6 +154,7 @@ private String runJobAndGetCheckpoint(
152154 int maxParallelism ,
153155 MiniCluster miniCluster )
154156 throws Exception {
157+ JobID jobID = null ;
155158 try {
156159 JobGraph jobGraph =
157160 createJobGraphWithKeyedState (
@@ -163,15 +166,18 @@ private String runJobAndGetCheckpoint(
163166 true ,
164167 100 ,
165168 miniCluster );
169+ jobID = jobGraph .getJobID ();
166170 miniCluster .submitJob (jobGraph ).get ();
167- miniCluster .requestJobResult (jobGraph . getJobID () ).get ();
168- return getLatestCompletedCheckpointPath (jobGraph . getJobID () , miniCluster )
171+ miniCluster .requestJobResult (jobID ).get ();
172+ return getLatestCompletedCheckpointPath (jobID , miniCluster )
169173 .orElseThrow (
170174 () ->
171175 new IllegalStateException (
172176 "Cannot get completed checkpoint, job failed before completing checkpoint" ));
173177 } finally {
174- CollectionSink .clearElementsSet ();
178+ if (jobID != null ) {
179+ CollectionSink .clearElementsSet (jobID );
180+ }
175181 }
176182 }
177183
@@ -184,6 +190,7 @@ private void restoreAndAssert(
184190 MiniCluster miniCluster ,
185191 String restorePath )
186192 throws Exception {
193+ JobID jobID = null ;
187194 try {
188195 JobGraph scaledJobGraph =
189196 createJobGraphWithKeyedState (
@@ -195,13 +202,14 @@ private void restoreAndAssert(
195202 false ,
196203 100 ,
197204 miniCluster );
205+ jobID = scaledJobGraph .getJobID ();
198206
199207 scaledJobGraph .setSavepointRestoreSettings (forPath (restorePath ));
200208
201209 miniCluster .submitJob (scaledJobGraph ).get ();
202- miniCluster .requestJobResult (scaledJobGraph . getJobID () ).get ();
210+ miniCluster .requestJobResult (jobID ).get ();
203211
204- Set <Tuple2 <Integer , Integer >> actualResult = CollectionSink .getElementsSet ();
212+ Set <Tuple2 <Integer , Integer >> actualResult = CollectionSink .getElementsSet (jobID );
205213
206214 Set <Tuple2 <Integer , Integer >> expectedResult = new HashSet <>();
207215
@@ -215,7 +223,9 @@ private void restoreAndAssert(
215223 }
216224 assertEquals (expectedResult , actualResult );
217225 } finally {
218- CollectionSink .clearElementsSet ();
226+ if (jobID != null ) {
227+ CollectionSink .clearElementsSet (jobID );
228+ }
219229 }
220230 }
221231
@@ -282,7 +292,7 @@ public Integer getKey(Integer value) {
282292 DataStream <Tuple2 <Integer , Integer >> result =
283293 input .flatMap (new SubtaskIndexFlatMapper (numberElementsExpect ));
284294
285- result .addSink (new CollectionSink <>());
295+ result .sinkTo (new CollectionSink <>());
286296
287297 return env .getStreamGraph ().getJobGraph (env .getClass ().getClassLoader (), jobID .get ());
288298 }
@@ -389,25 +399,59 @@ public void initializeState(FunctionInitializationContext context) throws Except
389399 }
390400 }
391401
392- private static class CollectionSink <IN > implements SinkFunction <IN > {
402+ private static class CollectionSink <IN > implements Sink <IN > {
393403
394- private static final Set < Object > elements =
395- Collections . newSetFromMap ( new ConcurrentHashMap <>() );
404+ private static final ConcurrentHashMap < JobID , CollectionSinkWriter <?>> writers =
405+ new ConcurrentHashMap <>();
396406
397407 private static final long serialVersionUID = 1L ;
398408
399409 @ SuppressWarnings ("unchecked" )
400- public static <IN > Set <IN > getElementsSet () {
401- return (Set <IN >) elements ;
410+ public static <IN > Set <IN > getElementsSet (JobID jobID ) {
411+ CollectionSinkWriter <IN > writer = (CollectionSinkWriter <IN >) writers .get (jobID );
412+ if (writer == null ) {
413+ return Collections .emptySet ();
414+ } else {
415+ return writer .getElementsSet ();
416+ }
402417 }
403418
404- public static void clearElementsSet () {
405- elements . clear ( );
419+ public static void clearElementsSet (JobID jobID ) {
420+ writers . remove ( jobID );
406421 }
407422
408423 @ Override
409- public void invoke (IN value ) throws Exception {
410- elements .add (value );
424+ @ SuppressWarnings ("unchecked" )
425+ public SinkWriter <IN > createWriter (WriterInitContext context ) throws IOException {
426+ final CollectionSinkWriter <IN > writer =
427+ (CollectionSinkWriter <IN >)
428+ writers .computeIfAbsent (
429+ context .getJobInfo ().getJobId (),
430+ (k ) -> new CollectionSinkWriter <IN >());
431+ return writer ;
432+ }
433+
434+ private static class CollectionSinkWriter <IN > implements SinkWriter <IN > {
435+
436+ private final Set <Object > elements =
437+ Collections .newSetFromMap (new ConcurrentHashMap <>());
438+
439+ @ Override
440+ public void write (IN element , Context context )
441+ throws IOException , InterruptedException {
442+ elements .add (element );
443+ }
444+
445+ @ Override
446+ public void flush (boolean endOfInput ) throws IOException , InterruptedException {}
447+
448+ @ Override
449+ public void close () throws Exception {}
450+
451+ @ SuppressWarnings ("unchecked" )
452+ public <IN > Set <IN > getElementsSet () {
453+ return (Set <IN >) elements ;
454+ }
411455 }
412456 }
413457}
0 commit comments