diff --git a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSink.java b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSink.java new file mode 100644 index 0000000000000..cbcfd035f12c5 --- /dev/null +++ b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSink.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.base.sink; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.connector.sink2.Sink; +import org.apache.flink.api.connector.sink2.SinkWriter; +import org.apache.flink.api.connector.sink2.StatefulSinkWriter; +import org.apache.flink.api.connector.sink2.SupportsWriterState; +import org.apache.flink.api.connector.sink2.WriterInitContext; +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Collection; + +/** + * A sink that dynamically routes elements to different underlying sinks based on a routing + * function. + * + *

The {@code DemultiplexingSink} allows elements to be routed to different sink instances at + * runtime based on the content of each element. This is useful for scenarios such as: + * + *

+ * + *

The sink maintains an internal cache of sink instances, creating new sinks on-demand when + * previously unseen routes are encountered. This provides good performance while supporting dynamic + * routing scenarios. + * + *

Example usage: + * + *

{@code
+ * // Route messages to different Kafka topics
+ * SinkRouter router = new SinkRouter() {
+ *     @Override
+ *     public String getRoute(MyMessage element) {
+ *         return element.getTopicName();
+ *     }
+ *
+ *     @Override
+ *     public Sink createSink(String topicName, MyMessage element) {
+ *         return KafkaSink.builder()
+ *             .setBootstrapServers("localhost:9092")
+ *             .setRecordSerializer(...)
+ *             .setTopics(topicName)
+ *             .build();
+ *     }
+ * };
+ *
+ * DemultiplexingSink demuxSink =
+ *     new DemultiplexingSink<>(router);
+ *
+ * dataStream.sinkTo(demuxSink);
+ * }
+ * + *

The sink supports checkpointing and recovery through the {@link SupportsWriterState} + * interface. State from all underlying sink writers is collected and restored appropriately during + * recovery. + * + * @param The type of input elements + * @param The type of route keys used for sink selection + */ +@PublicEvolving +public class DemultiplexingSink + implements Sink, SupportsWriterState> { + + private static final long serialVersionUID = 1L; + + /** The router that determines how elements are routed to sinks. */ + private final SinkRouter sinkRouter; + + /** + * Creates a new demultiplexing sink with the given router. + * + * @param sinkRouter The router that determines how elements are routed to different sinks + */ + public DemultiplexingSink(SinkRouter sinkRouter) { + this.sinkRouter = Preconditions.checkNotNull(sinkRouter, "sinkRouter must not be null"); + } + + @Override + public SinkWriter createWriter(WriterInitContext context) throws IOException { + return new DemultiplexingSinkWriter<>(sinkRouter, context); + } + + @Override + public StatefulSinkWriter> restoreWriter( + WriterInitContext context, Collection> recoveredState) { + + return new DemultiplexingSinkWriter<>(sinkRouter, context, recoveredState); + } + + @Override + public SimpleVersionedSerializer> getWriterStateSerializer() { + return new DemultiplexingSinkStateSerializer<>(); + } + + /** + * Gets the sink router used by this demultiplexing sink. + * + * @return The sink router + */ + public SinkRouter getSinkRouter() { + return sinkRouter; + } +} diff --git a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkState.java b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkState.java new file mode 100644 index 0000000000000..c35b3ce9d0801 --- /dev/null +++ b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkState.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.base.sink; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * State class for {@link DemultiplexingSink} that tracks the state of individual sink writers for + * each route during checkpointing and recovery. + * + *

This state contains: + * + *

