diff --git a/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala index 9629e6534..2a29f2c30 100644 --- a/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala +++ b/spark-cassandra-connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala @@ -505,6 +505,12 @@ class RDDSpec extends SparkCassandraITFlatSpecBase { results should have length keys.count(_ >= 5) } + it should " support functional where clauses" in { + val someCass = sc.parallelize(keys).map(x => new KVRow(x)).joinWithCassandraTable(ks, tableName).where("group = ?", (k : KVRow) => Seq(k.key * 100)) + val results = someCass.collect.map(_._2) + results should have length keys.size + } + it should " throw an exception if using a where on a column that is specified by the join" in { val exc = intercept[IllegalArgumentException] { val someCass = sc.parallelize(keys).map(x => (x, x * 100L)) diff --git a/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDJavaFunctions.java b/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDJavaFunctions.java index 3fe5472d4..5d3478c5f 100644 --- a/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDJavaFunctions.java +++ b/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDJavaFunctions.java @@ -103,6 +103,7 @@ public CassandraJavaPairRDD joinWithCassandraTable( Option clusteringOrder = Option.empty(); Option limit = Option.empty(); CqlWhereClause whereClause = CqlWhereClause.empty(); + FCqlWhereClause fwhereClause = FCqlWhereClause.empty(); ReadConf readConf = ReadConf.fromSparkConf(rdd.conf()); CassandraJoinRDD joinRDD = new CassandraJoinRDD<>( @@ -113,6 +114,7 @@ public CassandraJavaPairRDD joinWithCassandraTable( selectedColumns, joinColumns, whereClause, + fwhereClause, limit, clusteringOrder, readConf, diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/AbstractCassandraJoin.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/AbstractCassandraJoin.scala index d83ea06b9..09036de5b 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/AbstractCassandraJoin.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/AbstractCassandraJoin.scala @@ -23,6 +23,7 @@ private[rdd] trait AbstractCassandraJoin[L, R] { val left: RDD[L] val joinColumns: ColumnSelector + val fwhere : FCqlWhereClause[L] val manualRowWriter: Option[RowWriter[L]] implicit val rowWriterFactory: RowWriterFactory[L] @@ -99,7 +100,7 @@ private[rdd] trait AbstractCassandraJoin[L, R] { //We need to make sure we get selectedColumnRefs before serialization so that our RowReader is //built lazy val singleKeyCqlQuery: (String) = { - val whereClauses = where.predicates.flatMap(CqlWhereParser.parse) + val whereClauses = where.predicates.flatMap(CqlWhereParser.parse) ++ fwhere.predicates.flatMap(CqlWhereParser.parse) val joinColumns = joinColumnNames.map(_.columnName) val joinColumnPredicates = whereClauses.collect { case EqPredicate(c, _) if joinColumns.contains(c) => c @@ -121,7 +122,7 @@ private[rdd] trait AbstractCassandraJoin[L, R] { val joinWhere = joinColumnNames.map(_.columnName).map(name => s"${quote(name)} = :$name") val limitClause = limit.map(limit => s"LIMIT $limit").getOrElse("") val orderBy = clusteringOrder.map(_.toCql(tableDef)).getOrElse("") - val filter = (where.predicates ++ joinWhere).mkString(" AND ") + val filter = (where.predicates ++ fwhere.predicates ++ joinWhere).mkString(" AND ") val quotedKeyspaceName = quote(keyspaceName) val quotedTableName = quote(tableName) val query = @@ -135,7 +136,7 @@ private[rdd] trait AbstractCassandraJoin[L, R] { private def boundStatementBuilder(session: Session): BoundStatementBuilder[L] = { val protocolVersion = session.getCluster.getConfiguration.getProtocolOptions.getProtocolVersion val stmt = session.prepare(singleKeyCqlQuery).setConsistencyLevel(consistencyLevel) - new BoundStatementBuilder[L](rowWriter, stmt, where.values, protocolVersion = protocolVersion) + new BoundStatementBuilder[L](rowWriter, stmt, where.values, fwhere, protocolVersion = protocolVersion) } /** diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala index d5fa91f5a..7589ba176 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala @@ -10,6 +10,16 @@ import org.apache.spark.rdd.RDD import scala.reflect.ClassTag + +case class FCqlWhereClause[L](predicates: Seq[String], values: L => Seq[Any]) { + def apply(v1: L): CqlWhereClause = CqlWhereClause(predicates,values(v1)) + def and(other: FCqlWhereClause[L]) = FCqlWhereClause(predicates ++ other.predicates, (l: L) => values(l) ++ other.values(l)) +} +object FCqlWhereClause{ + def empty[L] : FCqlWhereClause[L] = FCqlWhereClause[L](Nil,(l: L) => Nil) +} + + /** * An [[org.apache.spark.rdd.RDD RDD]] that will do a selecting join between `left` RDD and the specified * Cassandra Table This will perform individual selects to retrieve the rows from Cassandra and will take @@ -27,6 +37,7 @@ class CassandraJoinRDD[L, R] private[connector]( val columnNames: ColumnSelector = AllColumns, val joinColumns: ColumnSelector = PartitionKeyColumns, val where: CqlWhereClause = CqlWhereClause.empty, + val fwhere : FCqlWhereClause[L] = FCqlWhereClause.empty[L], val limit: Option[Long] = None, val clusteringOrder: Option[ClusteringOrder] = None, val readConf: ReadConf = ReadConf(), @@ -50,7 +61,7 @@ class CassandraJoinRDD[L, R] private[connector]( case None => rowReaderFactory.rowReader(tableDef, columnNames.selectFrom(tableDef)) } - override protected def copy( + protected def copy( columnNames: ColumnSelector = columnNames, where: CqlWhereClause = where, limit: Option[Long] = limit, @@ -67,12 +78,36 @@ class CassandraJoinRDD[L, R] private[connector]( columnNames = columnNames, joinColumns = joinColumns, where = where, + fwhere = fwhere, + limit = limit, + clusteringOrder = clusteringOrder, + readConf = readConf + ) + } + + // I was not able to do a proper copy because of the inheritance. + def setFWhere( + fwhere : FCqlWhereClause[L] + ): Self = { + + new CassandraJoinRDD[L, R]( + left = left, + keyspaceName = keyspaceName, + tableName = tableName, + connector = connector, + columnNames = columnNames, + joinColumns = joinColumns, + where = where, + fwhere = fwhere, limit = limit, clusteringOrder = clusteringOrder, readConf = readConf ) } + def where(f : FCqlWhereClause[L]) : Self = setFWhere(fwhere = fwhere and f) + def where(clause : String, f : L => Seq[Any]) : Self = where(FCqlWhereClause(Seq(clause),f)) + override def cassandraCount(): Long = { columnNames match { case SomeColumns(_) => @@ -89,6 +124,7 @@ class CassandraJoinRDD[L, R] private[connector]( columnNames = SomeColumns(RowCountRef), joinColumns = joinColumns, where = where, + fwhere = fwhere, limit = limit, clusteringOrder = clusteringOrder, readConf = readConf @@ -106,13 +142,14 @@ class CassandraJoinRDD[L, R] private[connector]( columnNames = columnNames, joinColumns = joinColumns, where = where, + fwhere = fwhere, limit = limit, clusteringOrder = clusteringOrder, readConf = readConf ) } - private[rdd] def fetchIterator( + override private[rdd] def fetchIterator( session: Session, bsb: BoundStatementBuilder[L], leftIterator: Iterator[L] @@ -121,7 +158,6 @@ class CassandraJoinRDD[L, R] private[connector]( val rateLimiter = new RateLimiter( readConf.throughputJoinQueryPerSec, readConf.throughputJoinQueryPerSec ) - def pairWithRight(left: L): SettableFuture[Iterator[(L, R)]] = { val resultFuture = SettableFuture.create[Iterator[(L, R)]] val leftSide = Iterator.continually(left) @@ -141,6 +177,7 @@ class CassandraJoinRDD[L, R] private[connector]( resultFuture } val queryFutures = leftIterator.map(left => { + rateLimiter.maybeSleep(1) pairWithRight(left) }).toList @@ -162,6 +199,7 @@ class CassandraJoinRDD[L, R] private[connector]( columnNames, joinColumns, where, + fwhere, limit, clusteringOrder, readConf, diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala index 63db6b91a..2887ff5e8 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala @@ -27,6 +27,7 @@ class CassandraLeftJoinRDD[L, R] private[connector]( val columnNames: ColumnSelector = AllColumns, val joinColumns: ColumnSelector = PartitionKeyColumns, val where: CqlWhereClause = CqlWhereClause.empty, + val fwhere : FCqlWhereClause[L] = FCqlWhereClause.empty[L], val limit: Option[Long] = None, val clusteringOrder: Option[ClusteringOrder] = None, val readConf: ReadConf = ReadConf(), @@ -67,12 +68,36 @@ class CassandraLeftJoinRDD[L, R] private[connector]( columnNames = columnNames, joinColumns = joinColumns, where = where, + fwhere = fwhere, limit = limit, clusteringOrder = clusteringOrder, readConf = readConf ) } + // I was not able to do a proper copy because of the inheritance. + def setFWhere( + fwhere : FCqlWhereClause[L] + ): Self = { + + new CassandraLeftJoinRDD[L, R]( + left = left, + keyspaceName = keyspaceName, + tableName = tableName, + connector = connector, + columnNames = columnNames, + joinColumns = joinColumns, + where = where, + fwhere = fwhere, + limit = limit, + clusteringOrder = clusteringOrder, + readConf = readConf + ) + } + + def where(f : FCqlWhereClause[L]) : Self = setFWhere(fwhere = fwhere and f) + def where(clause : String, f : L => Seq[Any]) : Self = where(FCqlWhereClause(Seq(clause),f)) + override def cassandraCount(): Long = { columnNames match { case SomeColumns(_) => @@ -89,6 +114,7 @@ class CassandraLeftJoinRDD[L, R] private[connector]( columnNames = SomeColumns(RowCountRef), joinColumns = joinColumns, where = where, + fwhere = fwhere, limit = limit, clusteringOrder = clusteringOrder, readConf = readConf @@ -106,6 +132,7 @@ class CassandraLeftJoinRDD[L, R] private[connector]( columnNames = columnNames, joinColumns = joinColumns, where = where, + fwhere = fwhere, limit = limit, clusteringOrder = clusteringOrder, readConf = readConf @@ -127,6 +154,7 @@ class CassandraLeftJoinRDD[L, R] private[connector]( columnNames, joinColumns, where, + fwhere, limit, clusteringOrder, readConf, @@ -144,7 +172,6 @@ class CassandraLeftJoinRDD[L, R] private[connector]( val rateLimiter = new RateLimiter( readConf.throughputJoinQueryPerSec, readConf.throughputJoinQueryPerSec ) - def pairWithRight(left: L): SettableFuture[Iterator[(L, Option[R])]] = { val resultFuture = SettableFuture.create[Iterator[(L, Option[R])]] val leftSide = Iterator.continually(left) diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/BoundStatementBuilder.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/BoundStatementBuilder.scala index 77ad803b2..0d171a8cb 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/BoundStatementBuilder.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/BoundStatementBuilder.scala @@ -1,6 +1,7 @@ package com.datastax.spark.connector.writer import com.datastax.driver.core._ +import com.datastax.spark.connector.rdd.FCqlWhereClause import com.datastax.spark.connector.types.{ColumnType, Unset} import com.datastax.spark.connector.util.{CodecRegistryUtil, Logging} @@ -13,6 +14,7 @@ private[connector] class BoundStatementBuilder[T]( val rowWriter: RowWriter[T], val preparedStmt: PreparedStatement, val prefixVals: Seq[Any] = Seq.empty, + val dependentValues : FCqlWhereClause[T] = FCqlWhereClause.empty[T], val ignoreNulls: Boolean = false, val protocolVersion: ProtocolVersion) extends Logging { @@ -91,11 +93,21 @@ private[connector] class BoundStatementBuilder[T]( prefixConverter = ColumnType.converterToCassandra(prefixType) } yield prefixConverter.convert(prefixVal) + private def variablesConverted(row : T): Seq[AnyRef] = { + val values = dependentValues.values(row) + for { + index <- 0 until values.length + value = values(index) + valueType = preparedStmt.getVariables.getType(prefixVals.length + index) + valueConverter = ColumnType.converterToCassandra(valueType) + } yield valueConverter.convert(value) + } + /** Creates `BoundStatement` from the given data item */ def bind(row: T): RichBoundStatement = { val boundStatement = new RichBoundStatement(preparedStmt) - boundStatement.bind(prefixConverted: _*) - + val variables = prefixConverted ++ variablesConverted(row) + boundStatement.bind(variables: _*) rowWriter.readColumnValues(row, buffer) var bytesCount = 0 for (i <- 0 until columnNames.size) {