Skip to content

Commit 443f7e4

Browse files
committed
SPARKC-619 use async queries in row fetching
This is a base for eager page prefetching.
1 parent c2ccc68 commit 443f7e4

File tree

9 files changed

+107
-53
lines changed

9 files changed

+107
-53
lines changed

connector/src/main/scala/com/datastax/bdp/util/ScalaJavaUtil.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
package com.datastax.bdp.util
88

99
import java.time.{Duration => JavaDuration}
10-
import java.util.concurrent.Callable
10+
import java.util.concurrent.{Callable, CompletionStage}
1111
import java.util.function
12-
import java.util.function.{Consumer, Predicate, Supplier}
12+
import java.util.function.{BiConsumer, Consumer, Predicate, Supplier}
1313

14+
import scala.concurrent.{ExecutionContext, Future, Promise}
1415
import scala.concurrent.duration.{Duration => ScalaDuration}
1516
import scala.language.implicitConversions
1617

@@ -45,4 +46,19 @@ object ScalaJavaUtil {
4546
}
4647

4748
def asScalaFunction[T, R](f: java.util.function.Function[T, R]): T => R = x => f(x)
49+
50+
def asScalaFuture[T](completionStage: CompletionStage[T])
51+
(implicit context: ExecutionContext): Future[T] = {
52+
val promise = Promise[T]()
53+
completionStage.whenCompleteAsync(new BiConsumer[T, java.lang.Throwable] {
54+
override def accept(t: T, throwable: Throwable): Unit = {
55+
if (throwable == null)
56+
promise.success(t)
57+
else
58+
promise.failure(throwable)
59+
60+
}
61+
})
62+
promise.future
63+
}
4864
}

connector/src/main/scala/com/datastax/spark/connector/cql/Scanner.scala

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.datastax.spark.connector.cql
22

3+
import com.datastax.bdp.util.ScalaJavaUtil.asScalaFuture
34
import com.datastax.oss.driver.api.core.CqlSession
45
import com.datastax.oss.driver.api.core.cql.{Row, Statement}
56
import com.datastax.spark.connector.CassandraRowMetadata
@@ -8,6 +9,9 @@ import com.datastax.spark.connector.rdd.reader.PrefetchingResultSetIterator
89
import com.datastax.spark.connector.util.maybeExecutingAs
910
import com.datastax.spark.connector.writer.RateLimiter
1011

