diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/sink/KafkaSerializerWrapper.java b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/sink/KafkaSerializerWrapper.java index f2120568b..6bc811a46 100644 --- a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/sink/KafkaSerializerWrapper.java +++ b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/sink/KafkaSerializerWrapper.java @@ -61,17 +61,12 @@ class KafkaSerializerWrapper implements SerializationSchema { this(serializerClass, isKey, Collections.emptyMap(), topicSelector); } - @SuppressWarnings("unchecked") @Override public void open(InitializationContext context) throws Exception { final ClassLoader userCodeClassLoader = context.getUserCodeClassLoader().asClassLoader(); try (TemporaryClassLoaderContext ignored = TemporaryClassLoaderContext.of(userCodeClassLoader)) { - serializer = - InstantiationUtil.instantiate( - serializerClass.getName(), - Serializer.class, - getClass().getClassLoader()); + initializeSerializer(userCodeClassLoader); if (serializer instanceof Configurable) { ((Configurable) serializer).configure(config); @@ -88,4 +83,11 @@ public byte[] serialize(IN element) { checkState(serializer != null, "Call open() once before trying to serialize elements."); return serializer.serialize(topicSelector.apply(element), element); } + + @SuppressWarnings("unchecked") + protected void initializeSerializer(ClassLoader classLoader) throws Exception { + serializer = + InstantiationUtil.instantiate( + serializerClass.getName(), Serializer.class, classLoader); + } } diff --git a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/reader/deserializer/KafkaValueOnlyDeserializerWrapper.java b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/reader/deserializer/KafkaValueOnlyDeserializerWrapper.java index 8c8095b6b..4d320fcc8 100644 --- a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/reader/deserializer/KafkaValueOnlyDeserializerWrapper.java +++ b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/reader/deserializer/KafkaValueOnlyDeserializerWrapper.java @@ -55,17 +55,11 @@ class KafkaValueOnlyDeserializerWrapper implements KafkaRecordDeserialization } @Override - @SuppressWarnings("unchecked") public void open(DeserializationSchema.InitializationContext context) throws Exception { ClassLoader userCodeClassLoader = context.getUserCodeClassLoader().asClassLoader(); try (TemporaryClassLoaderContext ignored = TemporaryClassLoaderContext.of(userCodeClassLoader)) { - deserializer = - (Deserializer) - InstantiationUtil.instantiate( - deserializerClass.getName(), - Deserializer.class, - getClass().getClassLoader()); + initializeDeserializer(userCodeClassLoader); if (deserializer instanceof Configurable) { ((Configurable) deserializer).configure(config); @@ -103,4 +97,11 @@ public void deserialize(ConsumerRecord record, Collector coll public TypeInformation getProducedType() { return TypeExtractor.createTypeInfo(Deserializer.class, deserializerClass, 0, null, null); } + + @SuppressWarnings("unchecked") + protected void initializeDeserializer(ClassLoader classLoader) throws Exception { + deserializer = + InstantiationUtil.instantiate( + deserializerClass.getName(), Deserializer.class, classLoader); + } } diff --git a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/sink/KafkaSerializerWrapperTest.java b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/sink/KafkaSerializerWrapperTest.java new file mode 100644 index 000000000..2f8d872a8 --- /dev/null +++ b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/sink/KafkaSerializerWrapperTest.java @@ -0,0 +1,74 @@ +package org.apache.flink.connector.kafka.sink; + +import org.apache.flink.api.common.serialization.SerializationSchema; +import org.apache.flink.connector.testutils.formats.DummyInitializationContext; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; +import org.apache.flink.util.FlinkUserCodeClassLoaders; +import org.apache.flink.util.SimpleUserCodeClassLoader; +import org.apache.flink.util.UserCodeClassLoader; + +import org.apache.kafka.common.serialization.StringSerializer; +import org.junit.Test; + +import java.net.URL; + +import static org.junit.Assert.assertEquals; + +/** Tests for {@link KafkaSerializerWrapper}. */ +public class KafkaSerializerWrapperTest { + @Test + public void testUserCodeClassLoaderIsUsed() throws Exception { + final KafkaSerializerWrapperCaptureForTest wrapper = + new KafkaSerializerWrapperCaptureForTest(); + final ClassLoader classLoader = + FlinkUserCodeClassLoaders.childFirst( + new URL[0], + getClass().getClassLoader(), + new String[0], + throwable -> {}, + true); + wrapper.open( + new SerializationSchema.InitializationContext() { + @Override + public MetricGroup getMetricGroup() { + return new UnregisteredMetricsGroup(); + } + + @Override + public UserCodeClassLoader getUserCodeClassLoader() { + return SimpleUserCodeClassLoader.create(classLoader); + } + }); + + assertEquals(classLoader, wrapper.getClassLoaderUsed()); + } + + @Test + public void testDefaultClassLoaderIsUsed() throws Exception { + final KafkaSerializerWrapperCaptureForTest wrapper = + new KafkaSerializerWrapperCaptureForTest(); + wrapper.open(new DummyInitializationContext()); + + assertEquals( + DummyInitializationContext.class.getClassLoader(), wrapper.getClassLoaderUsed()); + } + + static class KafkaSerializerWrapperCaptureForTest extends KafkaSerializerWrapper { + private ClassLoader classLoaderUsed; + + KafkaSerializerWrapperCaptureForTest() { + super(StringSerializer.class, true, (value) -> "topic"); + } + + public ClassLoader getClassLoaderUsed() { + return classLoaderUsed; + } + + @Override + protected void initializeSerializer(ClassLoader classLoader) throws Exception { + classLoaderUsed = classLoader; + super.initializeSerializer(classLoader); + } + } +} diff --git a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/reader/deserializer/KafkaValueOnlyDeserializerWrapperTest.java b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/reader/deserializer/KafkaValueOnlyDeserializerWrapperTest.java new file mode 100644 index 000000000..312bfbcdf --- /dev/null +++ b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/reader/deserializer/KafkaValueOnlyDeserializerWrapperTest.java @@ -0,0 +1,76 @@ +package org.apache.flink.connector.kafka.source.reader.deserializer; + +import org.apache.flink.api.common.serialization.DeserializationSchema; +import org.apache.flink.connector.testutils.formats.DummyInitializationContext; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; +import org.apache.flink.util.FlinkUserCodeClassLoaders; +import org.apache.flink.util.SimpleUserCodeClassLoader; +import org.apache.flink.util.UserCodeClassLoader; + +import org.apache.kafka.common.serialization.StringDeserializer; +import org.junit.Test; + +import java.net.URL; +import java.util.HashMap; + +import static org.junit.Assert.assertEquals; + +/** Tests for {@link KafkaValueOnlyDeserializerWrapper}. */ +public class KafkaValueOnlyDeserializerWrapperTest { + @Test + public void testUserCodeClassLoaderIsUsed() throws Exception { + final KafkaValueOnlyDeserializerWrapperCaptureForTest wrapper = + new KafkaValueOnlyDeserializerWrapperCaptureForTest(); + final ClassLoader classLoader = + FlinkUserCodeClassLoaders.childFirst( + new URL[0], + getClass().getClassLoader(), + new String[0], + throwable -> {}, + true); + wrapper.open( + new DeserializationSchema.InitializationContext() { + @Override + public MetricGroup getMetricGroup() { + return new UnregisteredMetricsGroup(); + } + + @Override + public UserCodeClassLoader getUserCodeClassLoader() { + return SimpleUserCodeClassLoader.create(classLoader); + } + }); + + assertEquals(classLoader, wrapper.getClassLoaderUsed()); + } + + @Test + public void testDefaultClassLoaderIsUsed() throws Exception { + final KafkaValueOnlyDeserializerWrapperCaptureForTest wrapper = + new KafkaValueOnlyDeserializerWrapperCaptureForTest(); + wrapper.open(new DummyInitializationContext()); + + assertEquals( + DummyInitializationContext.class.getClassLoader(), wrapper.getClassLoaderUsed()); + } + + static class KafkaValueOnlyDeserializerWrapperCaptureForTest + extends KafkaValueOnlyDeserializerWrapper { + private ClassLoader classLoaderUsed; + + KafkaValueOnlyDeserializerWrapperCaptureForTest() { + super(StringDeserializer.class, new HashMap<>()); + } + + public ClassLoader getClassLoaderUsed() { + return classLoaderUsed; + } + + @Override + protected void initializeDeserializer(ClassLoader classLoader) throws Exception { + classLoaderUsed = classLoader; + super.initializeDeserializer(classLoader); + } + } +}