diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/cql/CassandraConnector.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/cql/CassandraConnector.scala index 3472e683a..628285f2a 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/cql/CassandraConnector.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/cql/CassandraConnector.scala @@ -3,13 +3,13 @@ package com.datastax.spark.connector.cql import java.io.IOException import java.net.InetAddress -import scala.collection.JavaConversions._ -import scala.language.reflectiveCalls -import org.apache.spark.{SparkConf, SparkContext} import com.datastax.driver.core._ import com.datastax.spark.connector.cql.CassandraConnectorConf.CassandraSSLConf -import com.datastax.spark.connector.util.SerialShutdownHooks -import com.datastax.spark.connector.util.Logging +import com.datastax.spark.connector.util.{Logging, SerialShutdownHooks} +import org.apache.spark.{SparkConf, SparkContext} + +import scala.collection.JavaConversions._ +import scala.language.reflectiveCalls /** Provides and manages connections to Cassandra. * @@ -78,10 +78,8 @@ class CassandraConnector(val conf: CassandraConnectorConf) def openSession() = { val session = sessionCache.acquire(_config) try { - val allNodes = session.getCluster.getMetadata.getAllHosts.toSet - val dcToUse = _config.localDC.getOrElse(LocalNodeFirstLoadBalancingPolicy.determineDataCenter(_config.hosts, allNodes)) - val myNodes = allNodes.filter(_.getDatacenter == dcToUse).map(_.getAddress) - _config = _config.copy(hosts = myNodes) + val foundNodes = findNodes(session) + _config = _config.copy(hosts = foundNodes) val connectionsPerHost = _config.maxConnectionsPerExecutor.getOrElse(1) val poolingOptions = session.getCluster.getConfiguration.getPoolingOptions @@ -102,6 +100,18 @@ class CassandraConnector(val conf: CassandraConnectorConf) } } + private def findNodes(session: Session) = { + val allNodes: Set[Host] = session.getCluster.getMetadata.getAllHosts.toSet + + session.getCluster.getConfiguration.getPolicies.getLoadBalancingPolicy match { + case policy: DataCenterAware => { + val dcToUse = _config.localDC.getOrElse(policy.determineDataCenter(_config.hosts, allNodes)) + allNodes.filter(_.getDatacenter == dcToUse).map(_.getAddress) + } + case _ => allNodes.map(_.getAddress) + } + } + /** Allows to use Cassandra `Session` in a safe way without * risk of forgetting to close it. The `Session` object obtained through this method * is a proxy to a shared, single `Session` associated with the cluster. diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/cql/DataCenterAware.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/cql/DataCenterAware.scala new file mode 100644 index 000000000..51259963c --- /dev/null +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/cql/DataCenterAware.scala @@ -0,0 +1,9 @@ +package com.datastax.spark.connector.cql + +import java.net.InetAddress + +import com.datastax.driver.core.Host + +trait DataCenterAware { + def determineDataCenter(contactPoints: Set[InetAddress], allHosts: Set[Host]):String +} diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/cql/LocalNodeFirstLoadBalancingPolicy.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/cql/LocalNodeFirstLoadBalancingPolicy.scala index b02b408ca..5b1505735 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/cql/LocalNodeFirstLoadBalancingPolicy.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/cql/LocalNodeFirstLoadBalancingPolicy.scala @@ -15,7 +15,7 @@ import scala.util.Random * For writes, if a statement has a routing key set, this LBP is token aware - it prefers the nodes which * are replicas of the computed token to the other nodes. */ class LocalNodeFirstLoadBalancingPolicy(contactPoints: Set[InetAddress], localDC: Option[String] = None, - shuffleReplicas: Boolean = true) extends LoadBalancingPolicy with Logging { + shuffleReplicas: Boolean = true) extends LoadBalancingPolicy with Logging with DataCenterAware { import LocalNodeFirstLoadBalancingPolicy._ @@ -72,7 +72,7 @@ class LocalNodeFirstLoadBalancingPolicy(contactPoints: Set[InetAddress], localDC else tokenAwareQueryPlan(keyspace, statement) } - + override def onAdd(host: Host) { // The added host might be a "better" version of a host already in the set. // The nodes added in the init call don't have DC and rack set. @@ -86,6 +86,10 @@ class LocalNodeFirstLoadBalancingPolicy(contactPoints: Set[InetAddress], localDC logInfo(s"Removed host ${host.getAddress.getHostAddress} (${host.getDatacenter})") } + override def determineDataCenter(contactPoints: Set[InetAddress], allHosts: Set[Host]): String = { + LocalNodeFirstLoadBalancingPolicy.determineDataCenter(contactPoints,allHosts) + } + override def close() = { } override def onUp(host: Host) = { } override def onDown(host: Host) = { }