diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSink.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSink.java index fb0f8184b591..4c56b13e5041 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSink.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/FlinkSink.java @@ -70,7 +70,6 @@ import static org.apache.paimon.flink.FlinkConnectorOptions.generateCustomUid; import static org.apache.paimon.flink.utils.ManagedMemoryUtils.declareManagedMemory; import static org.apache.paimon.flink.utils.ParallelismUtils.forwardParallelism; -import static org.apache.paimon.flink.utils.ParallelismUtils.setParallelism; import static org.apache.paimon.utils.Preconditions.checkArgument; /** Abstract sink of paimon. */ @@ -227,7 +226,7 @@ public DataStream doWrite( hasSinkMaterializer(input)), commitUser)); if (parallelism == null) { - setParallelism(written, input.getParallelism(), false); + forwardParallelism(written, input); } else { written.setParallelism(parallelism); } diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/ReadWriteTableITCase.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/ReadWriteTableITCase.java index dc07817b45fd..285b245c4a28 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/ReadWriteTableITCase.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/ReadWriteTableITCase.java @@ -56,7 +56,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import java.util.ArrayList; @@ -68,6 +70,7 @@ import java.util.Objects; import java.util.UUID; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.apache.flink.table.planner.factories.TestValuesTableFactory.changelogRow; import static org.apache.paimon.CoreOptions.BUCKET; @@ -1360,10 +1363,12 @@ public void testInferParallelism() throws Exception { .isEqualTo(2); } - @Test - public void testSinkParallelism() throws Exception { - testSinkParallelism(null, bExeEnv.getParallelism()); - testSinkParallelism(23, 23); + @ParameterizedTest + @MethodSource("testSinkParallelismParameters") + public void testSinkParallelism( + boolean isFixedBucket, boolean hasPrimaryKey, boolean isSinkParallelismSet) + throws Exception { + testSinkParallelism(isFixedBucket, hasPrimaryKey, isSinkParallelismSet ? 23 : null); } @Test @@ -1946,7 +1951,8 @@ private int sourceParallelismStreaming(String sql) { return stream.getParallelism(); } - private void testSinkParallelism(Integer configParallelism, int expectedParallelism) + private void testSinkParallelism( + boolean isFixedBucket, boolean hasPrimaryKey, Integer configParallelism) throws Exception { // 1. create a mock table sink Map options = new HashMap<>(); @@ -1954,8 +1960,10 @@ private void testSinkParallelism(Integer configParallelism, int expectedParallel options.put(SINK_PARALLELISM.key(), configParallelism.toString()); } options.put("path", getTempFilePath(UUID.randomUUID().toString())); - options.put("bucket", "1"); - options.put("bucket-key", "a"); + if (isFixedBucket) { + options.put("bucket", "1"); + options.put("bucket-key", "a"); + } DynamicTableFactory.Context context = new FactoryUtil.DefaultDynamicTableContext( @@ -1966,7 +1974,9 @@ private void testSinkParallelism(Integer configParallelism, int expectedParallel new LogicalType[] {new VarCharType(Integer.MAX_VALUE)}, new String[] {"a"}), Collections.emptyList(), - Collections.emptyList()), + hasPrimaryKey + ? Collections.singletonList("a") + : Collections.emptyList()), Collections.emptyMap(), new Configuration(), Thread.currentThread().getContextClassLoader(), @@ -1996,13 +2006,52 @@ private void testSinkParallelism(Integer configParallelism, int expectedParallel // 3. assert parallelism from transformation DataStream mockSource = bExeEnv.fromCollection(Collections.singletonList(GenericRowData.of())); + mockSource.getTransformation().setParallelism(mockSource.getParallelism(), false); DataStreamSink sink = sinkProvider.consumeDataStream(null, mockSource); + + boolean hasPartitionTransformation = isFixedBucket || hasPrimaryKey; + boolean expectedIsParallelismConfigured = + (configParallelism != null) || hasPartitionTransformation; + Transformation transformation = sink.getTransformation(); + boolean isPartitionTransformationFound = true; + boolean isWriterFound = false; // until a PartitionTransformation, see FlinkSinkBuilder.build() while (!(transformation instanceof PartitionTransformation)) { - assertThat(transformation.getParallelism()).isIn(1, expectedParallelism); - transformation = transformation.getInputs().get(0); + if (transformation.getName().contains("Writer")) { + isWriterFound = true; + assertThat(transformation.isParallelismConfigured()) + .isEqualTo(expectedIsParallelismConfigured); + } + assertThat(transformation.getParallelism()) + .isIn( + 1, + configParallelism == null + ? bExeEnv.getParallelism() + : configParallelism); + List> inputTransformations = transformation.getInputs(); + if (inputTransformations.isEmpty()) { + isPartitionTransformationFound = false; + break; + } + transformation = inputTransformations.get(0); + } + assertThat(isPartitionTransformationFound).isEqualTo(hasPartitionTransformation); + assertThat(isWriterFound).isTrue(); + } + + private static Stream testSinkParallelismParameters() { + List allBooleans = Arrays.asList(false, true); + List parameters = new ArrayList<>(); + for (boolean isFixedBucket : allBooleans) { + for (boolean hasPrimaryKey : allBooleans) { + for (boolean isSinkParallelismSet : allBooleans) { + parameters.add( + Arguments.of(isFixedBucket, hasPrimaryKey, isSinkParallelismSet)); + } + } } + return parameters.stream(); } private void assertChangeBucketWithoutRescale(String table, int bucketNum) throws Exception {