Skip to content

Commit 7ca460e

Browse files
authored
Merge pull request #1303 from datastax/SPARKC-635-2.5
SPARKC-635 fix per cluster settings with custom props
2 parents 1aaaff3 + 84206bc commit 7ca460e

File tree

4 files changed

+66
-18
lines changed

4 files changed

+66
-18
lines changed

connector/src/it/scala/com/datastax/spark/connector/sql/CassandraDataFrameSpec.scala

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
package com.datastax.spark.connector.sql
22

33
import java.io.IOException
4-
import java.util.concurrent.CompletableFuture
5-
6-
import com.datastax.oss.driver.api.core.{CqlIdentifier, DefaultProtocolVersion}
4+
import com.datastax.oss.driver.api.core.{CqlIdentifier, CqlSession, DefaultProtocolVersion}
75
import com.datastax.oss.driver.api.core.`type`.DataTypes
86
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder
97
import com.datastax.spark.connector.cluster.DefaultCluster
108
import com.datastax.spark.connector.{SparkCassandraITFlatSpecBase, _}
11-
import com.datastax.spark.connector.cql.{CassandraConnector, ClusteringColumn}
9+
import com.datastax.spark.connector.cql.{AuthConf, AuthConfFactory, CassandraConnectionFactory, CassandraConnector, CassandraConnectorConf, ClusteringColumn, DefaultConnectionFactory, NoAuthConf}
1210
import com.datastax.spark.connector.util.DriverUtil.toName
11+
import org.apache.spark.SparkConf
1312
import org.apache.spark.sql.SaveMode
1413
import org.apache.spark.sql.cassandra._
1514
import org.apache.spark.sql.functions._
@@ -382,4 +381,39 @@ class CassandraDataFrameSpec extends SparkCassandraITFlatSpecBase with DefaultCl
382381
}
383382
}
384383

384+
it should "allow to specify custom per-cluster settings" in {
385+
sparkSession.conf.set("myCluster/spark.cassandra.connection.factory",
386+
TestConnectionFactory.getClass.getName.split("\\$").last)
387+
sparkSession.conf.set("myCluster/spark.cassandra.auth.conf.factory",
388+
TestAuthFactory.getClass.getName.split("\\$").last)
389+
sparkSession.conf.set("myCluster/spark.cassandra.test.custom.property", "specialValue")
390+
391+
sparkSession
392+
.read
393+
.format("org.apache.spark.sql.cassandra")
394+
.options(Map("table" -> "tuple_test1", "keyspace" -> ks, "cluster" -> "myCluster"))
395+
.load
396+
.count()
397+
398+
withClue("Test auth factory was not used during the test") {
399+
TestAuthFactory.used shouldBe true
400+
}
401+
}
402+
}
403+
404+
object TestConnectionFactory extends CassandraConnectionFactory {
405+
override def createSession(conf: CassandraConnectorConf): CqlSession =
406+
DefaultConnectionFactory.createSession(conf)
407+
408+
override def properties: Set[String] = Set("spark.cassandra.test.custom.property")
385409
}
410+
411+
object TestAuthFactory extends AuthConfFactory {
412+
var used: Boolean = false
413+
414+
override def authConf(conf: SparkConf): AuthConf = {
415+
used = true
416+
assert(conf.get("spark.cassandra.test.custom.property") == "specialValue")
417+
NoAuthConf
418+
}
419+
}

connector/src/main/scala/com/datastax/spark/connector/cql/CassandraConnectionFactory.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,12 @@ object CassandraConnectionFactory {
224224
deprecatedSince = "DSE 6.0.0"
225225
)
226226

227-
228227
def fromSparkConf(conf: SparkConf): CassandraConnectionFactory = {
229-
conf.getOption(FactoryParam.name)
230-
.map(ReflectionUtil.findGlobalObject[CassandraConnectionFactory])
231-
.getOrElse(FactoryParam.default)
228+
fromNameOrDefault(conf.getOption(FactoryParam.name))
232229
}
233230

231+
def fromNameOrDefault(factoryName: Option[String]): CassandraConnectionFactory = {
232+
factoryName.map(ReflectionUtil.findGlobalObject[CassandraConnectionFactory])
233+
.getOrElse(FactoryParam.default)
234+
}
234235
}

connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package org.apache.spark.sql.cassandra
22

