diff --git a/pulsar-spark/src/main/java/org/apache/pulsar/spark/SparkStreamingPulsarReceiver.java b/pulsar-spark/src/main/java/org/apache/pulsar/spark/SparkStreamingPulsarReceiver.java index f9e7d6c..b3736b3 100644 --- a/pulsar-spark/src/main/java/org/apache/pulsar/spark/SparkStreamingPulsarReceiver.java +++ b/pulsar-spark/src/main/java/org/apache/pulsar/spark/SparkStreamingPulsarReceiver.java @@ -22,12 +22,15 @@ import static com.google.common.base.Preconditions.checkNotNull; import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; import org.apache.pulsar.client.api.Authentication; import org.apache.pulsar.client.api.Consumer; import org.apache.pulsar.client.api.MessageListener; import org.apache.pulsar.client.api.PulsarClient; import org.apache.pulsar.client.api.PulsarClientException; +import org.apache.pulsar.client.api.ClientBuilder; import org.apache.pulsar.client.impl.PulsarClientImpl; import org.apache.pulsar.client.impl.conf.ConsumerConfigurationData; import org.apache.spark.storage.StorageLevel; @@ -43,34 +46,52 @@ public class SparkStreamingPulsarReceiver extends Receiver { private static final Logger LOG = LoggerFactory.getLogger(SparkStreamingPulsarReceiver.class); private String serviceUrl; - private ConsumerConfigurationData conf; + private Map clientConfig; + private ConsumerConfigurationData consumerConfig; private Authentication authentication; private PulsarClient pulsarClient; private Consumer consumer; public SparkStreamingPulsarReceiver( String serviceUrl, - ConsumerConfigurationData conf, + ConsumerConfigurationData consumerConfig, Authentication authentication) { - this(StorageLevel.MEMORY_AND_DISK_2(), serviceUrl, conf, authentication); + this(StorageLevel.MEMORY_AND_DISK_2(), serviceUrl, new HashMap<>(), consumerConfig, authentication); + } + + public SparkStreamingPulsarReceiver( + String serviceUrl, + Map clientConfig, + ConsumerConfigurationData consumerConfig, + Authentication authentication) { + this(StorageLevel.MEMORY_AND_DISK_2(), serviceUrl, clientConfig, consumerConfig, authentication); + } + + public SparkStreamingPulsarReceiver(StorageLevel storageLevel, + String serviceUrl, + ConsumerConfigurationData consumerConf, + Authentication authentication) { + this(storageLevel, serviceUrl, new HashMap<>(), consumerConf, authentication); } public SparkStreamingPulsarReceiver(StorageLevel storageLevel, String serviceUrl, - ConsumerConfigurationData conf, + Map clientConfig, + ConsumerConfigurationData consumerConfig, Authentication authentication) { super(storageLevel); checkNotNull(serviceUrl, "serviceUrl must not be null"); - checkNotNull(conf, "ConsumerConfigurationData must not be null"); - checkArgument(conf.getTopicNames().size() > 0, "TopicNames must be set a value."); - checkNotNull(conf.getSubscriptionName(), "SubscriptionName must not be null"); + checkNotNull(consumerConfig, "ConsumerConfigurationData must not be null"); + checkNotNull(clientConfig, "Client configuration map must not be null"); + checkArgument(consumerConfig.getTopicNames().size() > 0, "TopicNames must be set a value."); + checkNotNull(consumerConfig.getSubscriptionName(), "SubscriptionName must not be null"); this.serviceUrl = serviceUrl; this.authentication = authentication; - if (conf.getMessageListener() == null) { - conf.setMessageListener((MessageListener & Serializable) (consumer, msg) -> { + if (consumerConfig.getMessageListener() == null) { + consumerConfig.setMessageListener((MessageListener & Serializable) (consumer, msg) -> { try { store(msg.getData()); consumer.acknowledgeAsync(msg); @@ -80,13 +101,18 @@ public SparkStreamingPulsarReceiver(StorageLevel storageLevel, } }); } - this.conf = conf; + this.clientConfig = clientConfig; + this.consumerConfig = consumerConfig; } public void onStart() { try { - pulsarClient = PulsarClient.builder().serviceUrl(serviceUrl).authentication(authentication).build(); - consumer = ((PulsarClientImpl) pulsarClient).subscribeAsync(conf).join(); + ClientBuilder builder = PulsarClient.builder().serviceUrl(serviceUrl).authentication(authentication); + if (!clientConfig.isEmpty()) { + builder.loadConf(clientConfig); + } + pulsarClient = builder.build(); + consumer = ((PulsarClientImpl) pulsarClient).subscribeAsync(consumerConfig).join(); } catch (Exception e) { LOG.error("Failed to start subscription : {}", e.getMessage()); restart("Restart a consumer"); diff --git a/tests/pulsar-spark-test/src/test/java/org/apache/pulsar/spark/SparkStreamingPulsarReceiverTest.java b/tests/pulsar-spark-test/src/test/java/org/apache/pulsar/spark/SparkStreamingPulsarReceiverTest.java index 443514c..8f0a034 100644 --- a/tests/pulsar-spark-test/src/test/java/org/apache/pulsar/spark/SparkStreamingPulsarReceiverTest.java +++ b/tests/pulsar-spark-test/src/test/java/org/apache/pulsar/spark/SparkStreamingPulsarReceiverTest.java @@ -147,11 +147,51 @@ public void testSharedSubscription(Supplier serviceUrl) throws Exception assertEquals(receveidCounts.size(), 2); } + @Test(expectedExceptions = NullPointerException.class, + expectedExceptionsMessageRegExp = "ConsumerConfigurationData must not be null", + dataProvider = "ServiceUrls") + public void testReceiverWhenConsumerConfigurationIsNull(Supplier serviceUrl) { + new SparkStreamingPulsarReceiver( + serviceUrl.get(), + null, + new AuthenticationDisabled()); + } + + @Test(dataProvider = "ServiceUrls") + public void testOverrideServiceUrlWithClientConfiguration(Supplier serviceUrl) { + Map testClientConfig = new HashMap<>(); + testClientConfig.put("serviceUrl",serviceUrl.get()); + + ConsumerConfigurationData testConsumerConfig = new ConsumerConfigurationData<>(); + Set set = new HashSet<>(); + set.add(TOPIC); + testConsumerConfig.setTopicNames(set); + testConsumerConfig.setSubscriptionName(SUBS); + testConsumerConfig.setSubscriptionType(SubscriptionType.Shared); + testConsumerConfig.setReceiverQueueSize(1); + + String deliberatelyWrongServiceUrl = "http://invalid.service.url:1234"; + + SparkStreamingPulsarReceiver testReceiver = new SparkStreamingPulsarReceiver( + deliberatelyWrongServiceUrl, + testClientConfig, + testConsumerConfig, + new AuthenticationDisabled()); + + testReceiver.onStart(); + waitForTransmission(); + testReceiver.onStop(); + } + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "ConsumerConfigurationData must not be null", dataProvider = "ServiceUrls") public void testReceiverWhenClientConfigurationIsNull(Supplier serviceUrl) { - new SparkStreamingPulsarReceiver(serviceUrl.get(), null, new AuthenticationDisabled()); + new SparkStreamingPulsarReceiver( + serviceUrl.get(), + null, + null, + new AuthenticationDisabled()); } private static void waitForTransmission() {