12+
import scala.concurrent.duration.Duration
13+
import scala.concurrent.{Await}
14+
1115
/**
1216
* Object which will be used in Table Scanning Operations.
1317
* One Scanner will be created per Spark Partition, it will be
@@ -35,21 +39,25 @@ class DefaultScanner (
3539
}
3640

3741
override def scan[StatementT <: Statement[StatementT]](statement: StatementT): ScanResult = {
38-
val rs = session.execute(maybeExecutingAs(statement, readConf.executeAs))
39-
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, codecRegistry)
40-
val prefetchingIterator = new PrefetchingResultSetIterator(rs, readConf.fetchSizeInRows)
41-
val rateLimitingIterator = readConf.throughputMiBPS match {
42-
case Some(throughput) =>
43-
val rateLimiter = new RateLimiter((throughput * 1024 * 1024).toLong, 1024 * 1024)
44-
prefetchingIterator.map { row =>
45-
rateLimiter.maybeSleep(getRowBinarySize(row))
46-
row
47-
}
48-
case None =>
49-
prefetchingIterator
50-
}
42+
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
5143

52-
ScanResult(rateLimitingIterator, columnMetaData)
44+
val rs = session.executeAsync(maybeExecutingAs(statement, readConf.executeAs))
45+
val scanResult = asScalaFuture(rs).map { rs =>
46+
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, codecRegistry)
47+
val prefetchingIterator = new PrefetchingResultSetIterator(rs, readConf.fetchSizeInRows)
48+
val rateLimitingIterator = readConf.throughputMiBPS match {
49+
case Some(throughput) =>
50+
val rateLimiter = new RateLimiter((throughput * 1024 * 1024).toLong, 1024 * 1024)
51+
prefetchingIterator.map { row =>
52+
rateLimiter.maybeSleep(getRowBinarySize(row))
53+
row
54+
}
55+
case None =>
56+
prefetchingIterator
57+
}
58+
ScanResult(rateLimitingIterator, columnMetaData)
59+
}
60+
Await.result(scanResult, Duration.Inf)
5361
}
5462

5563
override def getSession(): CqlSession = session

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraCoGroupedRDD.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ package com.datastax.spark.connector.rdd
77

88
import java.io.IOException
99

10+
import com.datastax.bdp.util.ScalaJavaUtil._
1011
import com.datastax.oss.driver.api.core.CqlSession
1112
import com.datastax.oss.driver.api.core.cql.{BoundStatement, Row}
1213
import com.datastax.spark.connector.util._
14+
1315
import scala.collection.JavaConversions._
1416
import scala.language.existentials
1517
import scala.reflect.ClassTag
@@ -28,6 +30,9 @@ import com.datastax.spark.connector.types.ColumnType
2830
import com.datastax.spark.connector.util.Quote._
2931
import com.datastax.spark.connector.util.{CountingIterator, MultiMergeJoinIterator, NameTools}
3032

33+
import scala.concurrent.Await
34+
import scala.concurrent.duration.Duration
35+
3136
/**
3237
* A RDD which pulls from provided separate CassandraTableScanRDDs which share partition keys type and
3338
* keyspaces. These tables will be joined on READ using a merge iterator. As long as we join
@@ -158,21 +163,21 @@ class CassandraCoGroupedRDD[T](
158163
s"with params ${values.mkString("[", ",", "]")}")
159164
val stmt = createStatement(session, fromRDD.readConf, cql, values: _*)
160165

161-
try {
162-
val rs = session.execute(stmt)
166+
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
167+
168+
val fetchResult = asScalaFuture(session.executeAsync(stmt)).map { rs =>
163169
val columnNames = fromRDD.selectedColumnRefs.map(_.selectedAs).toIndexedSeq ++ Seq(TokenColumn)
164-
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames,rs, session)
170+
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, session.getContext.getCodecRegistry)
165171
val iterator = new PrefetchingResultSetIterator(rs, fromRDD.readConf.fetchSizeInRows)
166172
val iteratorWithMetrics = iterator.map(inputMetricsUpdater.updateMetrics)
167173
logDebug(s"Row iterator for range $range obtained successfully.")
168174
(columnMetaData, iteratorWithMetrics)
169-
} catch {
170-
case t: Throwable =>
171-
throw new IOException(s"Exception during execution of $cql: ${t.getMessage}", t)
175+
}.recover {
176+
case t: Throwable => throw new IOException(s"Exception during execution of $cql: ${t.getMessage}", t)
172177
}
178+
Await.result(fetchResult, Duration.Inf)
173179
}
174180

175-
176181
@DeveloperApi
177182
override def compute(split: Partition, context: TaskContext): Iterator[Seq[Seq[T]]] = {
178183
/** Open two sessions if Cluster Configurations are different **/

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,11 @@ class CassandraJoinRDD[L, R] (
131131
val resultFuture = SettableFuture.create[Iterator[(L, R)]]
132132
val leftSide = Iterator.continually(left)
133133

134+
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
135+
134136
queryExecutor.executeAsync(bsb.bind(left).executeAs(readConf.executeAs)).onComplete {
135137
case Success(rs) =>
136-
val resultSet = new PrefetchingResultSetIterator(ResultSets.newInstance(rs), fetchSize)
138+
val resultSet = new PrefetchingResultSetIterator(rs, fetchSize)
137139
val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics)
138140
/* This is a much less than ideal place to actually rate limit, we are buffering
139141
these futures this means we will most likely exceed our threshold*/
@@ -142,13 +144,11 @@ class CassandraJoinRDD[L, R] (
142144
resultFuture.set(leftSide.zip(rightSide))
143145
case Failure(throwable) =>
144146
resultFuture.setException(throwable)
145-
}(ExecutionContext.Implicits.global) // TODO: use dedicated context, use Future down the road, remove SettableFuture
147+
}
146148

147149
resultFuture
148150
}
149151

