diff --git a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSink.java b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSink.java
new file mode 100644
index 0000000000000..cbcfd035f12c5
--- /dev/null
+++ b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSink.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.base.sink;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.connector.sink2.Sink;
+import org.apache.flink.api.connector.sink2.SinkWriter;
+import org.apache.flink.api.connector.sink2.StatefulSinkWriter;
+import org.apache.flink.api.connector.sink2.SupportsWriterState;
+import org.apache.flink.api.connector.sink2.WriterInitContext;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Collection;
+
+/**
+ * A sink that dynamically routes elements to different underlying sinks based on a routing
+ * function.
+ *
+ *
The {@code DemultiplexingSink} allows elements to be routed to different sink instances at
+ * runtime based on the content of each element. This is useful for scenarios such as:
+ *
+ *
+ * Routing messages to different Kafka topics based on message type or contents
+ * Writing to different databases based on a tenant identifier
+ * Sending data to different Elasticsearch clusters or indices based on data characteristics
+ *
+ *
+ * The sink maintains an internal cache of sink instances, creating new sinks on-demand when
+ * previously unseen routes are encountered. This provides good performance while supporting dynamic
+ * routing scenarios.
+ *
+ *
Example usage:
+ *
+ *
{@code
+ * // Route messages to different Kafka topics
+ * SinkRouter router = new SinkRouter() {
+ * @Override
+ * public String getRoute(MyMessage element) {
+ * return element.getTopicName();
+ * }
+ *
+ * @Override
+ * public Sink createSink(String topicName, MyMessage element) {
+ * return KafkaSink.builder()
+ * .setBootstrapServers("localhost:9092")
+ * .setRecordSerializer(...)
+ * .setTopics(topicName)
+ * .build();
+ * }
+ * };
+ *
+ * DemultiplexingSink demuxSink =
+ * new DemultiplexingSink<>(router);
+ *
+ * dataStream.sinkTo(demuxSink);
+ * }
+ *
+ * The sink supports checkpointing and recovery through the {@link SupportsWriterState}
+ * interface. State from all underlying sink writers is collected and restored appropriately during
+ * recovery.
+ *
+ * @param The type of input elements
+ * @param The type of route keys used for sink selection
+ */
+@PublicEvolving
+public class DemultiplexingSink
+ implements Sink, SupportsWriterState> {
+
+ private static final long serialVersionUID = 1L;
+
+ /** The router that determines how elements are routed to sinks. */
+ private final SinkRouter sinkRouter;
+
+ /**
+ * Creates a new demultiplexing sink with the given router.
+ *
+ * @param sinkRouter The router that determines how elements are routed to different sinks
+ */
+ public DemultiplexingSink(SinkRouter sinkRouter) {
+ this.sinkRouter = Preconditions.checkNotNull(sinkRouter, "sinkRouter must not be null");
+ }
+
+ @Override
+ public SinkWriter createWriter(WriterInitContext context) throws IOException {
+ return new DemultiplexingSinkWriter<>(sinkRouter, context);
+ }
+
+ @Override
+ public StatefulSinkWriter> restoreWriter(
+ WriterInitContext context, Collection> recoveredState) {
+
+ return new DemultiplexingSinkWriter<>(sinkRouter, context, recoveredState);
+ }
+
+ @Override
+ public SimpleVersionedSerializer> getWriterStateSerializer() {
+ return new DemultiplexingSinkStateSerializer<>();
+ }
+
+ /**
+ * Gets the sink router used by this demultiplexing sink.
+ *
+ * @return The sink router
+ */
+ public SinkRouter getSinkRouter() {
+ return sinkRouter;
+ }
+}
diff --git a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkState.java b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkState.java
new file mode 100644
index 0000000000000..c35b3ce9d0801
--- /dev/null
+++ b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkState.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.base.sink;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * State class for {@link DemultiplexingSink} that tracks the state of individual sink writers for
+ * each route during checkpointing and recovery.
+ *
+ * This state contains:
+ *
+ *
+ * A mapping of route keys to their corresponding sink writer states
+ * Metadata about which routes are currently active
+ *
+ *
+ * @param The type of route keys
+ */
+@PublicEvolving
+public class DemultiplexingSinkState implements Serializable {
+
+ private static final long serialVersionUID = 1L;
+
+ /** Map of route keys to their serialized sink writer states. */
+ private final Map routeStates;
+
+ /** Creates a new empty demultiplexing sink state. */
+ public DemultiplexingSinkState() {
+ this.routeStates = new HashMap<>();
+ }
+
+ /**
+ * Creates a new demultiplexing sink state with the given route states.
+ *
+ * @param routeStates Map of route keys to their serialized sink writer states
+ */
+ public DemultiplexingSinkState(Map routeStates) {
+ this.routeStates = new HashMap<>(Preconditions.checkNotNull(routeStates));
+ }
+
+ /**
+ * Gets the serialized state for a specific route.
+ *
+ * @param route The route key
+ * @return The serialized state for the route, or null if no state exists
+ */
+ public byte[] getRouteState(RouteT route) {
+ return routeStates.get(route);
+ }
+
+ /**
+ * Sets the serialized state for a specific route.
+ *
+ * @param route The route key
+ * @param state The serialized state data
+ */
+ public void setRouteState(RouteT route, byte[] state) {
+ if (state != null) {
+ routeStates.put(route, state);
+ } else {
+ routeStates.remove(route);
+ }
+ }
+
+ /**
+ * Gets all route keys that have associated state.
+ *
+ * @return An unmodifiable set of route keys
+ */
+ public java.util.Set getRoutes() {
+ return Collections.unmodifiableSet(routeStates.keySet());
+ }
+
+ /**
+ * Gets a copy of all route states.
+ *
+ * @return An unmodifiable map of route keys to their serialized states
+ */
+ public Map getRouteStates() {
+ return Collections.unmodifiableMap(routeStates);
+ }
+
+ /**
+ * Checks if this state contains any route states.
+ *
+ * @return true if there are no route states, false otherwise
+ */
+ public boolean isEmpty() {
+ return routeStates.isEmpty();
+ }
+
+ /**
+ * Gets the number of routes with associated state.
+ *
+ * @return The number of routes
+ */
+ public int size() {
+ return routeStates.size();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ DemultiplexingSinkState> that = (DemultiplexingSinkState>) o;
+
+ // Compare route states with proper byte array comparison
+ if (routeStates.size() != that.routeStates.size()) {
+ return false;
+ }
+
+ for (Map.Entry entry : routeStates.entrySet()) {
+ RouteT key = entry.getKey();
+ byte[] value = entry.getValue();
+ byte[] otherValue = that.routeStates.get(key);
+
+ if (!java.util.Arrays.equals(value, otherValue)) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = 1;
+ for (Map.Entry entry : routeStates.entrySet()) {
+ result = 31 * result + Objects.hashCode(entry.getKey());
+ result = 31 * result + java.util.Arrays.hashCode(entry.getValue());
+ }
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ return "DemultiplexingSinkState{"
+ + "routeCount="
+ + routeStates.size()
+ + ", routes="
+ + routeStates.keySet()
+ + '}';
+ }
+}
diff --git a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializer.java b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializer.java
new file mode 100644
index 0000000000000..36db571368b92
--- /dev/null
+++ b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializer.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.base.sink;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Serializer for {@link DemultiplexingSinkState}.
+ *
+ * This serializer handles the serialization of the demultiplexing sink state, which contains a
+ * mapping of route keys to their corresponding sink writer states. The route keys are serialized
+ * using Java serialization, while the sink writer states are stored as raw byte arrays.
+ *
+ *
Serialization format:
+ *
+ *
+ * Version (int)
+ * Number of routes (int)
+ * For each route:
+ *
+ * Route key length (int)
+ * Route key bytes (serialized using Java serialization)
+ * State length (int)
+ * State bytes
+ *
+ *
+ *
+ * @param The type of route keys
+ */
+@PublicEvolving
+public class DemultiplexingSinkStateSerializer
+ implements SimpleVersionedSerializer> {
+
+ private static final int VERSION = 1;
+
+ @Override
+ public int getVersion() {
+ return VERSION;
+ }
+
+ @Override
+ public byte[] serialize(DemultiplexingSinkState state) throws IOException {
+ final DataOutputSerializer out = new DataOutputSerializer(256);
+
+ // Write the number of routes
+ final Map routeStates = state.getRouteStates();
+ out.writeInt(routeStates.size());
+
+ // Write each route and its state
+ for (Map.Entry entry : routeStates.entrySet()) {
+ // Serialize the route key using Java serialization
+ final byte[] routeBytes = serializeRouteKey(entry.getKey());
+ out.writeInt(routeBytes.length);
+ out.write(routeBytes);
+
+ // Write the state bytes
+ final byte[] stateBytes = entry.getValue();
+ out.writeInt(stateBytes.length);
+ out.write(stateBytes);
+ }
+
+ return out.getCopyOfBuffer();
+ }
+
+ @Override
+ public DemultiplexingSinkState deserialize(int version, byte[] serialized)
+ throws IOException {
+ if (version != VERSION) {
+ throw new IOException(
+ "Unsupported version: " + version + ". Supported version: " + VERSION);
+ }
+
+ final DataInputDeserializer in = new DataInputDeserializer(serialized);
+
+ // Read the number of routes
+ final int numRoutes = in.readInt();
+ final Map routeStates = new HashMap<>(numRoutes);
+
+ // Read each route and its state
+ for (int i = 0; i < numRoutes; i++) {
+ // Read the route key
+ final int routeLength = in.readInt();
+ final byte[] routeBytes = new byte[routeLength];
+ in.readFully(routeBytes);
+ final RouteT route = deserializeRouteKey(routeBytes);
+
+ // Read the state bytes
+ final int stateLength = in.readInt();
+ final byte[] stateBytes = new byte[stateLength];
+ in.readFully(stateBytes);
+
+ // Store the route states
+ routeStates.put(route, stateBytes);
+ }
+
+ return new DemultiplexingSinkState<>(routeStates);
+ }
+
+ /**
+ * Serializes a route key using Java serialization.
+ *
+ * @param routeKey The route key to serialize
+ * @return The serialized route key bytes
+ * @throws IOException If serialization fails
+ */
+ private byte[] serializeRouteKey(RouteT routeKey) throws IOException {
+ try (final ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ final ObjectOutputStream oos = new ObjectOutputStream(baos)) {
+ oos.writeObject(routeKey);
+ oos.flush();
+ return baos.toByteArray();
+ } catch (Exception e) {
+ throw new IOException("Failed to serialize route key: " + routeKey, e);
+ }
+ }
+
+ /**
+ * Deserializes a route key using Java serialization.
+ *
+ * @param routeBytes The serialized route key bytes
+ * @return The deserialized route key
+ * @throws IOException If deserialization fails
+ */
+ @SuppressWarnings("unchecked")
+ private RouteT deserializeRouteKey(byte[] routeBytes) throws IOException {
+ try (final ByteArrayInputStream bais = new ByteArrayInputStream(routeBytes);
+ final ObjectInputStream ois = new ObjectInputStream(bais)) {
+ return (RouteT) ois.readObject();
+ } catch (Exception e) {
+ throw new IOException("Failed to deserialize route key", e);
+ }
+ }
+}
diff --git a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriter.java b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriter.java
new file mode 100644
index 0000000000000..51c48b0e88cb3
--- /dev/null
+++ b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriter.java
@@ -0,0 +1,493 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.base.sink;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.eventtime.Watermark;
+import org.apache.flink.api.connector.sink2.Sink;
+import org.apache.flink.api.connector.sink2.SinkWriter;
+import org.apache.flink.api.connector.sink2.StatefulSinkWriter;
+import org.apache.flink.api.connector.sink2.SupportsWriterState;
+import org.apache.flink.api.connector.sink2.WriterInitContext;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
+import org.apache.flink.util.Preconditions;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A sink writer that routes elements to different underlying sink writers based on a routing
+ * function.
+ *
+ * This writer maintains a cache of sink writers, creating new ones on-demand when previously
+ * unseen routes are encountered. Each underlying sink writer is managed independently, including
+ * their lifecycle, state management, and error handling.
+ *
+ *
The writer supports state management by collecting state from all underlying sink writers
+ * during checkpointing and restoring them appropriately during recovery.
+ *
+ * @param The type of input elements
+ * @param The type of route keys used for sink selection
+ */
+@PublicEvolving
+public class DemultiplexingSinkWriter
+ implements StatefulSinkWriter> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(DemultiplexingSinkWriter.class);
+
+ /** The router that determines how elements are routed to sinks. */
+ private final SinkRouter sinkRouter;
+
+ /** The writer initialization context. */
+ private final WriterInitContext context;
+
+ /** Cache of sink writers by route key. */
+ private final Map> sinkWriters;
+
+ /** Cache of sink instances by route key for creating writers. */
+ private final Map> sinks;
+
+ /** Cache of recovered states by route key for lazy restoration. */
+ private final Map> recoveredStates;
+
+ /**
+ * Creates a new demultiplexing sink writer.
+ *
+ * @param sinkRouter The router that determines how elements are routed to different sinks
+ * @param context The writer initialization context
+ */
+ public DemultiplexingSinkWriter(
+ SinkRouter sinkRouter, WriterInitContext context) {
+ this.sinkRouter = Preconditions.checkNotNull(sinkRouter);
+ this.context = Preconditions.checkNotNull(context);
+ this.sinkWriters = new HashMap<>();
+ this.sinks = new HashMap<>();
+ this.recoveredStates = new HashMap<>();
+ }
+
+ /**
+ * Creates a new demultiplexing sink writer and restores state from a previous checkpoint.
+ *
+ * @param sinkRouter The router that determines how elements are routed to different sinks
+ * @param context The writer initialization context
+ * @param recoveredStates The recovered states from previous checkpoints
+ */
+ public DemultiplexingSinkWriter(
+ SinkRouter sinkRouter,
+ WriterInitContext context,
+ Collection> recoveredStates) {
+ this(sinkRouter, context);
+
+ // Process recovered states and prepare them for lazy restoration
+ for (DemultiplexingSinkState state : recoveredStates) {
+ for (RouteT route : state.getRoutes()) {
+ byte[] routeStateBytes = state.getRouteState(route);
+ if (routeStateBytes != null && routeStateBytes.length > 0) {
+ try {
+ // Deserialize the writer states for this route
+ List writerStates = deserializeWriterStates(route, routeStateBytes);
+ this.recoveredStates.put(route, writerStates);
+ LOG.debug(
+ "Prepared state restoration for route: {} with {} states",
+ route,
+ writerStates.size());
+ } catch (Exception e) {
+ LOG.warn(
+ "Failed to deserialize state for route: {}, will start with empty state",
+ route,
+ e);
+ }
+ }
+ }
+ }
+ }
+
+ @Override
+ public void write(InputT element, Context context) throws IOException, InterruptedException {
+ // Determine the route for this element
+ final RouteT route = sinkRouter.getRoute(element);
+
+ // Get or create the sink writer for this route
+ SinkWriter writer = getOrCreateSinkWriter(route, element);
+
+ // Delegate to the appropriate sink writer
+ writer.write(element, context);
+ }
+
+ @Override
+ public void flush(boolean endOfInput) throws IOException, InterruptedException {
+ // Flush all active sink writers
+ IOException lastException = null;
+ for (Map.Entry> entry : sinkWriters.entrySet()) {
+ try {
+ entry.getValue().flush(endOfInput);
+ } catch (IOException e) {
+ LOG.warn("Failed to flush sink writer for route: {}", entry.getKey(), e);
+ lastException = e;
+ }
+ }
+
+ // Re-throw the last exception if any occurred
+ if (lastException != null) {
+ throw lastException;
+ }
+ }
+
+ @Override
+ public void writeWatermark(Watermark watermark) throws IOException, InterruptedException {
+ // Propagate watermark to all active sink writers
+ IOException lastException = null;
+ for (Map.Entry> entry : sinkWriters.entrySet()) {
+ try {
+ entry.getValue().writeWatermark(watermark);
+ } catch (IOException e) {
+ LOG.warn(
+ "Failed to write watermark to sink writer for route: {}",
+ entry.getKey(),
+ e);
+ lastException = e;
+ }
+ }
+
+ // Re-throw the last exception if any occurred
+ if (lastException != null) {
+ throw lastException;
+ }
+ }
+
+ @Override
+ public List> snapshotState(long checkpointId)
+ throws IOException {
+ final Map routeStates = new HashMap<>();
+
+ // Collect state from all active sink writers
+ for (Map.Entry> entry : sinkWriters.entrySet()) {
+ final RouteT route = entry.getKey();
+ final SinkWriter writer = entry.getValue();
+
+ // Only collect state from stateful sink writers
+ if (writer instanceof StatefulSinkWriter) {
+ StatefulSinkWriter statefulWriter =
+ (StatefulSinkWriter) writer;
+
+ try {
+ List> writerStates = statefulWriter.snapshotState(checkpointId);
+ if (!writerStates.isEmpty()) {
+ // Serialize the writer states and store them
+ byte[] serializedState = serializeWriterStates(route, writerStates);
+ routeStates.put(route, serializedState);
+ }
+ } catch (Exception e) {
+ LOG.warn("Failed to snapshot state for route: {}", route, e);
+ throw new IOException("Failed to snapshot state for route: " + route, e);
+ }
+ }
+ }
+
+ // Return a single state object containing all route states
+ final List> states = new ArrayList<>();
+ if (!routeStates.isEmpty()) {
+ states.add(new DemultiplexingSinkState<>(routeStates));
+ }
+
+ return states;
+ }
+
+ @Override
+ public void close() throws Exception {
+ // Close all active sink writers
+ Exception lastException = null;
+ for (Map.Entry> entry : sinkWriters.entrySet()) {
+ try {
+ entry.getValue().close();
+ } catch (Exception e) {
+ LOG.warn("Failed to close sink writer for route: {}", entry.getKey(), e);
+ lastException = e;
+ }
+ }
+
+ // Clear the caches
+ sinkWriters.clear();
+ sinks.clear();
+
+ // Re-throw the last exception if any occurred
+ if (lastException != null) {
+ throw lastException;
+ }
+ }
+
+ /**
+ * Gets or creates a sink writer for the given route.
+ *
+ * @param route The route key
+ * @param element The element that triggered this route (used for sink creation)
+ * @return The sink writer for the route
+ * @throws IOException If sink or writer creation fails
+ */
+ private SinkWriter getOrCreateSinkWriter(RouteT route, InputT element)
+ throws IOException {
+ SinkWriter writer = sinkWriters.get(route);
+ if (writer == null) {
+ // Create a new sink for this route
+ Sink sink = sinkRouter.createSink(route, element);
+ sinks.put(route, sink);
+
+ // Check if we have recovered state for this route
+ List routeRecoveredStates = recoveredStates.remove(route);
+
+ if (routeRecoveredStates != null && !routeRecoveredStates.isEmpty()) {
+ // Restore writer with state if the sink supports it
+ if (sink instanceof SupportsWriterState) {
+ @SuppressWarnings("unchecked")
+ SupportsWriterState statefulSink =
+ (SupportsWriterState) sink;
+
+ try {
+ // Check if we need to deserialize raw bytes using the sink's serializer
+ List processedStates =
+ processRecoveredStates(statefulSink, routeRecoveredStates);
+
+ writer = statefulSink.restoreWriter(context, processedStates);
+ LOG.debug(
+ "Restored sink writer for route: {} with {} states",
+ route,
+ processedStates.size());
+ } catch (Exception e) {
+ LOG.warn(
+ "Failed to restore writer state for route: {}, creating new writer",
+ route,
+ e);
+ writer = sink.createWriter(context);
+ }
+ } else {
+ // Sink doesn't support state, just create a new writer
+ writer = sink.createWriter(context);
+ LOG.debug("Sink for route {} doesn't support state, created new writer", route);
+ }
+ } else {
+ // No recovered state, create a new writer
+ writer = sink.createWriter(context);
+ LOG.debug("Created new sink writer for route: {} (no recovered state)", route);
+ }
+
+ sinkWriters.put(route, writer);
+ }
+ return writer;
+ }
+
+ /**
+ * Processes recovered states, deserializing raw bytes if necessary.
+ *
+ * @param statefulSink The stateful sink that can provide a state serializer
+ * @param recoveredStates The recovered states (may contain raw bytes)
+ * @return Processed states ready for restoration
+ */
+ private List processRecoveredStates(
+ SupportsWriterState statefulSink, List recoveredStates) {
+
+ List processedStates = new ArrayList<>();
+ SimpleVersionedSerializer serializer = statefulSink.getWriterStateSerializer();
+
+ for (Object state : recoveredStates) {
+ if (state instanceof byte[]) {
+ // This is raw bytes that need to be deserialized using the sink's serializer
+ byte[] rawBytes = (byte[]) state;
+ try {
+ List deserializedStates =
+ deserializeWithSinkSerializer(serializer, rawBytes);
+ processedStates.addAll(deserializedStates);
+ LOG.debug(
+ "Successfully deserialized {} states from raw bytes using sink serializer",
+ deserializedStates.size());
+ } catch (Exception e) {
+ LOG.warn(
+ "Failed to deserialize raw bytes using sink serializer, skipping state",
+ e);
+ }
+ } else {
+ // This is already a deserialized state object
+ processedStates.add(state);
+ }
+ }
+
+ return processedStates;
+ }
+
+ /**
+ * Deserializes states using the sink's state serializer.
+ *
+ * @param serializer The sink's state serializer
+ * @param rawBytes The raw serialized bytes
+ * @return List of deserialized state objects
+ * @throws IOException If deserialization fails
+ */
+ private List deserializeWithSinkSerializer(
+ SimpleVersionedSerializer serializer, byte[] rawBytes) throws IOException {
+
+ List states = new ArrayList<>();
+
+ try (final ByteArrayInputStream bais = new ByteArrayInputStream(rawBytes);
+ final java.io.DataInputStream dis = new java.io.DataInputStream(bais)) {
+
+ // Read the number of states
+ int numStates = dis.readInt();
+
+ // Read each state
+ for (int i = 0; i < numStates; i++) {
+ int stateLength = dis.readInt();
+ byte[] stateBytes = new byte[stateLength];
+ dis.readFully(stateBytes);
+
+ Object state = serializer.deserialize(serializer.getVersion(), stateBytes);
+ states.add(state);
+ }
+ }
+
+ return states;
+ }
+
+ /**
+ * Serializes the writer states for a given route.
+ *
+ * This implementation attempts to use the proper state serializer from the underlying sink
+ * if available, otherwise falls back to Java serialization.
+ *
+ * @param route The route key
+ * @param writerStates The writer states to serialize
+ * @return The serialized state bytes
+ * @throws IOException If serialization fails
+ */
+ private byte[] serializeWriterStates(RouteT route, List> writerStates) throws IOException {
+ if (writerStates == null || writerStates.isEmpty()) {
+ return new byte[0];
+ }
+
+ Sink sink = sinks.get(route);
+ if (sink instanceof SupportsWriterState) {
+ try {
+ @SuppressWarnings("unchecked")
+ SupportsWriterState statefulSink =
+ (SupportsWriterState) sink;
+ SimpleVersionedSerializer serializer =
+ statefulSink.getWriterStateSerializer();
+
+ // Serialize each state and combine them
+ try (final ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ final DataOutputStream dos = new DataOutputStream(baos)) {
+
+ // Write the number of states
+ dos.writeInt(writerStates.size());
+
+ // Write each state
+ for (Object state : writerStates) {
+ byte[] stateBytes = serializer.serialize(state);
+ dos.writeInt(stateBytes.length);
+ dos.write(stateBytes);
+ }
+
+ dos.flush();
+ return baos.toByteArray();
+ }
+ } catch (Exception e) {
+ LOG.warn(
+ "Failed to serialize state using sink serializer for route: {}, "
+ + "falling back to Java serialization",
+ route,
+ e);
+ }
+ }
+
+ // Fallback to Java serialization
+ try (final ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ final ObjectOutputStream oos = new ObjectOutputStream(baos)) {
+ oos.writeObject(writerStates);
+ oos.flush();
+ return baos.toByteArray();
+ } catch (Exception e) {
+ throw new IOException("Failed to serialize writer states for route: " + route, e);
+ }
+ }
+
+ /**
+ * Deserializes the writer states for a given route.
+ *
+ * This method attempts to deserialize states using Java serialization as a fallback. The
+ * proper sink-specific deserialization will happen later when the sink is created and we can
+ * access its state serializer.
+ *
+ * @param route The route key
+ * @param stateBytes The serialized state bytes
+ * @return The deserialized writer states (or raw bytes to be deserialized later)
+ */
+ private List deserializeWriterStates(RouteT route, byte[] stateBytes) {
+ if (stateBytes == null || stateBytes.length == 0) {
+ return new ArrayList<>();
+ }
+
+ // Try Java deserialization first (this handles the fallback case)
+ try (final ByteArrayInputStream bais = new ByteArrayInputStream(stateBytes);
+ final ObjectInputStream ois = new ObjectInputStream(bais)) {
+ @SuppressWarnings("unchecked")
+ List states = (List) ois.readObject();
+ LOG.debug(
+ "Successfully deserialized {} states for route {} using Java serialization",
+ states.size(),
+ route);
+ return states;
+ } catch (Exception e) {
+ LOG.debug(
+ "Java deserialization failed for route {}, will store raw bytes for later processing",
+ route);
+ // If Java deserialization fails, store the raw bytes
+ // They will be processed when we have access to the sink's serializer
+ List rawStates = new ArrayList<>();
+ rawStates.add(stateBytes); // Store as raw bytes
+ return rawStates;
+ }
+ }
+
+ /**
+ * Gets the number of currently active sink writers.
+ *
+ * @return The number of active sink writers
+ */
+ public int getActiveSinkWriterCount() {
+ return sinkWriters.size();
+ }
+
+ /**
+ * Gets the routes for all currently active sink writers.
+ *
+ * @return A collection of route keys for active sink writers
+ */
+ public Collection getActiveRoutes() {
+ return new ArrayList<>(sinkWriters.keySet());
+ }
+}
diff --git a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/SinkRouter.java b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/SinkRouter.java
new file mode 100644
index 0000000000000..10837e628bb53
--- /dev/null
+++ b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/SinkRouter.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.base.sink;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.connector.sink2.Sink;
+
+import java.io.Serializable;
+
+/**
+ * Interface for routing elements to different sinks in a {@link DemultiplexingSink}.
+ *
+ * The router is responsible for two key operations:
+ *
+ *
+ * Extracting a route key from each input element that determines which sink to use
+ * Creating new sink instances when a previously unseen route is encountered
+ *
+ *
+ * Route keys should be deterministic and consistent - the same logical destination should always
+ * produce the same route key to ensure proper sink reuse and state management.
+ *
+ *
Example usage:
+ *
+ *
{@code
+ * // Route messages to different Kafka topics based on message type
+ * SinkRouter router = new SinkRouter() {
+ * @Override
+ * public String getRoute(MyMessage element) {
+ * return element.getMessageType(); // e.g., "orders", "users", "events"
+ * }
+ *
+ * @Override
+ * public Sink createSink(String route, MyMessage element) {
+ * return KafkaSink.builder()
+ * .setBootstrapServers("localhost:9092")
+ * .setRecordSerializer(...)
+ * .setTopics(route) // Use route as topic name
+ * .build();
+ * }
+ * };
+ * }
+ *
+ * @param The type of input elements to route
+ * @param The type of route keys used for sink selection and caching
+ */
+@PublicEvolving
+public interface SinkRouter extends Serializable {
+
+ /**
+ * Extract the route key from an input element.
+ *
+ * This method is called for every element and should be efficient. The returned route key is
+ * used to:
+ *
+ *
+ * Look up the appropriate sink instance in the cache
+ * Create a new sink if this route hasn't been seen before
+ * Group elements by route for state management during checkpointing
+ *
+ *
+ * Route keys must implement {@link Object#equals(Object)} and {@link Object#hashCode()}
+ * properly as they are used as keys in hash-based collections.
+ *
+ * @param element The input element to route
+ * @return The route key that determines which sink to use for this element
+ */
+ RouteT getRoute(InputT element);
+
+ /**
+ * Create a new sink instance for the given route.
+ *
+ *
This method is called when a route is encountered for the first time. The created sink
+ * will be cached and reused for all subsequent elements with the same route key.
+ *
+ *
The element parameter provides access to the specific element that triggered the creation
+ * of this route, which can be useful for extracting configuration information (e.g., connection
+ * details, authentication credentials) that may be embedded in the element.
+ *
+ *
The created sink should be fully configured and ready to use. It will be initialized by
+ * the DemultiplexingSink framework using the standard Sink API.
+ *
+ * @param route The route key for which to create a sink
+ * @param element The element that triggered the creation of this route (for configuration
+ * extraction)
+ * @return A new sink instance configured for the given route
+ */
+ Sink createSink(RouteT route, InputT element);
+}
diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkIT.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkIT.java
new file mode 100644
index 0000000000000..d1755518d61dd
--- /dev/null
+++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkIT.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.base.sink;
+
+import org.apache.flink.api.connector.sink2.SinkWriter;
+import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.test.junit5.MiniClusterExtension;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Integration tests for {@link DemultiplexingSink}. */
+class DemultiplexingSinkIT {
+
+ @RegisterExtension
+ private static final MiniClusterExtension MINI_CLUSTER_RESOURCE =
+ new MiniClusterExtension(
+ new MiniClusterResourceConfiguration.Builder()
+ .setNumberTaskManagers(1)
+ .setNumberSlotsPerTaskManager(1)
+ .build());
+
+ final StreamExecutionEnvironment env =
+ StreamExecutionEnvironment.getExecutionEnvironment().setParallelism(1);
+
+ @Test
+ void testBasicElementRouting() throws Exception {
+ // Create a simple router that routes by first character
+ final TestSinkRouter router = new TestSinkRouter();
+
+ // Create the demultiplexing sink
+ final DemultiplexingSink demuxSink = new DemultiplexingSink<>(router);
+
+ // Create a data stream with elements that will route to different sinks
+ env.fromData("apple", "banana", "apricot", "blueberry", "avocado", "cherry")
+ .sinkTo(demuxSink);
+
+ // Execute the job
+ env.execute("Integration Test: DemultiplexingSinkIT");
+
+ // Verify results (should contain three routes: "a", "b", "c")
+ ConcurrentMap> results = TestSinkRouter.getResults();
+ assertThat(results).hasSize(3);
+ assertThat(results.get("a")).containsExactlyInAnyOrder("apple", "apricot", "avocado");
+ assertThat(results.get("b")).containsExactlyInAnyOrder("banana", "blueberry");
+ assertThat(results.get("c")).containsExactlyInAnyOrder("cherry");
+ }
+
+ /** A serializable test router that routes by first character. */
+ private static class TestSinkRouter implements SinkRouter {
+ private static final long serialVersionUID = 1L;
+
+ // Static shared results map for collecting test results
+ private static final ConcurrentMap> results =
+ new ConcurrentHashMap<>();
+
+ @Override
+ public String getRoute(String element) {
+ return element.substring(0, 1);
+ }
+
+ @Override
+ public org.apache.flink.api.connector.sink2.Sink createSink(
+ String route, String element) {
+ return new CollectingSink(route, results);
+ }
+
+ public static ConcurrentMap> getResults() {
+ return results;
+ }
+
+ public static void clearResults() {
+ results.clear();
+ }
+ }
+
+ /** A simple collecting sink for testing. */
+ private static class CollectingSink
+ implements org.apache.flink.api.connector.sink2.Sink {
+ private static final long serialVersionUID = 1L;
+
+ private final String route;
+ private final ConcurrentMap> results;
+
+ public CollectingSink(String route, ConcurrentMap> results) {
+ this.route = route;
+ this.results = results;
+ }
+
+ @Override
+ public org.apache.flink.api.connector.sink2.SinkWriter createWriter(
+ org.apache.flink.api.connector.sink2.WriterInitContext context) {
+ return new CollectingSinkWriter(route, results);
+ }
+ }
+
+ /** A simple collecting sink writer for testing. */
+ private static class CollectingSinkWriter implements SinkWriter, Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private final String route;
+ private final ConcurrentMap> results;
+
+ public CollectingSinkWriter(String route, ConcurrentMap> results) {
+ this.route = route;
+ this.results = results;
+ }
+
+ @Override
+ public void write(String element, Context context) {
+ results.computeIfAbsent(route, k -> new ArrayList<>()).add(element);
+ }
+
+ @Override
+ public void flush(boolean endOfInput) {
+ // No-op for this test
+ }
+
+ @Override
+ public void close() {
+ // No-op for this test
+ }
+ }
+}
diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateManagementTest.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateManagementTest.java
new file mode 100644
index 0000000000000..7dafd2e44ce63
--- /dev/null
+++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateManagementTest.java
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.base.sink;
+
+import org.apache.flink.api.connector.sink2.Sink;
+import org.apache.flink.api.connector.sink2.SinkWriter;
+import org.apache.flink.api.connector.sink2.StatefulSinkWriter;
+import org.apache.flink.api.connector.sink2.SupportsWriterState;
+import org.apache.flink.api.connector.sink2.WriterInitContext;
+import org.apache.flink.connector.base.sink.writer.TestSinkInitContext;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for state management in {@link DemultiplexingSinkWriter}. */
+class DemultiplexingSinkStateManagementTest {
+
+ private TestStatefulSinkRouter router;
+ private TestSinkInitContext context;
+
+ @BeforeEach
+ void setUp() {
+ router = new TestStatefulSinkRouter();
+ context = new TestSinkInitContext();
+ }
+
+ @Test
+ void testSnapshotAndRestoreState() throws Exception {
+ // Create writer and write to multiple routes
+ DemultiplexingSinkWriter writer =
+ new DemultiplexingSinkWriter<>(router, context);
+
+ // Write elements to create sink writers ("apple" and "apricot" resolve to same sink)
+ writer.write("apple", createContext());
+ writer.write("banana", createContext());
+ writer.write("apricot", createContext());
+
+ // Verify sink writers were created
+ assertThat(writer.getActiveSinkWriterCount()).isEqualTo(2);
+ assertThat(writer.getActiveRoutes()).containsExactlyInAnyOrder("a", "b");
+
+ // Add some state to the sink writers
+ TestStatefulSinkWriter writerA = router.getStatefulSinkWriter("a");
+ TestStatefulSinkWriter writerB = router.getStatefulSinkWriter("b");
+ writerA.addState("state-a-1");
+ writerA.addState("state-a-2");
+ writerB.addState("state-b-1");
+
+ // Verify state was added
+ assertThat(writerA.getStates()).containsExactly("state-a-1", "state-a-2");
+ assertThat(writerB.getStates()).containsExactly("state-b-1");
+
+ // Snapshot state
+ List> snapshotStates = writer.snapshotState(1L);
+ assertThat(snapshotStates).hasSize(1);
+
+ DemultiplexingSinkState state = snapshotStates.get(0);
+ assertThat(state.getRoutes()).containsExactlyInAnyOrder("a", "b");
+
+ // Close the original writer
+ writer.close();
+
+ // Create a new router for the restored writer to avoid confusion with old sinks
+ TestStatefulSinkRouter restoredRouter = new TestStatefulSinkRouter();
+
+ // Create a new writer with restored state
+ DemultiplexingSinkWriter restoredWriter =
+ new DemultiplexingSinkWriter<>(restoredRouter, context, snapshotStates);
+
+ // Write to the same routes to trigger restoration ("avocado" -> "a", "blueberry" -> "b")
+ restoredWriter.write("avocado", createContext());
+ restoredWriter.write("blueberry", createContext());
+
+ // Verify that state was restored
+ TestStatefulSinkWriter restoredWriterA = restoredRouter.getStatefulSinkWriter("a");
+ TestStatefulSinkWriter restoredWriterB = restoredRouter.getStatefulSinkWriter("b");
+
+ // Verify that the writers were restored from state
+ assertThat(restoredWriterA.wasRestoredFromState()).isTrue();
+ assertThat(restoredWriterB.wasRestoredFromState()).isTrue();
+
+ // Verify that the restored state contains the expected data
+ assertThat(restoredWriterA.getStates()).containsExactly("state-a-1", "state-a-2");
+ assertThat(restoredWriterB.getStates()).containsExactly("state-b-1");
+
+ restoredWriter.close();
+ }
+
+ @Test
+ void testSnapshotEmptyState() throws Exception {
+ DemultiplexingSinkWriter writer =
+ new DemultiplexingSinkWriter<>(router, context);
+
+ // Snapshot without any active writers
+ List> states = writer.snapshotState(1L);
+ assertThat(states).isEmpty();
+
+ writer.close();
+ }
+
+ @Test
+ void testRestoreWithNonStatefulSinks() throws Exception {
+ // Use a router that creates non-stateful sinks
+ TestNonStatefulSinkRouter nonStatefulRouter = new TestNonStatefulSinkRouter();
+
+ // Create some dummy state
+ DemultiplexingSinkState dummyState = new DemultiplexingSinkState<>();
+ dummyState.setRouteState("a", new byte[] {1, 2, 3});
+
+ // Create writer with restored state
+ DemultiplexingSinkWriter writer =
+ new DemultiplexingSinkWriter<>(nonStatefulRouter, context, List.of(dummyState));
+
+ // Write to trigger sink creation
+ writer.write("apple", createContext());
+
+ // Should work fine even though the sink doesn't support state
+ assertThat(writer.getActiveSinkWriterCount()).isEqualTo(1);
+
+ writer.close();
+ }
+
+ private SinkWriter.Context createContext() {
+ return new SinkWriter.Context() {
+ @Override
+ public long currentWatermark() {
+ return 0;
+ }
+
+ @Override
+ public Long timestamp() {
+ return null;
+ }
+ };
+ }
+
+ /** Test router that creates stateful sinks. */
+ private static class TestStatefulSinkRouter implements SinkRouter {
+ private final AtomicInteger sinkCreationCount = new AtomicInteger(0);
+ private final List createdSinks = new ArrayList<>();
+
+ @Override
+ public String getRoute(String element) {
+ return element.substring(0, 1);
+ }
+
+ @Override
+ public Sink createSink(String route, String element) {
+ sinkCreationCount.incrementAndGet();
+ TestStatefulSink sink = new TestStatefulSink(route);
+ createdSinks.add(sink);
+ return sink;
+ }
+
+ public TestStatefulSinkWriter getStatefulSinkWriter(String route) {
+ return createdSinks.stream()
+ .filter(sink -> sink.getRoute().equals(route))
+ .findFirst()
+ .map(TestStatefulSink::getCreatedWriter)
+ .orElse(null);
+ }
+ }
+
+ /** Test router that creates non-stateful sinks. */
+ private static class TestNonStatefulSinkRouter implements SinkRouter {
+ @Override
+ public String getRoute(String element) {
+ return element.substring(0, 1);
+ }
+
+ @Override
+ public Sink createSink(String route, String element) {
+ return new TestNonStatefulSink(route);
+ }
+ }
+
+ /** Test stateful sink implementation. */
+ private static class TestStatefulSink
+ implements Sink, SupportsWriterState {
+ private final String route;
+ private TestStatefulSinkWriter createdWriter;
+
+ public TestStatefulSink(String route) {
+ this.route = route;
+ }
+
+ @Override
+ public SinkWriter createWriter(WriterInitContext context) {
+ createdWriter = new TestStatefulSinkWriter(route);
+ return createdWriter;
+ }
+
+ @Override
+ public StatefulSinkWriter restoreWriter(
+ WriterInitContext context, Collection recoveredState) {
+ createdWriter = new TestStatefulSinkWriter(route, recoveredState);
+ return createdWriter;
+ }
+
+ @Override
+ public SimpleVersionedSerializer getWriterStateSerializer() {
+ return new TestStringSerializer();
+ }
+
+ public String getRoute() {
+ return route;
+ }
+
+ public TestStatefulSinkWriter getCreatedWriter() {
+ return createdWriter;
+ }
+ }
+
+ /** Test non-stateful sink implementation. */
+ private static class TestNonStatefulSink implements Sink {
+ private final String route;
+
+ public TestNonStatefulSink(String route) {
+ this.route = route;
+ }
+
+ @Override
+ public SinkWriter createWriter(WriterInitContext context) {
+ return new TestNonStatefulSinkWriter(route);
+ }
+ }
+
+ /** Test stateful sink writer. */
+ private static class TestStatefulSinkWriter implements StatefulSinkWriter {
+ private final String route;
+ private final List elements = new ArrayList<>();
+ private final List states = new ArrayList<>();
+ private final boolean restoredFromState;
+
+ public TestStatefulSinkWriter(String route) {
+ this.route = route;
+ this.restoredFromState = false;
+ }
+
+ public TestStatefulSinkWriter(String route, Collection recoveredStates) {
+ this.route = route;
+ this.states.addAll(recoveredStates);
+ this.restoredFromState = true;
+ }
+
+ @Override
+ public void write(String element, Context context) {
+ elements.add(element);
+ }
+
+ @Override
+ public void flush(boolean endOfInput) {
+ // No-op
+ }
+
+ @Override
+ public void close() {
+ // No-op
+ }
+
+ @Override
+ public List snapshotState(long checkpointId) {
+ return new ArrayList<>(states);
+ }
+
+ public void addState(String state) {
+ states.add(state);
+ }
+
+ public boolean wasRestoredFromState() {
+ return restoredFromState;
+ }
+
+ public List getElements() {
+ return new ArrayList<>(elements);
+ }
+
+ public List getStates() {
+ return new ArrayList<>(states);
+ }
+ }
+
+ /** Test non-stateful sink writer. */
+ private static class TestNonStatefulSinkWriter implements SinkWriter {
+ private final String route;
+ private final List elements = new ArrayList<>();
+
+ public TestNonStatefulSinkWriter(String route) {
+ this.route = route;
+ }
+
+ @Override
+ public void write(String element, Context context) {
+ elements.add(element);
+ }
+
+ @Override
+ public void flush(boolean endOfInput) {
+ // No-op
+ }
+
+ @Override
+ public void close() {
+ // No-op
+ }
+
+ public List getElements() {
+ return new ArrayList<>(elements);
+ }
+ }
+
+ /** Simple string serializer for testing. */
+ private static class TestStringSerializer implements SimpleVersionedSerializer {
+ @Override
+ public int getVersion() {
+ return 1;
+ }
+
+ @Override
+ public byte[] serialize(String obj) {
+ return obj.getBytes();
+ }
+
+ @Override
+ public String deserialize(int version, byte[] serialized) {
+ return new String(serialized);
+ }
+ }
+}
diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializerTest.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializerTest.java
new file mode 100644
index 0000000000000..c1eb9ea23d848
--- /dev/null
+++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializerTest.java
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.base.sink;
+
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link DemultiplexingSinkStateSerializer}. */
+class DemultiplexingSinkStateSerializerTest {
+
+ @Test
+ void testSerializeDeserializeEmptyState() throws IOException {
+ final DemultiplexingSinkStateSerializer serializer =
+ new DemultiplexingSinkStateSerializer<>();
+ final DemultiplexingSinkState originalState = new DemultiplexingSinkState<>();
+
+ final byte[] serialized = serializer.serialize(originalState);
+ final DemultiplexingSinkState deserializedState =
+ serializer.deserialize(serializer.getVersion(), serialized);
+
+ assertThat(deserializedState).isEqualTo(originalState);
+ assertThat(deserializedState.isEmpty()).isTrue();
+ }
+
+ @Test
+ void testSerializeDeserializeStateWithRoutes() throws IOException {
+ final DemultiplexingSinkStateSerializer serializer =
+ new DemultiplexingSinkStateSerializer<>();
+
+ // Create state with multiple routes ("route3" -> empty state)
+ final Map routeStates = new HashMap<>();
+ routeStates.put("route1", new byte[] {1, 2, 3});
+ routeStates.put("route2", new byte[] {4, 5, 6, 7});
+ routeStates.put("route3", new byte[0]);
+
+ final DemultiplexingSinkState originalState =
+ new DemultiplexingSinkState<>(routeStates);
+
+ final byte[] serialized = serializer.serialize(originalState);
+ final DemultiplexingSinkState deserializedState =
+ serializer.deserialize(serializer.getVersion(), serialized);
+
+ assertThat(deserializedState).isEqualTo(originalState);
+ assertThat(deserializedState.getRoutes())
+ .containsExactlyInAnyOrder("route1", "route2", "route3");
+ assertThat(deserializedState.getRouteState("route1")).containsExactly(1, 2, 3);
+ assertThat(deserializedState.getRouteState("route2")).containsExactly(4, 5, 6, 7);
+ assertThat(deserializedState.getRouteState("route3")).isEmpty();
+ }
+
+ @Test
+ void testSerializeDeserializeWithComplexRouteKeys() throws IOException {
+ final DemultiplexingSinkStateSerializer serializer =
+ new DemultiplexingSinkStateSerializer<>();
+
+ // Create state with complex route keys
+ final Map routeStates = new HashMap<>();
+ routeStates.put(new ComplexRouteKey("cluster1", 9092), new byte[] {1, 2});
+ routeStates.put(new ComplexRouteKey("cluster2", 9093), new byte[] {3, 4});
+
+ final DemultiplexingSinkState originalState =
+ new DemultiplexingSinkState<>(routeStates);
+
+ final byte[] serialized = serializer.serialize(originalState);
+ final DemultiplexingSinkState deserializedState =
+ serializer.deserialize(serializer.getVersion(), serialized);
+
+ assertThat(deserializedState).isEqualTo(originalState);
+ assertThat(deserializedState.size()).isEqualTo(2);
+ }
+
+ @Test
+ void testDeserializeWithWrongVersion() {
+ final DemultiplexingSinkStateSerializer serializer =
+ new DemultiplexingSinkStateSerializer<>();
+ final byte[] serialized = new byte[] {1, 2, 3, 4};
+
+ assertThatThrownBy(() -> serializer.deserialize(999, serialized))
+ .isInstanceOf(IOException.class)
+ .hasMessageContaining("Unsupported version: 999");
+ }
+
+ @Test
+ void testGetVersion() {
+ final DemultiplexingSinkStateSerializer serializer =
+ new DemultiplexingSinkStateSerializer<>();
+
+ assertThat(serializer.getVersion()).isEqualTo(1);
+ }
+
+ @Test
+ void testSetRouteStateWithNullState() {
+ DemultiplexingSinkState state = new DemultiplexingSinkState<>();
+
+ // Add a route first
+ state.setRouteState("route1", new byte[] {1, 2, 3});
+ assertThat(state.getRoutes()).containsExactly("route1");
+
+ // Setting null state should remove the route
+ state.setRouteState("route1", null);
+ assertThat(state.getRoutes()).isEmpty();
+ assertThat(state.getRouteState("route1")).isNull();
+ }
+
+ @Test
+ void testSetRouteStateWithNullStateOnNonExistentRoute() {
+ DemultiplexingSinkState state = new DemultiplexingSinkState<>();
+
+ // Setting null state on non-existent route should be no-op
+ state.setRouteState("nonExistent", null);
+ assertThat(state.getRoutes()).isEmpty();
+ }
+
+ @Test
+ void testEqualsWithSameInstance() {
+ DemultiplexingSinkState state = new DemultiplexingSinkState<>();
+ state.setRouteState("route1", new byte[] {1, 2, 3});
+
+ assertThat(state.equals(state)).isTrue();
+ }
+
+ @Test
+ void testEqualsWithNull() {
+ DemultiplexingSinkState state = new DemultiplexingSinkState<>();
+
+ assertThat(state.equals(null)).isFalse();
+ }
+
+ @Test
+ void testEqualsWithDifferentClass() {
+ DemultiplexingSinkState state = new DemultiplexingSinkState<>();
+
+ assertThat(state.equals("not a state")).isFalse();
+ }
+
+ @Test
+ void testEqualsWithDifferentRouteSizes() {
+ DemultiplexingSinkState state1 = new DemultiplexingSinkState<>();
+ state1.setRouteState("route1", new byte[] {1, 2, 3});
+
+ DemultiplexingSinkState state2 = new DemultiplexingSinkState<>();
+ state2.setRouteState("route1", new byte[] {1, 2, 3});
+ state2.setRouteState("route2", new byte[] {4, 5, 6});
+
+ assertThat(state1.equals(state2)).isFalse();
+ assertThat(state2.equals(state1)).isFalse();
+ }
+
+ @Test
+ void testEqualsWithSameRoutesButDifferentStates() {
+ DemultiplexingSinkState state1 = new DemultiplexingSinkState<>();
+ state1.setRouteState("route1", new byte[] {1, 2, 3});
+
+ DemultiplexingSinkState state2 = new DemultiplexingSinkState<>();
+ state2.setRouteState("route1", new byte[] {4, 5, 6});
+
+ assertThat(state1.equals(state2)).isFalse();
+ }
+
+ @Test
+ void testEqualsWithSameRoutesAndStates() {
+ DemultiplexingSinkState state1 = new DemultiplexingSinkState<>();
+ state1.setRouteState("route1", new byte[] {1, 2, 3});
+ state1.setRouteState("route2", new byte[] {4, 5, 6});
+
+ DemultiplexingSinkState state2 = new DemultiplexingSinkState<>();
+ state2.setRouteState("route1", new byte[] {1, 2, 3});
+ state2.setRouteState("route2", new byte[] {4, 5, 6});
+
+ assertThat(state1.equals(state2)).isTrue();
+ assertThat(state2.equals(state1)).isTrue();
+ }
+
+ @Test
+ void testHashCodeConsistency() {
+ DemultiplexingSinkState state1 = new DemultiplexingSinkState<>();
+ state1.setRouteState("route1", new byte[] {1, 2, 3});
+ state1.setRouteState("route2", new byte[] {4, 5, 6});
+
+ DemultiplexingSinkState state2 = new DemultiplexingSinkState<>();
+ state2.setRouteState("route1", new byte[] {1, 2, 3});
+ state2.setRouteState("route2", new byte[] {4, 5, 6});
+
+ // Equal objects must have equal hash codes
+ assertThat(state1.equals(state2)).isTrue();
+ assertThat(state1.hashCode()).isEqualTo(state2.hashCode());
+ }
+
+ @Test
+ void testHashCodeWithEmptyState() {
+ DemultiplexingSinkState state1 = new DemultiplexingSinkState<>();
+ DemultiplexingSinkState state2 = new DemultiplexingSinkState<>();
+
+ assertThat(state1.hashCode()).isEqualTo(state2.hashCode());
+ }
+
+ @Test
+ void testHashCodeWithDifferentStates() {
+ DemultiplexingSinkState state1 = new DemultiplexingSinkState<>();
+ state1.setRouteState("route1", new byte[] {1, 2, 3});
+
+ DemultiplexingSinkState state2 = new DemultiplexingSinkState<>();
+ state2.setRouteState("route1", new byte[] {4, 5, 6});
+
+ // Different objects should typically have different hash codes
+ assertThat(state1.hashCode()).isNotEqualTo(state2.hashCode());
+ }
+
+ @Test
+ void testToStringWithEmptyState() {
+ DemultiplexingSinkState state = new DemultiplexingSinkState<>();
+
+ String toString = state.toString();
+ assertThat(toString).contains("DemultiplexingSinkState");
+ assertThat(toString).contains("routeCount=0");
+ assertThat(toString).contains("routes=[]");
+ }
+
+ @Test
+ void testToStringWithSingleRoute() {
+ DemultiplexingSinkState state = new DemultiplexingSinkState<>();
+ state.setRouteState("route1", new byte[] {1, 2, 3});
+
+ String toString = state.toString();
+ assertThat(toString).contains("DemultiplexingSinkState");
+ assertThat(toString).contains("routeCount=1");
+ assertThat(toString).contains("route1");
+ }
+
+ @Test
+ void testToStringWithMultipleRoutes() {
+ DemultiplexingSinkState state = new DemultiplexingSinkState<>();
+ state.setRouteState("route1", new byte[] {1, 2, 3});
+ state.setRouteState("route2", new byte[] {4, 5, 6});
+
+ String toString = state.toString();
+ assertThat(toString).contains("DemultiplexingSinkState");
+ assertThat(toString).contains("routeCount=2");
+ assertThat(toString).contains("route1");
+ assertThat(toString).contains("route2");
+ }
+
+ /** A complex route key for testing serialization. */
+ private static class ComplexRouteKey implements java.io.Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private final String host;
+ private final int port;
+
+ public ComplexRouteKey(String host, int port) {
+ this.host = host;
+ this.port = port;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ComplexRouteKey that = (ComplexRouteKey) o;
+ return port == that.port && java.util.Objects.equals(host, that.host);
+ }
+
+ @Override
+ public int hashCode() {
+ return java.util.Objects.hash(host, port);
+ }
+
+ @Override
+ public String toString() {
+ return "ComplexRouteKey{host='" + host + "', port=" + port + '}';
+ }
+ }
+}
diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkTest.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkTest.java
new file mode 100644
index 0000000000000..33e956588416d
--- /dev/null
+++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkTest.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.base.sink;
+
+import org.apache.flink.api.connector.sink2.Sink;
+import org.apache.flink.api.connector.sink2.SinkWriter;
+import org.apache.flink.connector.base.sink.writer.TestSinkInitContext;
+
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link DemultiplexingSink}. */
+class DemultiplexingSinkTest {
+
+ @Test
+ void testSinkCreation() {
+ final TestSinkRouter router = new TestSinkRouter();
+ final DemultiplexingSink sink = new DemultiplexingSink<>(router);
+
+ assertThat(sink.getSinkRouter()).isSameAs(router);
+ }
+
+ @Test
+ void testSinkCreationWithNullRouter() {
+ assertThatThrownBy(() -> new DemultiplexingSink(null))
+ .isInstanceOf(NullPointerException.class)
+ .hasMessageContaining("sinkRouter must not be null");
+ }
+
+ @Test
+ void testCreateWriter() throws IOException {
+ final TestSinkRouter router = new TestSinkRouter();
+ final DemultiplexingSink sink = new DemultiplexingSink<>(router);
+ final TestSinkInitContext context = new TestSinkInitContext();
+
+ final SinkWriter writer = sink.createWriter(context);
+
+ assertThat(writer).isInstanceOf(DemultiplexingSinkWriter.class);
+ }
+
+ @Test
+ void testWriterStateSerializer() {
+ final TestSinkRouter router = new TestSinkRouter();
+ final DemultiplexingSink sink = new DemultiplexingSink<>(router);
+
+ assertThat(sink.getWriterStateSerializer()).isNotNull();
+ assertThat(sink.getWriterStateSerializer())
+ .isInstanceOf(DemultiplexingSinkStateSerializer.class);
+ }
+
+ @Test
+ void testRestoreWriter() {
+ final TestSinkRouter router = new TestSinkRouter();
+ final DemultiplexingSink sink = new DemultiplexingSink<>(router);
+ final TestSinkInitContext context = new TestSinkInitContext();
+
+ // Create some state to restore from
+ final DemultiplexingSinkState state1 = new DemultiplexingSinkState<>();
+ state1.setRouteState("a", new byte[] {1, 2, 3});
+ state1.setRouteState("b", new byte[] {4, 5, 6});
+
+ final DemultiplexingSinkState state2 = new DemultiplexingSinkState<>();
+ state2.setRouteState("c", new byte[] {7, 8, 9});
+
+ final java.util.List> recoveredStates =
+ java.util.Arrays.asList(state1, state2);
+
+ // Restore the writer with the states
+ final var restoredWriter = sink.restoreWriter(context, recoveredStates);
+
+ // Verify that the restored writer is the correct type
+ assertThat(restoredWriter).isInstanceOf(DemultiplexingSinkWriter.class);
+
+ // Verify that the writer was created successfully
+ assertThat(restoredWriter).isNotNull();
+ }
+
+ @Test
+ void testRestoreWriterWithEmptyState() {
+ final TestSinkRouter router = new TestSinkRouter();
+ final DemultiplexingSink sink = new DemultiplexingSink<>(router);
+ final TestSinkInitContext context = new TestSinkInitContext();
+
+ // Create an empty state
+ final DemultiplexingSinkState emptyState = new DemultiplexingSinkState<>();
+ final java.util.List> recoveredStates = List.of(emptyState);
+
+ // Restore the writer with empty state
+ final var restoredWriter = sink.restoreWriter(context, recoveredStates);
+
+ // Verify that the restored writer is created successfully even with empty state
+ assertThat(restoredWriter).isNotNull();
+ assertThat(restoredWriter).isInstanceOf(DemultiplexingSinkWriter.class);
+ }
+
+ /** Test implementation of {@link SinkRouter}. */
+ private static class TestSinkRouter implements SinkRouter {
+ private final AtomicInteger sinkCreationCount = new AtomicInteger(0);
+
+ @Override
+ public String getRoute(String element) {
+ // Route based on first character
+ return element.substring(0, 1);
+ }
+
+ @Override
+ public Sink createSink(String route, String element) {
+ sinkCreationCount.incrementAndGet();
+ return new TestSink(route);
+ }
+
+ public int getSinkCreationCount() {
+ return sinkCreationCount.get();
+ }
+ }
+
+ /** Test implementation of {@link Sink}. */
+ private static class TestSink implements Sink {
+ private final String route;
+
+ public TestSink(String route) {
+ this.route = route;
+ }
+
+ @Override
+ public SinkWriter createWriter(
+ org.apache.flink.api.connector.sink2.WriterInitContext context) {
+ return new TestSinkWriter(route);
+ }
+
+ public String getRoute() {
+ return route;
+ }
+ }
+
+ /** Test implementation of {@link SinkWriter}. */
+ private static class TestSinkWriter implements SinkWriter {
+ private final String route;
+ private final List elements = new ArrayList<>();
+ private boolean closed = false;
+
+ public TestSinkWriter(String route) {
+ this.route = route;
+ }
+
+ @Override
+ public void write(String element, Context context) {
+ if (closed) {
+ throw new IllegalStateException("Writer is closed");
+ }
+ elements.add(element);
+ }
+
+ @Override
+ public void flush(boolean endOfInput) {
+ // No-op for test
+ }
+
+ @Override
+ public void close() {
+ closed = true;
+ }
+
+ public List getElements() {
+ return new ArrayList<>(elements);
+ }
+
+ public String getRoute() {
+ return route;
+ }
+
+ public boolean isClosed() {
+ return closed;
+ }
+ }
+}
diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriterTest.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriterTest.java
new file mode 100644
index 0000000000000..54a03f72c9b3e
--- /dev/null
+++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriterTest.java
@@ -0,0 +1,430 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.base.sink;
+
+import org.apache.flink.api.common.eventtime.Watermark;
+import org.apache.flink.api.connector.sink2.Sink;
+import org.apache.flink.api.connector.sink2.SinkWriter;
+import org.apache.flink.api.connector.sink2.StatefulSinkWriter;
+import org.apache.flink.api.connector.sink2.SupportsWriterState;
+import org.apache.flink.api.connector.sink2.WriterInitContext;
+import org.apache.flink.connector.base.sink.writer.TestSinkInitContext;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link DemultiplexingSinkWriter}. */
+class DemultiplexingSinkWriterTest {
+
+ private TestSinkRouter router;
+ private TestSinkInitContext context;
+ private DemultiplexingSinkWriter writer;
+
+ @BeforeEach
+ void setUp() {
+ router = new TestSinkRouter();
+ context = new TestSinkInitContext();
+ writer = new DemultiplexingSinkWriter<>(router, context);
+ }
+
+ @Test
+ void testWriteToSingleRoute() throws IOException, InterruptedException {
+ // Write elements that all route to the same destination
+ writer.write("apple", createContext());
+ writer.write("avocado", createContext());
+ writer.write("apricot", createContext());
+
+ // Should have created only one sink
+ assertThat(router.getSinkCreationCount()).isEqualTo(1);
+ assertThat(writer.getActiveSinkWriterCount()).isEqualTo(1);
+ assertThat(writer.getActiveRoutes()).containsExactly("a");
+
+ // All elements should be in the same sink writer
+ TestSinkWriter sinkWriter = router.getSinkWriter("a");
+ assertThat(sinkWriter.getElements()).containsExactly("apple", "avocado", "apricot");
+ }
+
+ @Test
+ void testWriteToMultipleRoutes() throws IOException, InterruptedException {
+ // Write elements that route to different destinations ("apple" and "apricot" should resolve
+ // to the same)
+ writer.write("apple", createContext());
+ writer.write("banana", createContext());
+ writer.write("cherry", createContext());
+ writer.write("apricot", createContext());
+
+ // Should have created three sinks (a, b, c)
+ assertThat(router.getSinkCreationCount()).isEqualTo(3);
+ assertThat(writer.getActiveSinkWriterCount()).isEqualTo(3);
+ assertThat(writer.getActiveRoutes()).containsExactlyInAnyOrder("a", "b", "c");
+
+ // Check elements are routed correctly
+ assertThat(router.getSinkWriter("a").getElements()).containsExactly("apple", "apricot");
+ assertThat(router.getSinkWriter("b").getElements()).containsExactly("banana");
+ assertThat(router.getSinkWriter("c").getElements()).containsExactly("cherry");
+ }
+
+ @Test
+ void testFlush() throws IOException, InterruptedException {
+ // Write to multiple routes
+ writer.write("apple", createContext());
+ writer.write("banana", createContext());
+
+ // Flush should be called on all sink writers
+ writer.flush(false);
+
+ // Verify flush was called (our test implementation tracks this)
+ assertThat(router.getSinkWriter("a").getFlushCount()).isEqualTo(1);
+ assertThat(router.getSinkWriter("b").getFlushCount()).isEqualTo(1);
+ }
+
+ @Test
+ void testWriteWatermark() throws IOException, InterruptedException {
+ // Write to multiple routes
+ writer.write("apple", createContext());
+ writer.write("banana", createContext());
+
+ // Write watermark should be propagated to all sink writers
+ Watermark watermark = new Watermark(12345L);
+ writer.writeWatermark(watermark);
+
+ // Verify watermark was written (our test implementation tracks this)
+ assertThat(router.getSinkWriter("a").getWatermarksReceived()).containsExactly(watermark);
+ assertThat(router.getSinkWriter("b").getWatermarksReceived()).containsExactly(watermark);
+ }
+
+ @Test
+ void testClose() throws Exception {
+ // Write to multiple routes
+ writer.write("apple", createContext());
+ writer.write("banana", createContext());
+
+ // Close should close all sink writers
+ writer.close();
+
+ // Verify all writers were closed
+ assertThat(router.getSinkWriter("a").isClosed()).isTrue();
+ assertThat(router.getSinkWriter("b").isClosed()).isTrue();
+
+ // Active count should be zero after close
+ assertThat(writer.getActiveSinkWriterCount()).isEqualTo(0);
+ }
+
+ @Test
+ void testSnapshotState() throws IOException, InterruptedException {
+ // Write to multiple routes
+ writer.write("apple", createContext());
+ writer.write("banana", createContext());
+
+ // Snapshot state
+ List> states = writer.snapshotState(1L);
+
+ // Should return state (even if empty for our test implementation)
+ assertThat(states).isNotNull();
+ }
+
+ @Test
+ void testJavaSerializationFallback() throws Exception {
+ // Create a router that creates sinks without proper state serializers
+ JavaSerializationFallbackRouter fallbackRouter = new JavaSerializationFallbackRouter();
+ DemultiplexingSinkWriter writer =
+ new DemultiplexingSinkWriter<>(fallbackRouter, context);
+
+ // Write elements to create stateful sink writers
+ writer.write("apple", createContext());
+ writer.write("banana", createContext());
+
+ // Verify sink writers were created
+ assertThat(writer.getActiveSinkWriterCount()).isEqualTo(2);
+
+ // Add some state to the stateful sink writers
+ JavaSerializationFallbackSink sinkA = fallbackRouter.getCreatedSink("a");
+ JavaSerializationFallbackSink sinkB = fallbackRouter.getCreatedSink("b");
+
+ assertThat(sinkA).isNotNull();
+ assertThat(sinkB).isNotNull();
+
+ // The stateful writers should have received the elements
+ assertThat(sinkA.getCreatedWriter().getElements()).containsExactly("apple");
+ assertThat(sinkB.getCreatedWriter().getElements()).containsExactly("banana");
+
+ // Snapshot state - this should trigger Java serialization fallback
+ // since our test sink has a failing state serializer
+ List> snapshotStates = writer.snapshotState(1L);
+
+ // Should have successfully created state using Java serialization fallback
+ assertThat(snapshotStates).hasSize(1);
+ DemultiplexingSinkState state = snapshotStates.get(0);
+ assertThat(state.size()).isEqualTo(2);
+ assertThat(state.getRoutes()).containsExactlyInAnyOrder("a", "b");
+
+ writer.close();
+ }
+
+ /** Test implementation of {@link SinkWriter.Context}. */
+ private SinkWriter.Context createContext() {
+ return new SinkWriter.Context() {
+ @Override
+ public long currentWatermark() {
+ return 0;
+ }
+
+ @Override
+ public Long timestamp() {
+ return null;
+ }
+ };
+ }
+
+ /** Test implementation of {@link SinkRouter}. */
+ private static class TestSinkRouter implements SinkRouter {
+ private final AtomicInteger sinkCreationCount = new AtomicInteger(0);
+ private final List createdSinks = new ArrayList<>();
+
+ @Override
+ public String getRoute(String element) {
+ // Route based on first character
+ return element.substring(0, 1);
+ }
+
+ @Override
+ public Sink createSink(String route, String element) {
+ sinkCreationCount.incrementAndGet();
+ TestSink sink = new TestSink(route);
+ createdSinks.add(sink);
+ return sink;
+ }
+
+ public int getSinkCreationCount() {
+ return sinkCreationCount.get();
+ }
+
+ public TestSinkWriter getSinkWriter(String route) {
+ return createdSinks.stream()
+ .filter(sink -> sink.getRoute().equals(route))
+ .findFirst()
+ .map(TestSink::getCreatedWriter)
+ .orElse(null);
+ }
+ }
+
+ /** Test implementation of {@link Sink}. */
+ private static class TestSink implements Sink {
+ private final String route;
+ private TestSinkWriter createdWriter;
+
+ public TestSink(String route) {
+ this.route = route;
+ }
+
+ @Override
+ public SinkWriter createWriter(
+ org.apache.flink.api.connector.sink2.WriterInitContext context) {
+ createdWriter = new TestSinkWriter(route);
+ return createdWriter;
+ }
+
+ public String getRoute() {
+ return route;
+ }
+
+ public TestSinkWriter getCreatedWriter() {
+ return createdWriter;
+ }
+ }
+
+ /** Test implementation of {@link SinkWriter}. */
+ private static class TestSinkWriter implements SinkWriter {
+ private final String route;
+ private final List elements = new ArrayList<>();
+ private final List watermarksReceived = new ArrayList<>();
+ private int flushCount = 0;
+ private boolean closed = false;
+
+ public TestSinkWriter(String route) {
+ this.route = route;
+ }
+
+ @Override
+ public void write(String element, Context context) {
+ if (closed) {
+ throw new IllegalStateException("Writer is closed");
+ }
+ elements.add(element);
+ }
+
+ @Override
+ public void flush(boolean endOfInput) {
+ flushCount++;
+ }
+
+ @Override
+ public void writeWatermark(Watermark watermark) {
+ watermarksReceived.add(watermark);
+ }
+
+ @Override
+ public void close() {
+ closed = true;
+ }
+
+ public List getElements() {
+ return new ArrayList<>(elements);
+ }
+
+ public List getWatermarksReceived() {
+ return new ArrayList<>(watermarksReceived);
+ }
+
+ public int getFlushCount() {
+ return flushCount;
+ }
+
+ public String getRoute() {
+ return route;
+ }
+
+ public boolean isClosed() {
+ return closed;
+ }
+ }
+
+ /** Test serialization fallback implementation of {@link SinkRouter}. */
+ private static class JavaSerializationFallbackRouter implements SinkRouter {
+ private final java.util.Map createdSinks =
+ new java.util.HashMap<>();
+
+ @Override
+ public String getRoute(String element) {
+ return element.substring(0, 1);
+ }
+
+ @Override
+ public Sink createSink(String route, String element) {
+ JavaSerializationFallbackSink sink = new JavaSerializationFallbackSink(route);
+ createdSinks.put(route, sink);
+ return sink;
+ }
+
+ public JavaSerializationFallbackSink getCreatedSink(String route) {
+ return createdSinks.get(route);
+ }
+ }
+
+ /** Test serialization fallback implementation of {@link Sink}. */
+ private static class JavaSerializationFallbackSink
+ implements Sink, SupportsWriterState {
+ private final String route;
+ private JavaSerializationFallbackWriter createdWriter;
+
+ public JavaSerializationFallbackSink(String route) {
+ this.route = route;
+ }
+
+ @Override
+ public SinkWriter createWriter(WriterInitContext context) {
+ createdWriter = new JavaSerializationFallbackWriter(route);
+ return createdWriter;
+ }
+
+ @Override
+ public StatefulSinkWriter restoreWriter(
+ WriterInitContext context, Collection recoveredState) {
+ createdWriter = new JavaSerializationFallbackWriter(route);
+ // Restore the state
+ for (String state : recoveredState) {
+ createdWriter.addElement(state);
+ }
+ return createdWriter;
+ }
+
+ @Override
+ public SimpleVersionedSerializer getWriterStateSerializer() {
+ // Return a serializer that will fail, forcing fallback to Java serialization
+ return new SimpleVersionedSerializer() {
+ @Override
+ public int getVersion() {
+ return 1;
+ }
+
+ @Override
+ public byte[] serialize(String obj) throws IOException {
+ throw new IOException("Intentional serialization failure to test fallback");
+ }
+
+ @Override
+ public String deserialize(int version, byte[] serialized) throws IOException {
+ throw new IOException("Intentional deserialization failure to test fallback");
+ }
+ };
+ }
+
+ public JavaSerializationFallbackWriter getCreatedWriter() {
+ return createdWriter;
+ }
+ }
+
+ /** Test serialization fallback implementation of {@link StatefulSinkWriter}. */
+ private static class JavaSerializationFallbackWriter
+ implements StatefulSinkWriter {
+ private final String route;
+ private final List elements = new ArrayList<>();
+
+ public JavaSerializationFallbackWriter(String route) {
+ this.route = route;
+ }
+
+ @Override
+ public void write(String element, Context context) {
+ elements.add(element);
+ }
+
+ @Override
+ public void flush(boolean endOfInput) {
+ // No-op for test
+ }
+
+ @Override
+ public void close() {
+ // No-op for test
+ }
+
+ @Override
+ public List snapshotState(long checkpointId) {
+ // Return the elements as state (this will be Java-serialized as fallback)
+ return new ArrayList<>(elements);
+ }
+
+ public List getElements() {
+ return new ArrayList<>(elements);
+ }
+
+ public void addElement(String element) {
+ elements.add(element);
+ }
+ }
+}