+ * + * @param The type of route keys + */ +@PublicEvolving +public class DemultiplexingSinkState implements Serializable { + + private static final long serialVersionUID = 1L; + + /** Map of route keys to their serialized sink writer states. */ + private final Map routeStates; + + /** Creates a new empty demultiplexing sink state. */ + public DemultiplexingSinkState() { + this.routeStates = new HashMap<>(); + } + + /** + * Creates a new demultiplexing sink state with the given route states. + * + * @param routeStates Map of route keys to their serialized sink writer states + */ + public DemultiplexingSinkState(Map routeStates) { + this.routeStates = new HashMap<>(Preconditions.checkNotNull(routeStates)); + } + + /** + * Gets the serialized state for a specific route. + * + * @param route The route key + * @return The serialized state for the route, or null if no state exists + */ + public byte[] getRouteState(RouteT route) { + return routeStates.get(route); + } + + /** + * Sets the serialized state for a specific route. + * + * @param route The route key + * @param state The serialized state data + */ + public void setRouteState(RouteT route, byte[] state) { + if (state != null) { + routeStates.put(route, state); + } else { + routeStates.remove(route); + } + } + + /** + * Gets all route keys that have associated state. + * + * @return An unmodifiable set of route keys + */ + public java.util.Set getRoutes() { + return Collections.unmodifiableSet(routeStates.keySet()); + } + + /** + * Gets a copy of all route states. + * + * @return An unmodifiable map of route keys to their serialized states + */ + public Map getRouteStates() { + return Collections.unmodifiableMap(routeStates); + } + + /** + * Checks if this state contains any route states. + * + * @return true if there are no route states, false otherwise + */ + public boolean isEmpty() { + return routeStates.isEmpty(); + } + + /** + * Gets the number of routes with associated state. + * + * @return The number of routes + */ + public int size() { + return routeStates.size(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DemultiplexingSinkState that = (DemultiplexingSinkState) o; + + // Compare route states with proper byte array comparison + if (routeStates.size() != that.routeStates.size()) { + return false; + } + + for (Map.Entry entry : routeStates.entrySet()) { + RouteT key = entry.getKey(); + byte[] value = entry.getValue(); + byte[] otherValue = that.routeStates.get(key); + + if (!java.util.Arrays.equals(value, otherValue)) { + return false; + } + } + + return true; + } + + @Override + public int hashCode() { + int result = 1; + for (Map.Entry entry : routeStates.entrySet()) { + result = 31 * result + Objects.hashCode(entry.getKey()); + result = 31 * result + java.util.Arrays.hashCode(entry.getValue()); + } + return result; + } + + @Override + public String toString() { + return "DemultiplexingSinkState{" + + "routeCount=" + + routeStates.size() + + ", routes=" + + routeStates.keySet() + + '}'; + } +} diff --git a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializer.java b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializer.java new file mode 100644 index 0000000000000..36db571368b92 --- /dev/null +++ b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializer.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.base.sink; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.HashMap; +import java.util.Map; + +/** + * Serializer for {@link DemultiplexingSinkState}. + * + *

This serializer handles the serialization of the demultiplexing sink state, which contains a + * mapping of route keys to their corresponding sink writer states. The route keys are serialized + * using Java serialization, while the sink writer states are stored as raw byte arrays. + * + *

Serialization format: + * + *

    + *
  • Version (int) + *
  • Number of routes (int) + *
  • For each route: + *
      + *
    • Route key length (int) + *
    • Route key bytes (serialized using Java serialization) + *
    • State length (int) + *
    • State bytes + *
    + *
+ * + * @param The type of route keys + */ +@PublicEvolving +public class DemultiplexingSinkStateSerializer + implements SimpleVersionedSerializer> { + + private static final int VERSION = 1; + + @Override + public int getVersion() { + return VERSION; + } + + @Override + public byte[] serialize(DemultiplexingSinkState state) throws IOException { + final DataOutputSerializer out = new DataOutputSerializer(256); + + // Write the number of routes + final Map routeStates = state.getRouteStates(); + out.writeInt(routeStates.size()); + + // Write each route and its state + for (Map.Entry entry : routeStates.entrySet()) { + // Serialize the route key using Java serialization + final byte[] routeBytes = serializeRouteKey(entry.getKey()); + out.writeInt(routeBytes.length); + out.write(routeBytes); + + // Write the state bytes + final byte[] stateBytes = entry.getValue(); + out.writeInt(stateBytes.length); + out.write(stateBytes); + } + + return out.getCopyOfBuffer(); + } + + @Override + public DemultiplexingSinkState deserialize(int version, byte[] serialized) + throws IOException { + if (version != VERSION) { + throw new IOException( + "Unsupported version: " + version + ". Supported version: " + VERSION); + } + + final DataInputDeserializer in = new DataInputDeserializer(serialized); + + // Read the number of routes + final int numRoutes = in.readInt(); + final Map routeStates = new HashMap<>(numRoutes); + + // Read each route and its state + for (int i = 0; i < numRoutes; i++) { + // Read the route key + final int routeLength = in.readInt(); + final byte[] routeBytes = new byte[routeLength]; + in.readFully(routeBytes); + final RouteT route = deserializeRouteKey(routeBytes); + + // Read the state bytes + final int stateLength = in.readInt(); + final byte[] stateBytes = new byte[stateLength]; + in.readFully(stateBytes); + + // Store the route states + routeStates.put(route, stateBytes); + } + + return new DemultiplexingSinkState<>(routeStates); + } + + /** + * Serializes a route key using Java serialization. + * + * @param routeKey The route key to serialize + * @return The serialized route key bytes + * @throws IOException If serialization fails + */ + private byte[] serializeRouteKey(RouteT routeKey) throws IOException { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final ObjectOutputStream oos = new ObjectOutputStream(baos)) { + oos.writeObject(routeKey); + oos.flush(); + return baos.toByteArray(); + } catch (Exception e) { + throw new IOException("Failed to serialize route key: " + routeKey, e); + } + } + + /** + * Deserializes a route key using Java serialization. + * + * @param routeBytes The serialized route key bytes + * @return The deserialized route key + * @throws IOException If deserialization fails + */ + @SuppressWarnings("unchecked") + private RouteT deserializeRouteKey(byte[] routeBytes) throws IOException { + try (final ByteArrayInputStream bais = new ByteArrayInputStream(routeBytes); + final ObjectInputStream ois = new ObjectInputStream(bais)) { + return (RouteT) ois.readObject(); + } catch (Exception e) { + throw new IOException("Failed to deserialize route key", e); + } + } +} diff --git a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriter.java b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriter.java new file mode 100644 index 0000000000000..51c48b0e88cb3 --- /dev/null +++ b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriter.java @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.base.sink; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.eventtime.Watermark; +import org.apache.flink.api.connector.sink2.Sink; +import org.apache.flink.api.connector.sink2.SinkWriter; +import org.apache.flink.api.connector.sink2.StatefulSinkWriter; +import org.apache.flink.api.connector.sink2.SupportsWriterState; +import org.apache.flink.api.connector.sink2.WriterInitContext; +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A sink writer that routes elements to different underlying sink writers based on a routing + * function. + * + *

This writer maintains a cache of sink writers, creating new ones on-demand when previously + * unseen routes are encountered. Each underlying sink writer is managed independently, including + * their lifecycle, state management, and error handling. + * + *

The writer supports state management by collecting state from all underlying sink writers + * during checkpointing and restoring them appropriately during recovery. + * + * @param The type of input elements + * @param The type of route keys used for sink selection + */ +@PublicEvolving +public class DemultiplexingSinkWriter + implements StatefulSinkWriter> { + + private static final Logger LOG = LoggerFactory.getLogger(DemultiplexingSinkWriter.class); + + /** The router that determines how elements are routed to sinks. */ + private final SinkRouter sinkRouter; + + /** The writer initialization context. */ + private final WriterInitContext context; + + /** Cache of sink writers by route key. */ + private final Map> sinkWriters; + + /** Cache of sink instances by route key for creating writers. */ + private final Map> sinks; + + /** Cache of recovered states by route key for lazy restoration. */ + private final Map> recoveredStates; + + /** + * Creates a new demultiplexing sink writer. + * + * @param sinkRouter The router that determines how elements are routed to different sinks + * @param context The writer initialization context + */ + public DemultiplexingSinkWriter( + SinkRouter sinkRouter, WriterInitContext context) { + this.sinkRouter = Preconditions.checkNotNull(sinkRouter); + this.context = Preconditions.checkNotNull(context); + this.sinkWriters = new HashMap<>(); + this.sinks = new HashMap<>(); + this.recoveredStates = new HashMap<>(); + } + + /** + * Creates a new demultiplexing sink writer and restores state from a previous checkpoint. + * + * @param sinkRouter The router that determines how elements are routed to different sinks + * @param context The writer initialization context + * @param recoveredStates The recovered states from previous checkpoints + */ + public DemultiplexingSinkWriter( + SinkRouter sinkRouter, + WriterInitContext context, + Collection> recoveredStates) { + this(sinkRouter, context); + + // Process recovered states and prepare them for lazy restoration + for (DemultiplexingSinkState state : recoveredStates) { + for (RouteT route : state.getRoutes()) { + byte[] routeStateBytes = state.getRouteState(route); + if (routeStateBytes != null && routeStateBytes.length > 0) { + try { + // Deserialize the writer states for this route + List writerStates = deserializeWriterStates(route, routeStateBytes); + this.recoveredStates.put(route, writerStates); + LOG.debug( + "Prepared state restoration for route: {} with {} states", + route, + writerStates.size()); + } catch (Exception e) { + LOG.warn( + "Failed to deserialize state for route: {}, will start with empty state", + route, + e); + } + } + } + } + } + + @Override + public void write(InputT element, Context context) throws IOException, InterruptedException { + // Determine the route for this element + final RouteT route = sinkRouter.getRoute(element); + + // Get or create the sink writer for this route + SinkWriter writer = getOrCreateSinkWriter(route, element); + + // Delegate to the appropriate sink writer + writer.write(element, context); + } + + @Override + public void flush(boolean endOfInput) throws IOException, InterruptedException { + // Flush all active sink writers + IOException lastException = null; + for (Map.Entry> entry : sinkWriters.entrySet()) { + try { + entry.getValue().flush(endOfInput); + } catch (IOException e) { + LOG.warn("Failed to flush sink writer for route: {}", entry.getKey(), e); + lastException = e; + } + } + + // Re-throw the last exception if any occurred + if (lastException != null) { + throw lastException; + } + } + + @Override + public void writeWatermark(Watermark watermark) throws IOException, InterruptedException { + // Propagate watermark to all active sink writers + IOException lastException = null; + for (Map.Entry> entry : sinkWriters.entrySet()) { + try { + entry.getValue().writeWatermark(watermark); + } catch (IOException e) { + LOG.warn( + "Failed to write watermark to sink writer for route: {}", + entry.getKey(), + e); + lastException = e; + } + } + + // Re-throw the last exception if any occurred + if (lastException != null) { + throw lastException; + } + } + + @Override + public List> snapshotState(long checkpointId) + throws IOException { + final Map routeStates = new HashMap<>(); + + // Collect state from all active sink writers + for (Map.Entry> entry : sinkWriters.entrySet()) { + final RouteT route = entry.getKey(); + final SinkWriter writer = entry.getValue(); + + // Only collect state from stateful sink writers + if (writer instanceof StatefulSinkWriter) { + StatefulSinkWriter statefulWriter = + (StatefulSinkWriter) writer; + + try { + List writerStates = statefulWriter.snapshotState(checkpointId); + if (!writerStates.isEmpty()) { + // Serialize the writer states and store them + byte[] serializedState = serializeWriterStates(route, writerStates); + routeStates.put(route, serializedState); + } + } catch (Exception e) { + LOG.warn("Failed to snapshot state for route: {}", route, e); + throw new IOException("Failed to snapshot state for route: " + route, e); + } + } + } + + // Return a single state object containing all route states + final List> states = new ArrayList<>(); + if (!routeStates.isEmpty()) { + states.add(new DemultiplexingSinkState<>(routeStates)); + } + + return states; + } + + @Override + public void close() throws Exception { + // Close all active sink writers + Exception lastException = null; + for (Map.Entry> entry : sinkWriters.entrySet()) { + try { + entry.getValue().close(); + } catch (Exception e) { + LOG.warn("Failed to close sink writer for route: {}", entry.getKey(), e); + lastException = e; + } + } + + // Clear the caches + sinkWriters.clear(); + sinks.clear(); + + // Re-throw the last exception if any occurred + if (lastException != null) { + throw lastException; + } + } + + /** + * Gets or creates a sink writer for the given route. + * + * @param route The route key + * @param element The element that triggered this route (used for sink creation) + * @return The sink writer for the route + * @throws IOException If sink or writer creation fails + */ + private SinkWriter getOrCreateSinkWriter(RouteT route, InputT element) + throws IOException { + SinkWriter writer = sinkWriters.get(route); + if (writer == null) { + // Create a new sink for this route + Sink sink = sinkRouter.createSink(route, element); + sinks.put(route, sink); + + // Check if we have recovered state for this route + List routeRecoveredStates = recoveredStates.remove(route); + + if (routeRecoveredStates != null && !routeRecoveredStates.isEmpty()) { + // Restore writer with state if the sink supports it + if (sink instanceof SupportsWriterState) { + @SuppressWarnings("unchecked") + SupportsWriterState statefulSink = + (SupportsWriterState) sink; + + try { + // Check if we need to deserialize raw bytes using the sink's serializer + List processedStates = + processRecoveredStates(statefulSink, routeRecoveredStates); + + writer = statefulSink.restoreWriter(context, processedStates); + LOG.debug( + "Restored sink writer for route: {} with {} states", + route, + processedStates.size()); + } catch (Exception e) { + LOG.warn( + "Failed to restore writer state for route: {}, creating new writer", + route, + e); + writer = sink.createWriter(context); + } + } else { + // Sink doesn't support state, just create a new writer + writer = sink.createWriter(context); + LOG.debug("Sink for route {} doesn't support state, created new writer", route); + } + } else { + // No recovered state, create a new writer + writer = sink.createWriter(context); + LOG.debug("Created new sink writer for route: {} (no recovered state)", route); + } + + sinkWriters.put(route, writer); + } + return writer; + } + + /** + * Processes recovered states, deserializing raw bytes if necessary. + * + * @param statefulSink The stateful sink that can provide a state serializer + * @param recoveredStates The recovered states (may contain raw bytes) + * @return Processed states ready for restoration + */ + private List processRecoveredStates( + SupportsWriterState statefulSink, List recoveredStates) { + + List processedStates = new ArrayList<>(); + SimpleVersionedSerializer serializer = statefulSink.getWriterStateSerializer(); + + for (Object state : recoveredStates) { + if (state instanceof byte[]) { + // This is raw bytes that need to be deserialized using the sink's serializer + byte[] rawBytes = (byte[]) state; + try { + List deserializedStates = + deserializeWithSinkSerializer(serializer, rawBytes); + processedStates.addAll(deserializedStates); + LOG.debug( + "Successfully deserialized {} states from raw bytes using sink serializer", + deserializedStates.size()); + } catch (Exception e) { + LOG.warn( + "Failed to deserialize raw bytes using sink serializer, skipping state", + e); + } + } else { + // This is already a deserialized state object + processedStates.add(state); + } + } + + return processedStates; + } + + /** + * Deserializes states using the sink's state serializer. + * + * @param serializer The sink's state serializer + * @param rawBytes The raw serialized bytes + * @return List of deserialized state objects + * @throws IOException If deserialization fails + */ + private List deserializeWithSinkSerializer( + SimpleVersionedSerializer serializer, byte[] rawBytes) throws IOException { + + List states = new ArrayList<>(); + + try (final ByteArrayInputStream bais = new ByteArrayInputStream(rawBytes); + final java.io.DataInputStream dis = new java.io.DataInputStream(bais)) { + + // Read the number of states + int numStates = dis.readInt(); + + // Read each state + for (int i = 0; i < numStates; i++) { + int stateLength = dis.readInt(); + byte[] stateBytes = new byte[stateLength]; + dis.readFully(stateBytes); + + Object state = serializer.deserialize(serializer.getVersion(), stateBytes); + states.add(state); + } + } + + return states; + } + + /** + * Serializes the writer states for a given route. + * + *

This implementation attempts to use the proper state serializer from the underlying sink + * if available, otherwise falls back to Java serialization. + * + * @param route The route key + * @param writerStates The writer states to serialize + * @return The serialized state bytes + * @throws IOException If serialization fails + */ + private byte[] serializeWriterStates(RouteT route, List writerStates) throws IOException { + if (writerStates == null || writerStates.isEmpty()) { + return new byte[0]; + } + + Sink sink = sinks.get(route); + if (sink instanceof SupportsWriterState) { + try { + @SuppressWarnings("unchecked") + SupportsWriterState statefulSink = + (SupportsWriterState) sink; + SimpleVersionedSerializer serializer = + statefulSink.getWriterStateSerializer(); + + // Serialize each state and combine them + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final DataOutputStream dos = new DataOutputStream(baos)) { + + // Write the number of states + dos.writeInt(writerStates.size()); + + // Write each state + for (Object state : writerStates) { + byte[] stateBytes = serializer.serialize(state); + dos.writeInt(stateBytes.length); + dos.write(stateBytes); + } + + dos.flush(); + return baos.toByteArray(); + } + } catch (Exception e) { + LOG.warn( + "Failed to serialize state using sink serializer for route: {}, " + + "falling back to Java serialization", + route, + e); + } + } + + // Fallback to Java serialization + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final ObjectOutputStream oos = new ObjectOutputStream(baos)) { + oos.writeObject(writerStates); + oos.flush(); + return baos.toByteArray(); + } catch (Exception e) { + throw new IOException("Failed to serialize writer states for route: " + route, e); + } + } + + /** + * Deserializes the writer states for a given route. + * + *

This method attempts to deserialize states using Java serialization as a fallback. The + * proper sink-specific deserialization will happen later when the sink is created and we can + * access its state serializer. + * + * @param route The route key + * @param stateBytes The serialized state bytes + * @return The deserialized writer states (or raw bytes to be deserialized later) + */ + private List deserializeWriterStates(RouteT route, byte[] stateBytes) { + if (stateBytes == null || stateBytes.length == 0) { + return new ArrayList<>(); + } + + // Try Java deserialization first (this handles the fallback case) + try (final ByteArrayInputStream bais = new ByteArrayInputStream(stateBytes); + final ObjectInputStream ois = new ObjectInputStream(bais)) { + @SuppressWarnings("unchecked") + List states = (List) ois.readObject(); + LOG.debug( + "Successfully deserialized {} states for route {} using Java serialization", + states.size(), + route); + return states; + } catch (Exception e) { + LOG.debug( + "Java deserialization failed for route {}, will store raw bytes for later processing", + route); + // If Java deserialization fails, store the raw bytes + // They will be processed when we have access to the sink's serializer + List rawStates = new ArrayList<>(); + rawStates.add(stateBytes); // Store as raw bytes + return rawStates; + } + } + + /** + * Gets the number of currently active sink writers. + * + * @return The number of active sink writers + */ + public int getActiveSinkWriterCount() { + return sinkWriters.size(); + } + + /** + * Gets the routes for all currently active sink writers. + * + * @return A collection of route keys for active sink writers + */ + public Collection getActiveRoutes() { + return new ArrayList<>(sinkWriters.keySet()); + } +} diff --git a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/SinkRouter.java b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/SinkRouter.java new file mode 100644 index 0000000000000..10837e628bb53 --- /dev/null +++ b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/sink/SinkRouter.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.base.sink; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.connector.sink2.Sink; + +import java.io.Serializable; + +/** + * Interface for routing elements to different sinks in a {@link DemultiplexingSink}. + * + *

The router is responsible for two key operations: + * + *

    + *
  • Extracting a route key from each input element that determines which sink to use + *
  • Creating new sink instances when a previously unseen route is encountered + *
+ * + *

Route keys should be deterministic and consistent - the same logical destination should always + * produce the same route key to ensure proper sink reuse and state management. + * + *

Example usage: + * + *

{@code
+ * // Route messages to different Kafka topics based on message type
+ * SinkRouter router = new SinkRouter() {
+ *     @Override
+ *     public String getRoute(MyMessage element) {
+ *         return element.getMessageType(); // e.g., "orders", "users", "events"
+ *     }
+ *
+ *     @Override
+ *     public Sink createSink(String route, MyMessage element) {
+ *         return KafkaSink.builder()
+ *             .setBootstrapServers("localhost:9092")
+ *             .setRecordSerializer(...)
+ *             .setTopics(route) // Use route as topic name
+ *             .build();
+ *     }
+ * };
+ * }
+ * + * @param The type of input elements to route + * @param The type of route keys used for sink selection and caching + */ +@PublicEvolving +public interface SinkRouter extends Serializable { + + /** + * Extract the route key from an input element. + * + *

This method is called for every element and should be efficient. The returned route key is + * used to: + * + *

    + *
  • Look up the appropriate sink instance in the cache + *
  • Create a new sink if this route hasn't been seen before + *
  • Group elements by route for state management during checkpointing + *
+ * + *

Route keys must implement {@link Object#equals(Object)} and {@link Object#hashCode()} + * properly as they are used as keys in hash-based collections. + * + * @param element The input element to route + * @return The route key that determines which sink to use for this element + */ + RouteT getRoute(InputT element); + + /** + * Create a new sink instance for the given route. + * + *

This method is called when a route is encountered for the first time. The created sink + * will be cached and reused for all subsequent elements with the same route key. + * + *

The element parameter provides access to the specific element that triggered the creation + * of this route, which can be useful for extracting configuration information (e.g., connection + * details, authentication credentials) that may be embedded in the element. + * + *

The created sink should be fully configured and ready to use. It will be initialized by + * the DemultiplexingSink framework using the standard Sink API. + * + * @param route The route key for which to create a sink + * @param element The element that triggered the creation of this route (for configuration + * extraction) + * @return A new sink instance configured for the given route + */ + Sink createSink(RouteT route, InputT element); +} diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkIT.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkIT.java new file mode 100644 index 0000000000000..d1755518d61dd --- /dev/null +++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkIT.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.base.sink; + +import org.apache.flink.api.connector.sink2.SinkWriter; +import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.test.junit5.MiniClusterExtension; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Integration tests for {@link DemultiplexingSink}. */ +class DemultiplexingSinkIT { + + @RegisterExtension + private static final MiniClusterExtension MINI_CLUSTER_RESOURCE = + new MiniClusterExtension( + new MiniClusterResourceConfiguration.Builder() + .setNumberTaskManagers(1) + .setNumberSlotsPerTaskManager(1) + .build()); + + final StreamExecutionEnvironment env = + StreamExecutionEnvironment.getExecutionEnvironment().setParallelism(1); + + @Test + void testBasicElementRouting() throws Exception { + // Create a simple router that routes by first character + final TestSinkRouter router = new TestSinkRouter(); + + // Create the demultiplexing sink + final DemultiplexingSink demuxSink = new DemultiplexingSink<>(router); + + // Create a data stream with elements that will route to different sinks + env.fromData("apple", "banana", "apricot", "blueberry", "avocado", "cherry") + .sinkTo(demuxSink); + + // Execute the job + env.execute("Integration Test: DemultiplexingSinkIT"); + + // Verify results (should contain three routes: "a", "b", "c") + ConcurrentMap> results = TestSinkRouter.getResults(); + assertThat(results).hasSize(3); + assertThat(results.get("a")).containsExactlyInAnyOrder("apple", "apricot", "avocado"); + assertThat(results.get("b")).containsExactlyInAnyOrder("banana", "blueberry"); + assertThat(results.get("c")).containsExactlyInAnyOrder("cherry"); + } + + /** A serializable test router that routes by first character. */ + private static class TestSinkRouter implements SinkRouter { + private static final long serialVersionUID = 1L; + + // Static shared results map for collecting test results + private static final ConcurrentMap> results = + new ConcurrentHashMap<>(); + + @Override + public String getRoute(String element) { + return element.substring(0, 1); + } + + @Override + public org.apache.flink.api.connector.sink2.Sink createSink( + String route, String element) { + return new CollectingSink(route, results); + } + + public static ConcurrentMap> getResults() { + return results; + } + + public static void clearResults() { + results.clear(); + } + } + + /** A simple collecting sink for testing. */ + private static class CollectingSink + implements org.apache.flink.api.connector.sink2.Sink { + private static final long serialVersionUID = 1L; + + private final String route; + private final ConcurrentMap> results; + + public CollectingSink(String route, ConcurrentMap> results) { + this.route = route; + this.results = results; + } + + @Override + public org.apache.flink.api.connector.sink2.SinkWriter createWriter( + org.apache.flink.api.connector.sink2.WriterInitContext context) { + return new CollectingSinkWriter(route, results); + } + } + + /** A simple collecting sink writer for testing. */ + private static class CollectingSinkWriter implements SinkWriter, Serializable { + private static final long serialVersionUID = 1L; + + private final String route; + private final ConcurrentMap> results; + + public CollectingSinkWriter(String route, ConcurrentMap> results) { + this.route = route; + this.results = results; + } + + @Override + public void write(String element, Context context) { + results.computeIfAbsent(route, k -> new ArrayList<>()).add(element); + } + + @Override + public void flush(boolean endOfInput) { + // No-op for this test + } + + @Override + public void close() { + // No-op for this test + } + } +} diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateManagementTest.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateManagementTest.java new file mode 100644 index 0000000000000..7dafd2e44ce63 --- /dev/null +++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateManagementTest.java @@ -0,0 +1,352 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.base.sink; + +import org.apache.flink.api.connector.sink2.Sink; +import org.apache.flink.api.connector.sink2.SinkWriter; +import org.apache.flink.api.connector.sink2.StatefulSinkWriter; +import org.apache.flink.api.connector.sink2.SupportsWriterState; +import org.apache.flink.api.connector.sink2.WriterInitContext; +import org.apache.flink.connector.base.sink.writer.TestSinkInitContext; +import org.apache.flink.core.io.SimpleVersionedSerializer; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for state management in {@link DemultiplexingSinkWriter}. */ +class DemultiplexingSinkStateManagementTest { + + private TestStatefulSinkRouter router; + private TestSinkInitContext context; + + @BeforeEach + void setUp() { + router = new TestStatefulSinkRouter(); + context = new TestSinkInitContext(); + } + + @Test + void testSnapshotAndRestoreState() throws Exception { + // Create writer and write to multiple routes + DemultiplexingSinkWriter writer = + new DemultiplexingSinkWriter<>(router, context); + + // Write elements to create sink writers ("apple" and "apricot" resolve to same sink) + writer.write("apple", createContext()); + writer.write("banana", createContext()); + writer.write("apricot", createContext()); + + // Verify sink writers were created + assertThat(writer.getActiveSinkWriterCount()).isEqualTo(2); + assertThat(writer.getActiveRoutes()).containsExactlyInAnyOrder("a", "b"); + + // Add some state to the sink writers + TestStatefulSinkWriter writerA = router.getStatefulSinkWriter("a"); + TestStatefulSinkWriter writerB = router.getStatefulSinkWriter("b"); + writerA.addState("state-a-1"); + writerA.addState("state-a-2"); + writerB.addState("state-b-1"); + + // Verify state was added + assertThat(writerA.getStates()).containsExactly("state-a-1", "state-a-2"); + assertThat(writerB.getStates()).containsExactly("state-b-1"); + + // Snapshot state + List> snapshotStates = writer.snapshotState(1L); + assertThat(snapshotStates).hasSize(1); + + DemultiplexingSinkState state = snapshotStates.get(0); + assertThat(state.getRoutes()).containsExactlyInAnyOrder("a", "b"); + + // Close the original writer + writer.close(); + + // Create a new router for the restored writer to avoid confusion with old sinks + TestStatefulSinkRouter restoredRouter = new TestStatefulSinkRouter(); + + // Create a new writer with restored state + DemultiplexingSinkWriter restoredWriter = + new DemultiplexingSinkWriter<>(restoredRouter, context, snapshotStates); + + // Write to the same routes to trigger restoration ("avocado" -> "a", "blueberry" -> "b") + restoredWriter.write("avocado", createContext()); + restoredWriter.write("blueberry", createContext()); + + // Verify that state was restored + TestStatefulSinkWriter restoredWriterA = restoredRouter.getStatefulSinkWriter("a"); + TestStatefulSinkWriter restoredWriterB = restoredRouter.getStatefulSinkWriter("b"); + + // Verify that the writers were restored from state + assertThat(restoredWriterA.wasRestoredFromState()).isTrue(); + assertThat(restoredWriterB.wasRestoredFromState()).isTrue(); + + // Verify that the restored state contains the expected data + assertThat(restoredWriterA.getStates()).containsExactly("state-a-1", "state-a-2"); + assertThat(restoredWriterB.getStates()).containsExactly("state-b-1"); + + restoredWriter.close(); + } + + @Test + void testSnapshotEmptyState() throws Exception { + DemultiplexingSinkWriter writer = + new DemultiplexingSinkWriter<>(router, context); + + // Snapshot without any active writers + List> states = writer.snapshotState(1L); + assertThat(states).isEmpty(); + + writer.close(); + } + + @Test + void testRestoreWithNonStatefulSinks() throws Exception { + // Use a router that creates non-stateful sinks + TestNonStatefulSinkRouter nonStatefulRouter = new TestNonStatefulSinkRouter(); + + // Create some dummy state + DemultiplexingSinkState dummyState = new DemultiplexingSinkState<>(); + dummyState.setRouteState("a", new byte[] {1, 2, 3}); + + // Create writer with restored state + DemultiplexingSinkWriter writer = + new DemultiplexingSinkWriter<>(nonStatefulRouter, context, List.of(dummyState)); + + // Write to trigger sink creation + writer.write("apple", createContext()); + + // Should work fine even though the sink doesn't support state + assertThat(writer.getActiveSinkWriterCount()).isEqualTo(1); + + writer.close(); + } + + private SinkWriter.Context createContext() { + return new SinkWriter.Context() { + @Override + public long currentWatermark() { + return 0; + } + + @Override + public Long timestamp() { + return null; + } + }; + } + + /** Test router that creates stateful sinks. */ + private static class TestStatefulSinkRouter implements SinkRouter { + private final AtomicInteger sinkCreationCount = new AtomicInteger(0); + private final List createdSinks = new ArrayList<>(); + + @Override + public String getRoute(String element) { + return element.substring(0, 1); + } + + @Override + public Sink createSink(String route, String element) { + sinkCreationCount.incrementAndGet(); + TestStatefulSink sink = new TestStatefulSink(route); + createdSinks.add(sink); + return sink; + } + + public TestStatefulSinkWriter getStatefulSinkWriter(String route) { + return createdSinks.stream() + .filter(sink -> sink.getRoute().equals(route)) + .findFirst() + .map(TestStatefulSink::getCreatedWriter) + .orElse(null); + } + } + + /** Test router that creates non-stateful sinks. */ + private static class TestNonStatefulSinkRouter implements SinkRouter { + @Override + public String getRoute(String element) { + return element.substring(0, 1); + } + + @Override + public Sink createSink(String route, String element) { + return new TestNonStatefulSink(route); + } + } + + /** Test stateful sink implementation. */ + private static class TestStatefulSink + implements Sink, SupportsWriterState { + private final String route; + private TestStatefulSinkWriter createdWriter; + + public TestStatefulSink(String route) { + this.route = route; + } + + @Override + public SinkWriter createWriter(WriterInitContext context) { + createdWriter = new TestStatefulSinkWriter(route); + return createdWriter; + } + + @Override + public StatefulSinkWriter restoreWriter( + WriterInitContext context, Collection recoveredState) { + createdWriter = new TestStatefulSinkWriter(route, recoveredState); + return createdWriter; + } + + @Override + public SimpleVersionedSerializer getWriterStateSerializer() { + return new TestStringSerializer(); + } + + public String getRoute() { + return route; + } + + public TestStatefulSinkWriter getCreatedWriter() { + return createdWriter; + } + } + + /** Test non-stateful sink implementation. */ + private static class TestNonStatefulSink implements Sink { + private final String route; + + public TestNonStatefulSink(String route) { + this.route = route; + } + + @Override + public SinkWriter createWriter(WriterInitContext context) { + return new TestNonStatefulSinkWriter(route); + } + } + + /** Test stateful sink writer. */ + private static class TestStatefulSinkWriter implements StatefulSinkWriter { + private final String route; + private final List elements = new ArrayList<>(); + private final List states = new ArrayList<>(); + private final boolean restoredFromState; + + public TestStatefulSinkWriter(String route) { + this.route = route; + this.restoredFromState = false; + } + + public TestStatefulSinkWriter(String route, Collection recoveredStates) { + this.route = route; + this.states.addAll(recoveredStates); + this.restoredFromState = true; + } + + @Override + public void write(String element, Context context) { + elements.add(element); + } + + @Override + public void flush(boolean endOfInput) { + // No-op + } + + @Override + public void close() { + // No-op + } + + @Override + public List snapshotState(long checkpointId) { + return new ArrayList<>(states); + } + + public void addState(String state) { + states.add(state); + } + + public boolean wasRestoredFromState() { + return restoredFromState; + } + + public List getElements() { + return new ArrayList<>(elements); + } + + public List getStates() { + return new ArrayList<>(states); + } + } + + /** Test non-stateful sink writer. */ + private static class TestNonStatefulSinkWriter implements SinkWriter { + private final String route; + private final List elements = new ArrayList<>(); + + public TestNonStatefulSinkWriter(String route) { + this.route = route; + } + + @Override + public void write(String element, Context context) { + elements.add(element); + } + + @Override + public void flush(boolean endOfInput) { + // No-op + } + + @Override + public void close() { + // No-op + } + + public List getElements() { + return new ArrayList<>(elements); + } + } + + /** Simple string serializer for testing. */ + private static class TestStringSerializer implements SimpleVersionedSerializer { + @Override + public int getVersion() { + return 1; + } + + @Override + public byte[] serialize(String obj) { + return obj.getBytes(); + } + + @Override + public String deserialize(int version, byte[] serialized) { + return new String(serialized); + } + } +} diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializerTest.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializerTest.java new file mode 100644 index 0000000000000..c1eb9ea23d848 --- /dev/null +++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkStateSerializerTest.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.base.sink; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link DemultiplexingSinkStateSerializer}. */ +class DemultiplexingSinkStateSerializerTest { + + @Test + void testSerializeDeserializeEmptyState() throws IOException { + final DemultiplexingSinkStateSerializer serializer = + new DemultiplexingSinkStateSerializer<>(); + final DemultiplexingSinkState originalState = new DemultiplexingSinkState<>(); + + final byte[] serialized = serializer.serialize(originalState); + final DemultiplexingSinkState deserializedState = + serializer.deserialize(serializer.getVersion(), serialized); + + assertThat(deserializedState).isEqualTo(originalState); + assertThat(deserializedState.isEmpty()).isTrue(); + } + + @Test + void testSerializeDeserializeStateWithRoutes() throws IOException { + final DemultiplexingSinkStateSerializer serializer = + new DemultiplexingSinkStateSerializer<>(); + + // Create state with multiple routes ("route3" -> empty state) + final Map routeStates = new HashMap<>(); + routeStates.put("route1", new byte[] {1, 2, 3}); + routeStates.put("route2", new byte[] {4, 5, 6, 7}); + routeStates.put("route3", new byte[0]); + + final DemultiplexingSinkState originalState = + new DemultiplexingSinkState<>(routeStates); + + final byte[] serialized = serializer.serialize(originalState); + final DemultiplexingSinkState deserializedState = + serializer.deserialize(serializer.getVersion(), serialized); + + assertThat(deserializedState).isEqualTo(originalState); + assertThat(deserializedState.getRoutes()) + .containsExactlyInAnyOrder("route1", "route2", "route3"); + assertThat(deserializedState.getRouteState("route1")).containsExactly(1, 2, 3); + assertThat(deserializedState.getRouteState("route2")).containsExactly(4, 5, 6, 7); + assertThat(deserializedState.getRouteState("route3")).isEmpty(); + } + + @Test + void testSerializeDeserializeWithComplexRouteKeys() throws IOException { + final DemultiplexingSinkStateSerializer serializer = + new DemultiplexingSinkStateSerializer<>(); + + // Create state with complex route keys + final Map routeStates = new HashMap<>(); + routeStates.put(new ComplexRouteKey("cluster1", 9092), new byte[] {1, 2}); + routeStates.put(new ComplexRouteKey("cluster2", 9093), new byte[] {3, 4}); + + final DemultiplexingSinkState originalState = + new DemultiplexingSinkState<>(routeStates); + + final byte[] serialized = serializer.serialize(originalState); + final DemultiplexingSinkState deserializedState = + serializer.deserialize(serializer.getVersion(), serialized); + + assertThat(deserializedState).isEqualTo(originalState); + assertThat(deserializedState.size()).isEqualTo(2); + } + + @Test + void testDeserializeWithWrongVersion() { + final DemultiplexingSinkStateSerializer serializer = + new DemultiplexingSinkStateSerializer<>(); + final byte[] serialized = new byte[] {1, 2, 3, 4}; + + assertThatThrownBy(() -> serializer.deserialize(999, serialized)) + .isInstanceOf(IOException.class) + .hasMessageContaining("Unsupported version: 999"); + } + + @Test + void testGetVersion() { + final DemultiplexingSinkStateSerializer serializer = + new DemultiplexingSinkStateSerializer<>(); + + assertThat(serializer.getVersion()).isEqualTo(1); + } + + @Test + void testSetRouteStateWithNullState() { + DemultiplexingSinkState state = new DemultiplexingSinkState<>(); + + // Add a route first + state.setRouteState("route1", new byte[] {1, 2, 3}); + assertThat(state.getRoutes()).containsExactly("route1"); + + // Setting null state should remove the route + state.setRouteState("route1", null); + assertThat(state.getRoutes()).isEmpty(); + assertThat(state.getRouteState("route1")).isNull(); + } + + @Test + void testSetRouteStateWithNullStateOnNonExistentRoute() { + DemultiplexingSinkState state = new DemultiplexingSinkState<>(); + + // Setting null state on non-existent route should be no-op + state.setRouteState("nonExistent", null); + assertThat(state.getRoutes()).isEmpty(); + } + + @Test + void testEqualsWithSameInstance() { + DemultiplexingSinkState state = new DemultiplexingSinkState<>(); + state.setRouteState("route1", new byte[] {1, 2, 3}); + + assertThat(state.equals(state)).isTrue(); + } + + @Test + void testEqualsWithNull() { + DemultiplexingSinkState state = new DemultiplexingSinkState<>(); + + assertThat(state.equals(null)).isFalse(); + } + + @Test + void testEqualsWithDifferentClass() { + DemultiplexingSinkState state = new DemultiplexingSinkState<>(); + + assertThat(state.equals("not a state")).isFalse(); + } + + @Test + void testEqualsWithDifferentRouteSizes() { + DemultiplexingSinkState state1 = new DemultiplexingSinkState<>(); + state1.setRouteState("route1", new byte[] {1, 2, 3}); + + DemultiplexingSinkState state2 = new DemultiplexingSinkState<>(); + state2.setRouteState("route1", new byte[] {1, 2, 3}); + state2.setRouteState("route2", new byte[] {4, 5, 6}); + + assertThat(state1.equals(state2)).isFalse(); + assertThat(state2.equals(state1)).isFalse(); + } + + @Test + void testEqualsWithSameRoutesButDifferentStates() { + DemultiplexingSinkState state1 = new DemultiplexingSinkState<>(); + state1.setRouteState("route1", new byte[] {1, 2, 3}); + + DemultiplexingSinkState state2 = new DemultiplexingSinkState<>(); + state2.setRouteState("route1", new byte[] {4, 5, 6}); + + assertThat(state1.equals(state2)).isFalse(); + } + + @Test + void testEqualsWithSameRoutesAndStates() { + DemultiplexingSinkState state1 = new DemultiplexingSinkState<>(); + state1.setRouteState("route1", new byte[] {1, 2, 3}); + state1.setRouteState("route2", new byte[] {4, 5, 6}); + + DemultiplexingSinkState state2 = new DemultiplexingSinkState<>(); + state2.setRouteState("route1", new byte[] {1, 2, 3}); + state2.setRouteState("route2", new byte[] {4, 5, 6}); + + assertThat(state1.equals(state2)).isTrue(); + assertThat(state2.equals(state1)).isTrue(); + } + + @Test + void testHashCodeConsistency() { + DemultiplexingSinkState state1 = new DemultiplexingSinkState<>(); + state1.setRouteState("route1", new byte[] {1, 2, 3}); + state1.setRouteState("route2", new byte[] {4, 5, 6}); + + DemultiplexingSinkState state2 = new DemultiplexingSinkState<>(); + state2.setRouteState("route1", new byte[] {1, 2, 3}); + state2.setRouteState("route2", new byte[] {4, 5, 6}); + + // Equal objects must have equal hash codes + assertThat(state1.equals(state2)).isTrue(); + assertThat(state1.hashCode()).isEqualTo(state2.hashCode()); + } + + @Test + void testHashCodeWithEmptyState() { + DemultiplexingSinkState state1 = new DemultiplexingSinkState<>(); + DemultiplexingSinkState state2 = new DemultiplexingSinkState<>(); + + assertThat(state1.hashCode()).isEqualTo(state2.hashCode()); + } + + @Test + void testHashCodeWithDifferentStates() { + DemultiplexingSinkState state1 = new DemultiplexingSinkState<>(); + state1.setRouteState("route1", new byte[] {1, 2, 3}); + + DemultiplexingSinkState state2 = new DemultiplexingSinkState<>(); + state2.setRouteState("route1", new byte[] {4, 5, 6}); + + // Different objects should typically have different hash codes + assertThat(state1.hashCode()).isNotEqualTo(state2.hashCode()); + } + + @Test + void testToStringWithEmptyState() { + DemultiplexingSinkState state = new DemultiplexingSinkState<>(); + + String toString = state.toString(); + assertThat(toString).contains("DemultiplexingSinkState"); + assertThat(toString).contains("routeCount=0"); + assertThat(toString).contains("routes=[]"); + } + + @Test + void testToStringWithSingleRoute() { + DemultiplexingSinkState state = new DemultiplexingSinkState<>(); + state.setRouteState("route1", new byte[] {1, 2, 3}); + + String toString = state.toString(); + assertThat(toString).contains("DemultiplexingSinkState"); + assertThat(toString).contains("routeCount=1"); + assertThat(toString).contains("route1"); + } + + @Test + void testToStringWithMultipleRoutes() { + DemultiplexingSinkState state = new DemultiplexingSinkState<>(); + state.setRouteState("route1", new byte[] {1, 2, 3}); + state.setRouteState("route2", new byte[] {4, 5, 6}); + + String toString = state.toString(); + assertThat(toString).contains("DemultiplexingSinkState"); + assertThat(toString).contains("routeCount=2"); + assertThat(toString).contains("route1"); + assertThat(toString).contains("route2"); + } + + /** A complex route key for testing serialization. */ + private static class ComplexRouteKey implements java.io.Serializable { + private static final long serialVersionUID = 1L; + + private final String host; + private final int port; + + public ComplexRouteKey(String host, int port) { + this.host = host; + this.port = port; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ComplexRouteKey that = (ComplexRouteKey) o; + return port == that.port && java.util.Objects.equals(host, that.host); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(host, port); + } + + @Override + public String toString() { + return "ComplexRouteKey{host='" + host + "', port=" + port + '}'; + } + } +} diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkTest.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkTest.java new file mode 100644 index 0000000000000..33e956588416d --- /dev/null +++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkTest.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.base.sink; + +import org.apache.flink.api.connector.sink2.Sink; +import org.apache.flink.api.connector.sink2.SinkWriter; +import org.apache.flink.connector.base.sink.writer.TestSinkInitContext; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link DemultiplexingSink}. */ +class DemultiplexingSinkTest { + + @Test + void testSinkCreation() { + final TestSinkRouter router = new TestSinkRouter(); + final DemultiplexingSink sink = new DemultiplexingSink<>(router); + + assertThat(sink.getSinkRouter()).isSameAs(router); + } + + @Test + void testSinkCreationWithNullRouter() { + assertThatThrownBy(() -> new DemultiplexingSink(null)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("sinkRouter must not be null"); + } + + @Test + void testCreateWriter() throws IOException { + final TestSinkRouter router = new TestSinkRouter(); + final DemultiplexingSink sink = new DemultiplexingSink<>(router); + final TestSinkInitContext context = new TestSinkInitContext(); + + final SinkWriter writer = sink.createWriter(context); + + assertThat(writer).isInstanceOf(DemultiplexingSinkWriter.class); + } + + @Test + void testWriterStateSerializer() { + final TestSinkRouter router = new TestSinkRouter(); + final DemultiplexingSink sink = new DemultiplexingSink<>(router); + + assertThat(sink.getWriterStateSerializer()).isNotNull(); + assertThat(sink.getWriterStateSerializer()) + .isInstanceOf(DemultiplexingSinkStateSerializer.class); + } + + @Test + void testRestoreWriter() { + final TestSinkRouter router = new TestSinkRouter(); + final DemultiplexingSink sink = new DemultiplexingSink<>(router); + final TestSinkInitContext context = new TestSinkInitContext(); + + // Create some state to restore from + final DemultiplexingSinkState state1 = new DemultiplexingSinkState<>(); + state1.setRouteState("a", new byte[] {1, 2, 3}); + state1.setRouteState("b", new byte[] {4, 5, 6}); + + final DemultiplexingSinkState state2 = new DemultiplexingSinkState<>(); + state2.setRouteState("c", new byte[] {7, 8, 9}); + + final java.util.List> recoveredStates = + java.util.Arrays.asList(state1, state2); + + // Restore the writer with the states + final var restoredWriter = sink.restoreWriter(context, recoveredStates); + + // Verify that the restored writer is the correct type + assertThat(restoredWriter).isInstanceOf(DemultiplexingSinkWriter.class); + + // Verify that the writer was created successfully + assertThat(restoredWriter).isNotNull(); + } + + @Test + void testRestoreWriterWithEmptyState() { + final TestSinkRouter router = new TestSinkRouter(); + final DemultiplexingSink sink = new DemultiplexingSink<>(router); + final TestSinkInitContext context = new TestSinkInitContext(); + + // Create an empty state + final DemultiplexingSinkState emptyState = new DemultiplexingSinkState<>(); + final java.util.List> recoveredStates = List.of(emptyState); + + // Restore the writer with empty state + final var restoredWriter = sink.restoreWriter(context, recoveredStates); + + // Verify that the restored writer is created successfully even with empty state + assertThat(restoredWriter).isNotNull(); + assertThat(restoredWriter).isInstanceOf(DemultiplexingSinkWriter.class); + } + + /** Test implementation of {@link SinkRouter}. */ + private static class TestSinkRouter implements SinkRouter { + private final AtomicInteger sinkCreationCount = new AtomicInteger(0); + + @Override + public String getRoute(String element) { + // Route based on first character + return element.substring(0, 1); + } + + @Override + public Sink createSink(String route, String element) { + sinkCreationCount.incrementAndGet(); + return new TestSink(route); + } + + public int getSinkCreationCount() { + return sinkCreationCount.get(); + } + } + + /** Test implementation of {@link Sink}. */ + private static class TestSink implements Sink { + private final String route; + + public TestSink(String route) { + this.route = route; + } + + @Override + public SinkWriter createWriter( + org.apache.flink.api.connector.sink2.WriterInitContext context) { + return new TestSinkWriter(route); + } + + public String getRoute() { + return route; + } + } + + /** Test implementation of {@link SinkWriter}. */ + private static class TestSinkWriter implements SinkWriter { + private final String route; + private final List elements = new ArrayList<>(); + private boolean closed = false; + + public TestSinkWriter(String route) { + this.route = route; + } + + @Override + public void write(String element, Context context) { + if (closed) { + throw new IllegalStateException("Writer is closed"); + } + elements.add(element); + } + + @Override + public void flush(boolean endOfInput) { + // No-op for test + } + + @Override + public void close() { + closed = true; + } + + public List getElements() { + return new ArrayList<>(elements); + } + + public String getRoute() { + return route; + } + + public boolean isClosed() { + return closed; + } + } +} diff --git a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriterTest.java b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriterTest.java new file mode 100644 index 0000000000000..54a03f72c9b3e --- /dev/null +++ b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/sink/DemultiplexingSinkWriterTest.java @@ -0,0 +1,430 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.connector.base.sink; + +import org.apache.flink.api.common.eventtime.Watermark; +import org.apache.flink.api.connector.sink2.Sink; +import org.apache.flink.api.connector.sink2.SinkWriter; +import org.apache.flink.api.connector.sink2.StatefulSinkWriter; +import org.apache.flink.api.connector.sink2.SupportsWriterState; +import org.apache.flink.api.connector.sink2.WriterInitContext; +import org.apache.flink.connector.base.sink.writer.TestSinkInitContext; +import org.apache.flink.core.io.SimpleVersionedSerializer; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link DemultiplexingSinkWriter}. */ +class DemultiplexingSinkWriterTest { + + private TestSinkRouter router; + private TestSinkInitContext context; + private DemultiplexingSinkWriter writer; + + @BeforeEach + void setUp() { + router = new TestSinkRouter(); + context = new TestSinkInitContext(); + writer = new DemultiplexingSinkWriter<>(router, context); + } + + @Test + void testWriteToSingleRoute() throws IOException, InterruptedException { + // Write elements that all route to the same destination + writer.write("apple", createContext()); + writer.write("avocado", createContext()); + writer.write("apricot", createContext()); + + // Should have created only one sink + assertThat(router.getSinkCreationCount()).isEqualTo(1); + assertThat(writer.getActiveSinkWriterCount()).isEqualTo(1); + assertThat(writer.getActiveRoutes()).containsExactly("a"); + + // All elements should be in the same sink writer + TestSinkWriter sinkWriter = router.getSinkWriter("a"); + assertThat(sinkWriter.getElements()).containsExactly("apple", "avocado", "apricot"); + } + + @Test + void testWriteToMultipleRoutes() throws IOException, InterruptedException { + // Write elements that route to different destinations ("apple" and "apricot" should resolve + // to the same) + writer.write("apple", createContext()); + writer.write("banana", createContext()); + writer.write("cherry", createContext()); + writer.write("apricot", createContext()); + + // Should have created three sinks (a, b, c) + assertThat(router.getSinkCreationCount()).isEqualTo(3); + assertThat(writer.getActiveSinkWriterCount()).isEqualTo(3); + assertThat(writer.getActiveRoutes()).containsExactlyInAnyOrder("a", "b", "c"); + + // Check elements are routed correctly + assertThat(router.getSinkWriter("a").getElements()).containsExactly("apple", "apricot"); + assertThat(router.getSinkWriter("b").getElements()).containsExactly("banana"); + assertThat(router.getSinkWriter("c").getElements()).containsExactly("cherry"); + } + + @Test + void testFlush() throws IOException, InterruptedException { + // Write to multiple routes + writer.write("apple", createContext()); + writer.write("banana", createContext()); + + // Flush should be called on all sink writers + writer.flush(false); + + // Verify flush was called (our test implementation tracks this) + assertThat(router.getSinkWriter("a").getFlushCount()).isEqualTo(1); + assertThat(router.getSinkWriter("b").getFlushCount()).isEqualTo(1); + } + + @Test + void testWriteWatermark() throws IOException, InterruptedException { + // Write to multiple routes + writer.write("apple", createContext()); + writer.write("banana", createContext()); + + // Write watermark should be propagated to all sink writers + Watermark watermark = new Watermark(12345L); + writer.writeWatermark(watermark); + + // Verify watermark was written (our test implementation tracks this) + assertThat(router.getSinkWriter("a").getWatermarksReceived()).containsExactly(watermark); + assertThat(router.getSinkWriter("b").getWatermarksReceived()).containsExactly(watermark); + } + + @Test + void testClose() throws Exception { + // Write to multiple routes + writer.write("apple", createContext()); + writer.write("banana", createContext()); + + // Close should close all sink writers + writer.close(); + + // Verify all writers were closed + assertThat(router.getSinkWriter("a").isClosed()).isTrue(); + assertThat(router.getSinkWriter("b").isClosed()).isTrue(); + + // Active count should be zero after close + assertThat(writer.getActiveSinkWriterCount()).isEqualTo(0); + } + + @Test + void testSnapshotState() throws IOException, InterruptedException { + // Write to multiple routes + writer.write("apple", createContext()); + writer.write("banana", createContext()); + + // Snapshot state + List> states = writer.snapshotState(1L); + + // Should return state (even if empty for our test implementation) + assertThat(states).isNotNull(); + } + + @Test + void testJavaSerializationFallback() throws Exception { + // Create a router that creates sinks without proper state serializers + JavaSerializationFallbackRouter fallbackRouter = new JavaSerializationFallbackRouter(); + DemultiplexingSinkWriter writer = + new DemultiplexingSinkWriter<>(fallbackRouter, context); + + // Write elements to create stateful sink writers + writer.write("apple", createContext()); + writer.write("banana", createContext()); + + // Verify sink writers were created + assertThat(writer.getActiveSinkWriterCount()).isEqualTo(2); + + // Add some state to the stateful sink writers + JavaSerializationFallbackSink sinkA = fallbackRouter.getCreatedSink("a"); + JavaSerializationFallbackSink sinkB = fallbackRouter.getCreatedSink("b"); + + assertThat(sinkA).isNotNull(); + assertThat(sinkB).isNotNull(); + + // The stateful writers should have received the elements + assertThat(sinkA.getCreatedWriter().getElements()).containsExactly("apple"); + assertThat(sinkB.getCreatedWriter().getElements()).containsExactly("banana"); + + // Snapshot state - this should trigger Java serialization fallback + // since our test sink has a failing state serializer + List> snapshotStates = writer.snapshotState(1L); + + // Should have successfully created state using Java serialization fallback + assertThat(snapshotStates).hasSize(1); + DemultiplexingSinkState state = snapshotStates.get(0); + assertThat(state.size()).isEqualTo(2); + assertThat(state.getRoutes()).containsExactlyInAnyOrder("a", "b"); + + writer.close(); + } + + /** Test implementation of {@link SinkWriter.Context}. */ + private SinkWriter.Context createContext() { + return new SinkWriter.Context() { + @Override + public long currentWatermark() { + return 0; + } + + @Override + public Long timestamp() { + return null; + } + }; + } + + /** Test implementation of {@link SinkRouter}. */ + private static class TestSinkRouter implements SinkRouter { + private final AtomicInteger sinkCreationCount = new AtomicInteger(0); + private final List createdSinks = new ArrayList<>(); + + @Override + public String getRoute(String element) { + // Route based on first character + return element.substring(0, 1); + } + + @Override + public Sink createSink(String route, String element) { + sinkCreationCount.incrementAndGet(); + TestSink sink = new TestSink(route); + createdSinks.add(sink); + return sink; + } + + public int getSinkCreationCount() { + return sinkCreationCount.get(); + } + + public TestSinkWriter getSinkWriter(String route) { + return createdSinks.stream() + .filter(sink -> sink.getRoute().equals(route)) + .findFirst() + .map(TestSink::getCreatedWriter) + .orElse(null); + } + } + + /** Test implementation of {@link Sink}. */ + private static class TestSink implements Sink { + private final String route; + private TestSinkWriter createdWriter; + + public TestSink(String route) { + this.route = route; + } + + @Override + public SinkWriter createWriter( + org.apache.flink.api.connector.sink2.WriterInitContext context) { + createdWriter = new TestSinkWriter(route); + return createdWriter; + } + + public String getRoute() { + return route; + } + + public TestSinkWriter getCreatedWriter() { + return createdWriter; + } + } + + /** Test implementation of {@link SinkWriter}. */ + private static class TestSinkWriter implements SinkWriter { + private final String route; + private final List elements = new ArrayList<>(); + private final List watermarksReceived = new ArrayList<>(); + private int flushCount = 0; + private boolean closed = false; + + public TestSinkWriter(String route) { + this.route = route; + } + + @Override + public void write(String element, Context context) { + if (closed) { + throw new IllegalStateException("Writer is closed"); + } + elements.add(element); + } + + @Override + public void flush(boolean endOfInput) { + flushCount++; + } + + @Override + public void writeWatermark(Watermark watermark) { + watermarksReceived.add(watermark); + } + + @Override + public void close() { + closed = true; + } + + public List getElements() { + return new ArrayList<>(elements); + } + + public List getWatermarksReceived() { + return new ArrayList<>(watermarksReceived); + } + + public int getFlushCount() { + return flushCount; + } + + public String getRoute() { + return route; + } + + public boolean isClosed() { + return closed; + } + } + + /** Test serialization fallback implementation of {@link SinkRouter}. */ + private static class JavaSerializationFallbackRouter implements SinkRouter { + private final java.util.Map createdSinks = + new java.util.HashMap<>(); + + @Override + public String getRoute(String element) { + return element.substring(0, 1); + } + + @Override + public Sink createSink(String route, String element) { + JavaSerializationFallbackSink sink = new JavaSerializationFallbackSink(route); + createdSinks.put(route, sink); + return sink; + } + + public JavaSerializationFallbackSink getCreatedSink(String route) { + return createdSinks.get(route); + } + } + + /** Test serialization fallback implementation of {@link Sink}. */ + private static class JavaSerializationFallbackSink + implements Sink, SupportsWriterState { + private final String route; + private JavaSerializationFallbackWriter createdWriter; + + public JavaSerializationFallbackSink(String route) { + this.route = route; + } + + @Override + public SinkWriter createWriter(WriterInitContext context) { + createdWriter = new JavaSerializationFallbackWriter(route); + return createdWriter; + } + + @Override + public StatefulSinkWriter restoreWriter( + WriterInitContext context, Collection recoveredState) { + createdWriter = new JavaSerializationFallbackWriter(route); + // Restore the state + for (String state : recoveredState) { + createdWriter.addElement(state); + } + return createdWriter; + } + + @Override + public SimpleVersionedSerializer getWriterStateSerializer() { + // Return a serializer that will fail, forcing fallback to Java serialization + return new SimpleVersionedSerializer() { + @Override + public int getVersion() { + return 1; + } + + @Override + public byte[] serialize(String obj) throws IOException { + throw new IOException("Intentional serialization failure to test fallback"); + } + + @Override + public String deserialize(int version, byte[] serialized) throws IOException { + throw new IOException("Intentional deserialization failure to test fallback"); + } + }; + } + + public JavaSerializationFallbackWriter getCreatedWriter() { + return createdWriter; + } + } + + /** Test serialization fallback implementation of {@link StatefulSinkWriter}. */ + private static class JavaSerializationFallbackWriter + implements StatefulSinkWriter { + private final String route; + private final List elements = new ArrayList<>(); + + public JavaSerializationFallbackWriter(String route) { + this.route = route; + } + + @Override + public void write(String element, Context context) { + elements.add(element); + } + + @Override + public void flush(boolean endOfInput) { + // No-op for test + } + + @Override + public void close() { + // No-op for test + } + + @Override + public List snapshotState(long checkpointId) { + // Return the elements as state (this will be Java-serialized as fallback) + return new ArrayList<>(elements); + } + + public List getElements() { + return new ArrayList<>(elements); + } + + public void addElement(String element) { + elements.add(element); + } + } +}