150-
151-
152152
val queryFutures = leftIterator.map(left => {
153153
requestsPerSecondRateLimiter.maybeSleep(1)
154154
pairWithRight(left)

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ class CassandraLeftJoinRDD[L, R] (
149149
leftIterator: Iterator[L],
150150
metricsUpdater: InputMetricsUpdater
151151
): Iterator[(L, Option[R])] = {
152+
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
152153

153154
val queryExecutor = QueryExecutor(session, readConf.parallelismLevel, None, None)
154155

@@ -158,7 +159,7 @@ class CassandraLeftJoinRDD[L, R] (
158159

159160
queryExecutor.executeAsync(bsb.bind(left).executeAs(readConf.executeAs)).onComplete {
160161
case Success(rs) =>
161-
val resultSet = new PrefetchingResultSetIterator(ResultSets.newInstance(rs), fetchSize)
162+
val resultSet = new PrefetchingResultSetIterator(rs, fetchSize)
162163
val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics)
163164
/* This is a much less than ideal place to actually rate limit, we are buffering
164165
these futures this means we will most likely exceed our threshold*/
@@ -170,10 +171,11 @@ class CassandraLeftJoinRDD[L, R] (
170171
resultFuture.set(leftSide.zip(rightSide))
171172
case Failure(throwable) =>
172173
resultFuture.setException(throwable)
173-
}(ExecutionContext.Implicits.global) // TODO: use dedicated context, use Future instead of SettableFuture
174+
}
174175

175176
resultFuture
176177
}
178+
177179
val queryFutures = leftIterator.map(left => {
178180
requestsPerSecondRateLimiter.maybeSleep(1)
179181
pairWithRight(left)

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraMergeJoinRDD.scala

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,30 @@ package com.datastax.spark.connector.rdd
77

88
import java.io.IOException
99

10+
import com.datastax.bdp.util.ScalaJavaUtil.asScalaFuture
11+
1012
import scala.collection.JavaConversions._
1113
import scala.language.existentials
1214
import scala.reflect.ClassTag
1315
import org.apache.spark.annotation.DeveloperApi
1416
import org.apache.spark.metrics.InputMetricsUpdater
1517
import org.apache.spark.rdd.RDD
1618
import org.apache.spark.{Partition, SparkContext, TaskContext}
17-
import com.datastax.driver.core._
1819
import com.datastax.oss.driver.api.core.CqlSession
19-
import com.datastax.oss.driver.api.core.cql.{BoundStatement, Row, Statement}
20+
import com.datastax.oss.driver.api.core.cql.{BoundStatement, Row}
2021
import com.datastax.oss.driver.api.core.metadata.Metadata
2122
import com.datastax.oss.driver.api.core.metadata.token.Token
2223
import com.datastax.spark.connector.CassandraRowMetadata
23-
import com.datastax.spark.connector.cql.{CassandraConnector, ColumnDef, Schema}
24+
import com.datastax.spark.connector.cql.{CassandraConnector, ColumnDef}
2425
import com.datastax.spark.connector.rdd.partitioner.{CassandraPartition, CqlTokenRange, NodeAddresses}
2526
import com.datastax.spark.connector.rdd.reader.{PrefetchingResultSetIterator, RowReader}
2627
import com.datastax.spark.connector.types.ColumnType
2728
import com.datastax.spark.connector.util.Quote._
2829
import com.datastax.spark.connector.util.{CountingIterator, MergeJoinIterator, NameTools, schemaFromCassandra}
2930

31+
import scala.concurrent.Await
32+
import scala.concurrent.duration.Duration
33+
3034
/**
3135
* A RDD which pulls from two separate CassandraTableScanRDDs which share partition keys and
3236
* keyspaces. These tables will be joined on READ using a merge iterator. As long as we join
@@ -151,21 +155,21 @@ class CassandraMergeJoinRDD[L,R](
151155
s"with params ${values.mkString("[", ",", "]")}")
152156
val stmt = createStatement(session, fromRDD.readConf, cql, values: _*)
153157

154-
try {
155-
val rs = session.execute(stmt)
158+
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
159+
160+
val fetchResult = asScalaFuture(session.executeAsync(stmt)).map { rs =>
156161
val columnNames = fromRDD.selectedColumnRefs.map(_.selectedAs).toIndexedSeq ++ Seq(TokenColumn)
157162
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, session)
158163
val iterator = new PrefetchingResultSetIterator(rs, fromRDD.readConf.fetchSizeInRows)
159164
val iteratorWithMetrics = iterator.map(inputMetricsUpdater.updateMetrics)
160165
logDebug(s"Row iterator for range $range obtained successfully.")
161166
(columnMetaData, iteratorWithMetrics)
162-
} catch {
163-
case t: Throwable =>
164-
throw new IOException(s"Exception during execution of $cql: ${t.getMessage}", t)
167+
}.recover {
168+
case t: Throwable => throw new IOException(s"Exception during execution of $cql: ${t.getMessage}", t)
165169
}
170+
Await.result(fetchResult, Duration.Inf)
166171
}
167172

168-
169173
@DeveloperApi
170174
override def compute(split: Partition, context: TaskContext): Iterator[(Seq[L], Seq[R])] = {
171175

connector/src/main/scala/com/datastax/spark/connector/rdd/reader/PrefetchingResultSetIterator.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
package com.datastax.spark.connector.rdd.reader
22

3-
import java.util.concurrent.TimeUnit
4-
53
import com.codahale.metrics.Timer
6-
import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, ResultSet, Row}
7-
import com.datastax.oss.driver.internal.core.cql.MultiPageResultSet
8-
import com.google.common.util.concurrent.{FutureCallback, Futures, ListenableFuture}
4+
import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, Row}
5+
import com.datastax.oss.driver.internal.core.cql.ResultSets
96

107
/** Allows to efficiently iterate over a large, paged ResultSet,
118
* asynchronously prefetching the next page.
@@ -15,10 +12,10 @@ import com.google.common.util.concurrent.{FutureCallback, Futures, ListenableFut
1512
* initiates fetching the next page
1613
* @param timer a Codahale timer to optionally gather the metrics of fetching time
1714
*/
18-
class PrefetchingResultSetIterator(resultSet: ResultSet, prefetchWindowSize: Int, timer: Option[Timer] = None)
15+
class PrefetchingResultSetIterator(resultSet: AsyncResultSet, prefetchWindowSize: Int, timer: Option[Timer] = None)
1916
extends Iterator[Row] {
2017

21-
private[this] val iterator = resultSet.iterator()
18+
private val iterator = ResultSets.newInstance(resultSet).iterator() //TODO
2219

2320
override def hasNext = iterator.hasNext
2421

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package com.datastax.spark.connector.util
2+
3+
import java.util.concurrent.{Executors, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
4+
5+
import com.google.common.util.concurrent.ThreadFactoryBuilder
6+
7+
import scala.concurrent.ExecutionContext
8+
9+
object Threads {
10+
11+
implicit val BlockingIOExecutionContext = {
12+
val threadFactory = new ThreadFactoryBuilder()
13+
.setDaemon(true)
14+
.setNameFormat("spark-cassandra-connector-io" + "%d")
15+
.build
16+
ExecutionContext.fromExecutorService(Executors.newCachedThreadPool(threadFactory))
17+
}
18+
}
19+

driver/src/main/scala/com/datastax/spark/connector/CassandraRow.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package com.datastax.spark.connector
33
import com.datastax.oss.driver.api.core.CqlSession
44
import com.datastax.oss.driver.api.core.`type`.codec.TypeCodec
55
import com.datastax.oss.driver.api.core.`type`.codec.registry.CodecRegistry
6-
import com.datastax.oss.driver.api.core.cql.{ColumnDefinitions, PreparedStatement, ResultSet, Row}
6+
import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, ColumnDefinitions, PreparedStatement, ResultSet, Row}
77
import com.datastax.spark.connector.util.DriverUtil.toName
88

99
/** Represents a single row fetched from Cassandra.
@@ -129,22 +129,25 @@ case class CassandraRowMetadata(columnNames: IndexedSeq[String],
129129

130130
object CassandraRowMetadata {
131131

132+
def fromResultSet(columnNames: IndexedSeq[String], rs: AsyncResultSet, session: CqlSession): CassandraRowMetadata = {
133+
fromResultSet(columnNames: IndexedSeq[String], rs, session.getContext.getCodecRegistry)
134+
}
132135

133-
def fromResultSet(columnNames: IndexedSeq[String], rs: ResultSet, session: CqlSession) :CassandraRowMetadata = {
134-
fromResultSet(columnNames: IndexedSeq[String], rs: ResultSet, session.getContext.getCodecRegistry)
136+
def fromResultSet(columnNames: IndexedSeq[String], rs: AsyncResultSet, registry: CodecRegistry): CassandraRowMetadata = {
137+
fromColumnDefs(columnNames, rs.getColumnDefinitions, registry)
135138
}
136139

137-
def fromResultSet(columnNames: IndexedSeq[String], rs: ResultSet, registry: CodecRegistry) :CassandraRowMetadata = {
140+
def fromResultSet(columnNames: IndexedSeq[String], rs: ResultSet, registry: CodecRegistry): CassandraRowMetadata = {
138141
fromColumnDefs(columnNames, rs.getColumnDefinitions, registry)
139142
}
140143

141-
def fromPreparedStatement(columnNames: IndexedSeq[String], statement: PreparedStatement, registry: CodecRegistry) :CassandraRowMetadata = {
144+
def fromPreparedStatement(columnNames: IndexedSeq[String], statement: PreparedStatement, registry: CodecRegistry): CassandraRowMetadata = {
142145
fromColumnDefs(columnNames, statement.getResultSetDefinitions, registry)
143146
}
144147

145-
private def fromColumnDefs(columnNames: IndexedSeq[String], columnDefs: ColumnDefinitions, registry: CodecRegistry) = {
146-
import scala.collection.JavaConversions._
147-
val scalaColumnDefs = columnDefs.toList
148+
private def fromColumnDefs(columnNames: IndexedSeq[String], columnDefs: ColumnDefinitions, registry: CodecRegistry): CassandraRowMetadata = {
149+
import scala.collection.JavaConverters._
150+
val scalaColumnDefs = columnDefs.asScala.toList
148151
val rsColumnNames = scalaColumnDefs.map(c => toName(c.getName))
149152
val codecs = scalaColumnDefs.map(col => registry.codecFor(col.getType))
150153
.asInstanceOf[List[TypeCodec[AnyRef]]]

0 commit comments

Comments
 (0)