diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java index 4dffeee0bb906..fffc7b44305de 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java @@ -33,6 +33,8 @@ */ public class InflightDataRescalingDescriptor implements Serializable { + private static final int[] EMPTY_INT_ARRAY = new int[0]; + public static final InflightDataRescalingDescriptor NO_RESCALE = new NoRescalingDescriptor(); private static final long serialVersionUID = -3396674344669796295L; @@ -115,8 +117,8 @@ public static class InflightDataGateOrPartitionRescalingDescriptor implements Se public static final InflightDataGateOrPartitionRescalingDescriptor NO_STATE = new InflightDataGateOrPartitionRescalingDescriptor( - new int[0], - RescaleMappings.identity(0, 0), + EMPTY_INT_ARRAY, + RescaleMappings.SYMMETRIC_IDENTITY, Collections.emptySet(), MappingType.IDENTITY) { @@ -124,14 +126,12 @@ public static class InflightDataGateOrPartitionRescalingDescriptor implements Se @Override public int[] getOldSubtaskInstances() { - throw new UnsupportedOperationException( - "Cannot get old subtasks from a descriptor that represents no state."); + return EMPTY_INT_ARRAY; } @Override public RescaleMappings getRescaleMappings() { - throw new UnsupportedOperationException( - "Cannot get rescale mappings from a descriptor that represents no state."); + return RescaleMappings.SYMMETRIC_IDENTITY; } }; @@ -228,7 +228,7 @@ public NoRescalingDescriptor() { @Override public int[] getOldSubtaskIndexes(int gateOrPartitionIndex) { - return new int[0]; + return EMPTY_INT_ARRAY; } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptorTest.java index 6252d5b21068f..f4123a34dec5e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptorTest.java @@ -27,31 +27,25 @@ import java.util.Collections; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; /** Tests for {@link InflightDataRescalingDescriptor}. */ class InflightDataRescalingDescriptorTest { @Test - void testNoStateDescriptorThrowsOnGetOldSubtaskInstances() { + void testNoStateDescriptorReturnsEmptyOldSubtaskInstances() { InflightDataGateOrPartitionRescalingDescriptor noStateDescriptor = InflightDataGateOrPartitionRescalingDescriptor.NO_STATE; - assertThatThrownBy(noStateDescriptor::getOldSubtaskInstances) - .isInstanceOf(UnsupportedOperationException.class) - .hasMessageContaining( - "Cannot get old subtasks from a descriptor that represents no state"); + assertThat(noStateDescriptor.getOldSubtaskInstances()).isEqualTo(new int[0]); } @Test - void testNoStateDescriptorThrowsOnGetRescaleMappings() { + void testNoStateDescriptorReturnsSymmetricIdentity() { InflightDataGateOrPartitionRescalingDescriptor noStateDescriptor = InflightDataGateOrPartitionRescalingDescriptor.NO_STATE; - assertThatThrownBy(noStateDescriptor::getRescaleMappings) - .isInstanceOf(UnsupportedOperationException.class) - .hasMessageContaining( - "Cannot get rescale mappings from a descriptor that represents no state"); + assertThat(noStateDescriptor.getRescaleMappings()) + .isEqualTo(RescaleMappings.SYMMETRIC_IDENTITY); } @Test @@ -108,11 +102,10 @@ void testInflightDataRescalingDescriptorWithNoStateDescriptor() { InflightDataRescalingDescriptor rescalingDescriptor = new InflightDataRescalingDescriptor(descriptors); - // First gate/partition has NO_STATE - assertThatThrownBy(() -> rescalingDescriptor.getOldSubtaskIndexes(0)) - .isInstanceOf(UnsupportedOperationException.class); - assertThatThrownBy(() -> rescalingDescriptor.getChannelMapping(0)) - .isInstanceOf(UnsupportedOperationException.class); + // First gate/partition has NO_STATE - should return empty array and SYMMETRIC_IDENTITY + assertThat(rescalingDescriptor.getOldSubtaskIndexes(0)).isEqualTo(new int[0]); + assertThat(rescalingDescriptor.getChannelMapping(0)) + .isEqualTo(RescaleMappings.SYMMETRIC_IDENTITY); // Second gate/partition has normal state assertThat(rescalingDescriptor.getOldSubtaskIndexes(1)).isEqualTo(new int[] {0, 1}); diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleWithMixedExchangesITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleWithMixedExchangesITCase.java index 34c3277c4a1f3..d11f787aef6d0 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleWithMixedExchangesITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleWithMixedExchangesITCase.java @@ -27,6 +27,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.ExternalizedCheckpointRetention; import org.apache.flink.configuration.MemorySize; +import org.apache.flink.configuration.RestartStrategyOptions; import org.apache.flink.configuration.StateRecoveryOptions; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.connector.datagen.source.DataGeneratorSource; @@ -57,6 +58,8 @@ import java.util.List; import java.util.Random; +import static org.apache.flink.configuration.RestartStrategyOptions.RestartStrategyType.NO_RESTART_STRATEGY; + /** * Integration test for rescaling jobs with mixed (UC-supported and UC-unsupported) exchanges from * an unaligned checkpoint. @@ -80,7 +83,8 @@ public static Collection parameter() { UnalignedCheckpointRescaleWithMixedExchangesITCase::createMultiOutputDAG, UnalignedCheckpointRescaleWithMixedExchangesITCase::createMultiInputDAG, UnalignedCheckpointRescaleWithMixedExchangesITCase::createRescalePartitionerDAG, - UnalignedCheckpointRescaleWithMixedExchangesITCase::createMixedComplexityDAG); + UnalignedCheckpointRescaleWithMixedExchangesITCase::createMixedComplexityDAG, + UnalignedCheckpointRescaleWithMixedExchangesITCase::createPartEmptyHashExchangeDAG); } @Before @@ -137,6 +141,7 @@ private StreamExecutionEnvironment getUnalignedCheckpointEnv(@Nullable String re conf.set(CheckpointingOptions.CHECKPOINTING_INTERVAL, Duration.ofSeconds(1)); // Disable aligned timeout to ensure it works with unaligned checkpoint directly conf.set(CheckpointingOptions.ALIGNED_CHECKPOINT_TIMEOUT, Duration.ofSeconds(0)); + conf.set(RestartStrategyOptions.RESTART_STRATEGY, NO_RESTART_STRATEGY.getMainValue()); conf.set( CheckpointingOptions.EXTERNALIZED_CHECKPOINT_RETENTION, ExternalizedCheckpointRetention.RETAIN_ON_CANCELLATION); @@ -336,6 +341,53 @@ private static JobClient createMixedComplexityDAG(StreamExecutionEnvironment env return env.executeAsync(); } + /** + * Creates a DAG where the downstream MapAfterKeyBy task receives input from two hash exchanges: + * one with actual data and one that is empty due to filtering. This tests unaligned checkpoint + * rescaling with mixed empty and non-empty hash partitions. + */ + private static JobClient createPartEmptyHashExchangeDAG(StreamExecutionEnvironment env) + throws Exception { + int source1Parallelism = getRandomParallelism(); + DataGeneratorSource source1 = + new DataGeneratorSource<>( + index -> index, + Long.MAX_VALUE, + RateLimiterStrategy.perSecond(5000), + Types.LONG); + DataStream sourceStream1 = + env.fromSource(source1, WatermarkStrategy.noWatermarks(), "Source 1") + .setParallelism(source1Parallelism); + + int source2Parallelism = getRandomParallelism(); + DataGeneratorSource source2 = + new DataGeneratorSource<>( + index -> index, + Long.MAX_VALUE, + RateLimiterStrategy.perSecond(5000), + Types.LONG); + + // Filter all records to simulate empty state exchange + DataStream sourceStream2 = + env.fromSource(source2, WatermarkStrategy.noWatermarks(), "Source 2") + .setParallelism(source2Parallelism) + .filter(value -> false) + .setParallelism(source2Parallelism); + + sourceStream1 + .union(sourceStream2) + .keyBy((KeySelector) value -> value) + .map( + x -> { + Thread.sleep(5); + return x; + }) + .name("MapAfterKeyBy") + .setParallelism(getRandomParallelism()); + + return env.executeAsync(); + } + private static int getRandomParallelism() { return RANDOM.nextInt(MAX_SLOTS) + 1; }