33
import java.net.InetAddress
44
import java.util.{Locale, UUID}
5-
65
import scala.collection.mutable.ListBuffer
76
import scala.util.Try
87
import org.apache.hadoop.hive.conf.HiveConf
@@ -17,7 +16,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
1716
import org.apache.spark.sql.sources._
1817
import org.apache.spark.sql.types._
1918
import org.apache.spark.unsafe.types.UTF8String
20-
import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf, ColumnDef, Schema, TableDef}
19+
import com.datastax.spark.connector.cql.{CassandraConnectionFactory, CassandraConnector, CassandraConnectorConf, ColumnDef, Schema, TableDef}
2120
import com.datastax.spark.connector.rdd.partitioner.DataSizeEstimates
2221
import com.datastax.spark.connector.rdd.partitioner.dht.TokenFactory.forSystemLocalPartitioner
2322
import com.datastax.spark.connector.rdd.{CassandraJoinRDD, CassandraRDD, CassandraTableScanRDD, ReadConf}
@@ -741,19 +740,31 @@ object CassandraSourceRelation extends Logging {
741740
tableRef: TableRef,
742741
tableConf: Map[String, String]) : SparkConf = {
743742

744-
//Default settings
743+
// Default settings
745744
val conf = sparkConf.clone()
746745
val cluster = tableRef.cluster.getOrElse(defaultClusterName)
747746
val ks = tableRef.keyspace
748-
val AllSCCConfNames = (ConfigParameter.names ++ DeprecatedConfigParameter.names)
749-
//Keyspace/Cluster level settings
750-
for (prop <- AllSCCConfNames) {
751-
val value = Seq(
747+
748+
def consolidate(prop: String): Option[String] = {
749+
Seq(
752750
tableConf.get(prop.toLowerCase(Locale.ROOT)), //tableConf is actually a caseInsensitive map so lower case keys must be used
753751
sqlConf.get(s"$cluster:$ks/$prop"),
754752
sqlConf.get(s"$cluster/$prop"),
755753
sqlConf.get(s"default/$prop"),
756754
sqlConf.get(prop)).flatten.headOption
755+
}
756+
757+
// Custom connection factories may have a set of custom supported properties
758+
val factoryName = consolidate(CassandraConnectionFactory.FactoryParam.name)
759+
val customConnectionFactoryProperties = CassandraConnectionFactory.fromNameOrDefault(factoryName).properties
760+
761+
val AllSCCConfNames = ConfigParameter.names ++
762+
DeprecatedConfigParameter.names ++
763+
customConnectionFactoryProperties
764+
765+
//Keyspace/Cluster level settings
766+
for (prop <- AllSCCConfNames) {
767+
val value = consolidate(prop)
757768
value.foreach(conf.set(prop, _))
758769
}
759770
//Set all user properties

driver/src/test/scala/com/datastax/spark/connector/cql/MultiplexingSchemaListenerTest.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ package com.datastax.spark.connector.cql
77

88
import com.datastax.oss.driver.api.core.`type`.UserDefinedType
99
import com.datastax.oss.driver.api.core.metadata.schema.{AggregateMetadata, FunctionMetadata, KeyspaceMetadata, SchemaChangeListener, TableMetadata, ViewMetadata}
10+
import org.scalatest.concurrent.Eventually
1011
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}
1112
import org.scalatestplus.mockito.MockitoSugar
1213

1314
import scala.collection.mutable
1415
import scala.concurrent.Future
1516
import scala.util.Random
1617

17-
class MultiplexingSchemaListenerTest extends FlatSpec with Matchers with MockitoSugar with BeforeAndAfterEach {
18+
class MultiplexingSchemaListenerTest extends FlatSpec with Matchers with MockitoSugar with BeforeAndAfterEach with Eventually {
1819
val r = new Random()
1920

2021
import scala.concurrent.ExecutionContext.Implicits.global
@@ -194,7 +195,6 @@ class MultiplexingSchemaListenerTest extends FlatSpec with Matchers with Mockito
194195
}
195196

196197
it should "allow listeners to be added while triggering events" in {
197-
var listeners = 0
198198
for (it <- 1 to 200) {
199199
Future (listener.addListener(new IncrementingSchemaListener()))
200200
Future (triggerAllEvents(5))
@@ -203,7 +203,9 @@ class MultiplexingSchemaListenerTest extends FlatSpec with Matchers with Mockito
203203
if (it % 10 == 0)
204204
Future (listener.clearListeners())
205205
}
206-
actionsDone.values.sum should be > 0
206+
eventually {
207+
actionsDone.values.sum should be > 0
208+
}
207209
}
208210

209211
}

0 commit comments

Comments
 (0)