diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 7e23182042c9..19919672e96c 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -683,6 +683,10 @@ private List getOverrides(boolean streaming) { PTransformMatchers.groupWithShardableStates(), new GroupIntoBatchesOverride.StreamingGroupIntoBatchesWithShardedKeyOverrideFactory( this))); + overridesBuilder.add( + PTransformOverride.of( + KafkaIO.Read.KEYED_BY_PARTITION_MATCHER, + new KeyedByPartitionOverride.StreamingKeyedByPartitionOverrideFactory(this))); overridesBuilder .add( diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/KeyedByPartitionOverride.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/KeyedByPartitionOverride.java new file mode 100644 index 000000000000..68d97d7a4de5 --- /dev/null +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/KeyedByPartitionOverride.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow; + +import java.util.Map; +import org.apache.beam.sdk.io.kafka.KafkaIO; +import org.apache.beam.sdk.io.kafka.KafkaRecord; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.construction.PTransformReplacements; +import org.apache.beam.sdk.util.construction.ReplacementOutputs; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TupleTag; + +@SuppressWarnings({ + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) +}) +public class KeyedByPartitionOverride { + + static class StreamingKeyedByPartitionOverrideFactory + implements PTransformOverrideFactory< + PCollection>, + PCollection>, + KafkaIO.Read.KeyedByPartition> { + + private final DataflowRunner runner; + + StreamingKeyedByPartitionOverrideFactory(DataflowRunner runner) { + this.runner = runner; + } + + @Override + public PTransformReplacement>, PCollection>> + getReplacementTransform( + AppliedPTransform< + PCollection>, + PCollection>, + KafkaIO.Read.KeyedByPartition> + transform) { + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + new StreamingKeyedByPartition<>( + runner, + transform.getTransform(), + PTransformReplacements.getSingletonMainOutput(transform))); + } + + @Override + public Map, ReplacementOutput> mapOutputs( + Map, PCollection> outputs, PCollection> newOutput) { + return ReplacementOutputs.singleton(outputs, newOutput); + } + } + + static class StreamingKeyedByPartition + extends PTransform>, PCollection>> { + + private final transient DataflowRunner runner; + private final KafkaIO.Read.KeyedByPartition originalTransform; + private final transient PCollection> originalOutput; + + public StreamingKeyedByPartition( + DataflowRunner runner, + KafkaIO.Read.KeyedByPartition original, + PCollection> output) { + this.runner = runner; + this.originalTransform = original; + this.originalOutput = output; + } + + @Override + public PCollection> expand(PCollection> input) { + // Record the output PCollection of the original transform since the new output will be + // replaced by the original one when the replacement transform is wired to other nodes in the + // graph, although the old and the new outputs are effectively the same. + runner.maybeRecordPCollectionPreservedKeys(originalOutput); + System.out.println("StreamingKeyedByPartition override"); + return input.apply(originalTransform); + } + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 568fe49217b3..57e46af3caf6 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -63,6 +63,7 @@ import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.PTransformMatcher; import org.apache.beam.sdk.runners.PTransformOverride; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.schemas.JavaFieldSchema; @@ -1575,6 +1576,10 @@ public PTransform>> withoutMetadata() { return new TypedWithoutMetadata<>(this); } + public PTransform>> keyedByPartition() { + return new ValuesKeyedByPartition<>(this); + } + public PTransform> externalWithMetadata() { return new RowsWithMetadata<>(this); } @@ -1818,6 +1823,28 @@ public Map, ReplacementOutput> mapOutputs( } } + @Internal + public static final PTransformMatcher KEYED_BY_PARTITION_MATCHER = + PTransformMatchers.classEqualTo(KeyedByPartition.class); + + public static class KeyedByPartition + extends PTransform>, PCollection>> { + + @Override + public PCollection> expand(PCollection> input) { + return input.apply( + "Repartition", + ParDo.of( + new DoFn, KV>() { + @ProcessElement + public void processElement(ProcessContext ctx) { + ctx.output( + KV.of(ctx.element().getPartition(), ctx.element().getKV().getValue())); + } + })); + } + } + private abstract static class AbstractReadFromKafka extends PTransform>> { Read kafkaRead; @@ -2170,6 +2197,41 @@ public void populateDisplayData(DisplayData.Builder builder) { } } + public static class ValuesKeyedByPartition + extends PTransform>> { + private final Read read; + + ValuesKeyedByPartition(Read read) { + super("KafkaIO.Read"); + this.read = read; + } + + static class Builder + implements ExternalTransformBuilder< + Read.External.Configuration, PBegin, PCollection>> { + + @Override + public PTransform>> buildExternal( + Read.External.Configuration config) { + Read.Builder readBuilder = new AutoValue_KafkaIO_Read.Builder<>(); + Read.Builder.setupExternalBuilder(readBuilder, config); + + return readBuilder.build().keyedByPartition(); + } + } + + @Override + public PCollection> expand(PBegin begin) { + return begin.apply(read).apply("KeyedByPartition", new Read.KeyedByPartition<>()); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + read.populateDisplayData(builder); + } + } + /** * A {@link PTransform} to read from Kafka topics. Similar to {@link KafkaIO.Read}, but removes * Kafka metatdata and returns a {@link PCollection} of {@link KV}. See {@link KafkaIO} for more