Skip to content

Commit 3142d23

Browse files
committed
SPARKC-619 restore PrefetchingResultSetIterator
The iterator tries to prefetch the next result page. Once the current page is exhausted, the next page (hopefully materialized at this point) becomes the current page and a fetch request is sent. fetchSize argument was removed, the supplied statement should have pageSize parameter set by the caller.
1 parent 443f7e4 commit 3142d23

File tree

11 files changed

+141
-49
lines changed

11 files changed

+141
-49
lines changed

connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ trait SparkCassandraITSpecBase
183183

184184
implicit val ec = SparkCassandraITSpecBase.ec
185185

186+
def await[T](unit: Future[T]): T = {
187+
Await.result(unit, Duration.Inf)
188+
}
189+
186190
def awaitAll[T](units: Future[T]*): Seq[T] = {
187191
Await.result(Future.sequence(units), Duration.Inf)
188192
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package com.datastax.spark.connector.rdd.reader
2+
3+
import com.codahale.metrics.Timer
4+
import com.datastax.oss.driver.api.core.cql.SimpleStatement.newInstance
5+
import com.datastax.spark.connector.SparkCassandraITFlatSpecBase
6+
import com.datastax.spark.connector.cluster.DefaultCluster
7+
import com.datastax.spark.connector.cql.CassandraConnector
8+
import org.scalatest.concurrent.Eventually.{eventually, timeout}
9+
import org.scalatest.time.{Seconds, Span}
10+
11+
class PrefetchingResultSetIteratorSpec extends SparkCassandraITFlatSpecBase with DefaultCluster {
12+
13+
private val table = "prefetching"
14+
private val emptyTable = "empty_prefetching"
15+
override lazy val conn = CassandraConnector(sparkConf)
16+
17+
override def beforeClass {
18+
conn.withSessionDo { session =>
19+
session.execute(
20+
s"CREATE KEYSPACE IF NOT EXISTS $ks WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }")
21+
22+
session.execute(
23+
s"CREATE TABLE IF NOT EXISTS $ks.$table (key INT, x INT, PRIMARY KEY (key))")
24+
25+
session.execute(
26+
s"CREATE TABLE IF NOT EXISTS $ks.$emptyTable (key INT, x INT, PRIMARY KEY (key))")
27+
28+
awaitAll(
29+
for (i <- 1 to 999) yield {
30+
executor.executeAsync(newInstance(s"INSERT INTO $ks.$table (key, x) values ($i, $i)"))
31+
}
32+
)
33+
}
34+
}
35+
36+
"PrefetchingResultSetIterator" should "return all rows regardless of the page sizes" in {
37+
val pageSizes = Seq(1, 2, 5, 111, 998, 999, 1000, 1001)
38+
for (pageSize <- pageSizes) {
39+
withClue(s"Prefetching iterator failed for the page size: $pageSize") {
40+
val statement = newInstance(s"select * from $ks.$table").setPageSize(pageSize)
41+
val result = executor.executeAsync(statement).map(new PrefetchingResultSetIterator(_))
42+
await(result).toList should have size 999
43+
}
44+
}
45+
}
46+
47+
it should "be empty for an empty table" in {
48+
val statement = newInstance(s"select * from $ks.$emptyTable")
49+
val result = executor.executeAsync(statement).map(new PrefetchingResultSetIterator(_))
50+
51+
await(result).hasNext should be(false)
52+
intercept[NoSuchElementException] {
53+
await(result).next()
54+
}
55+
}
56+
57+
it should "update the provided timer" in {
58+
val statement = newInstance(s"select * from $ks.$table").setPageSize(200)
59+
val timer = new Timer()
60+
val result = executor.executeAsync(statement).map(rs => new PrefetchingResultSetIterator(rs, Option(timer)))
61+
await(result).toList
62+
63+
eventually(timeout(Span(2, Seconds))) {
64+
timer.getCount should be(4)
65+
}
66+
}
67+
}

connector/src/main/scala/com/datastax/bdp/util/ScalaJavaUtil.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import java.util.concurrent.{Callable, CompletionStage}
1111
import java.util.function
1212
import java.util.function.{BiConsumer, Consumer, Predicate, Supplier}
1313

14-
import scala.concurrent.{ExecutionContext, Future, Promise}
14+
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, Promise}
1515
import scala.concurrent.duration.{Duration => ScalaDuration}
1616
import scala.language.implicitConversions
1717

@@ -48,7 +48,7 @@ object ScalaJavaUtil {
4848
def asScalaFunction[T, R](f: java.util.function.Function[T, R]): T => R = x => f(x)
4949

5050
def asScalaFuture[T](completionStage: CompletionStage[T])
51-
(implicit context: ExecutionContext): Future[T] = {
51+
(implicit context: ExecutionContextExecutor): Future[T] = {
5252
val promise = Promise[T]()
5353
completionStage.whenCompleteAsync(new BiConsumer[T, java.lang.Throwable] {
5454
override def accept(t: T, throwable: Throwable): Unit = {
@@ -58,7 +58,7 @@ object ScalaJavaUtil {
5858
promise.failure(throwable)
5959

6060
}
61-
})
61+
}, context)
6262
promise.future
6363
}
6464
}

connector/src/main/scala/com/datastax/spark/connector/cql/Scanner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class DefaultScanner (
4444
val rs = session.executeAsync(maybeExecutingAs(statement, readConf.executeAs))
4545
val scanResult = asScalaFuture(rs).map { rs =>
4646
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, codecRegistry)
47-
val prefetchingIterator = new PrefetchingResultSetIterator(rs, readConf.fetchSizeInRows)
47+
val prefetchingIterator = new PrefetchingResultSetIterator(rs)
4848
val rateLimitingIterator = readConf.throughputMiBPS match {
4949
case Some(throughput) =>
5050
val rateLimiter = new RateLimiter((throughput * 1024 * 1024).toLong, 1024 * 1024)

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraCoGroupedRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class CassandraCoGroupedRDD[T](
168168
val fetchResult = asScalaFuture(session.executeAsync(stmt)).map { rs =>
169169
val columnNames = fromRDD.selectedColumnRefs.map(_.selectedAs).toIndexedSeq ++ Seq(TokenColumn)
170170
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, session.getContext.getCodecRegistry)
171-
val iterator = new PrefetchingResultSetIterator(rs, fromRDD.readConf.fetchSizeInRows)
171+
val iterator = new PrefetchingResultSetIterator(rs)
172172
val iteratorWithMetrics = iterator.map(inputMetricsUpdater.updateMetrics)
173173
logDebug(s"Row iterator for range $range obtained successfully.")
174174
(columnMetaData, iteratorWithMetrics)

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ class CassandraJoinRDD[L, R] (
124124
metricsUpdater: InputMetricsUpdater
125125
): Iterator[(L, R)] = {
126126

127-
128127
val queryExecutor = QueryExecutor(session, readConf.parallelismLevel, None, None)
129128

130129
def pairWithRight(left: L): SettableFuture[Iterator[(L, R)]] = {
@@ -133,9 +132,12 @@ class CassandraJoinRDD[L, R] (
133132

134133
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
135134

136-
queryExecutor.executeAsync(bsb.bind(left).executeAs(readConf.executeAs)).onComplete {
135+
val stmt = bsb.bind(left)
136+
.update(_.setPageSize(readConf.fetchSizeInRows))
137+
.executeAs(readConf.executeAs)
138+
queryExecutor.executeAsync(stmt).onComplete {
137139
case Success(rs) =>
138-
val resultSet = new PrefetchingResultSetIterator(rs, fetchSize)
140+
val resultSet = new PrefetchingResultSetIterator(rs)
139141
val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics)
140142
/* This is a much less than ideal place to actually rate limit, we are buffering
141143
these futures this means we will most likely exceed our threshold*/

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,12 @@ class CassandraLeftJoinRDD[L, R] (
157157
val resultFuture = SettableFuture.create[Iterator[(L, Option[R])]]
158158
val leftSide = Iterator.continually(left)
159159

160-
queryExecutor.executeAsync(bsb.bind(left).executeAs(readConf.executeAs)).onComplete {
160+
val stmt = bsb.bind(left)
161+
.update(_.setPageSize(readConf.fetchSizeInRows))
162+
.executeAs(readConf.executeAs)
163+
queryExecutor.executeAsync(stmt).onComplete {
161164
case Success(rs) =>
162-
val resultSet = new PrefetchingResultSetIterator(rs, fetchSize)
165+
val resultSet = new PrefetchingResultSetIterator(rs)
163166
val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics)
164167
/* This is a much less than ideal place to actually rate limit, we are buffering
165168
these futures this means we will most likely exceed our threshold*/

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraMergeJoinRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class CassandraMergeJoinRDD[L,R](
160160
val fetchResult = asScalaFuture(session.executeAsync(stmt)).map { rs =>
161161
val columnNames = fromRDD.selectedColumnRefs.map(_.selectedAs).toIndexedSeq ++ Seq(TokenColumn)
162162
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, session)
163-
val iterator = new PrefetchingResultSetIterator(rs, fromRDD.readConf.fetchSizeInRows)
163+
val iterator = new PrefetchingResultSetIterator(rs)
164164
val iteratorWithMetrics = iterator.map(inputMetricsUpdater.updateMetrics)
165165
logDebug(s"Row iterator for range $range obtained successfully.")
166166
(columnMetaData, iteratorWithMetrics)
Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,55 @@
11
package com.datastax.spark.connector.rdd.reader
22

3+
import java.util.concurrent.TimeUnit
4+
35
import com.codahale.metrics.Timer
6+
import com.datastax.bdp.util.ScalaJavaUtil
47
import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, Row}
5-
import com.datastax.oss.driver.internal.core.cql.ResultSets
8+
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
9+
10+
import scala.concurrent.duration.Duration
11+
import scala.concurrent.{Await, Future}
612

713
/** Allows to efficiently iterate over a large, paged ResultSet,
814
* asynchronously prefetching the next page.
9-
*
15+
*
16+
* This iterator is NOT thread safe. Attempting to retrieve elements from many threads without synchronization
17+
* may yield unspecified results.
18+
*
1019
* @param resultSet result set obtained from the Java driver
11-
* @param prefetchWindowSize if there are less than this rows available without blocking,
12-
* initiates fetching the next page
13-
* @param timer a Codahale timer to optionally gather the metrics of fetching time
20+
* @param timer a Codahale timer to optionally gather the metrics of fetching time
1421
*/
15-
class PrefetchingResultSetIterator(resultSet: AsyncResultSet, prefetchWindowSize: Int, timer: Option[Timer] = None)
16-
extends Iterator[Row] {
17-
18-
private val iterator = ResultSets.newInstance(resultSet).iterator() //TODO
19-
20-
override def hasNext = iterator.hasNext
21-
22-
// TODO: implement async page fetching. Following implementation might call fetchMoreResults up to prefetchWindowSize
23-
// times to fetch the same page. Is this behaviour still valid in the new driver?
24-
// This class should take AsyncResultSet as constructor param (not ResultSet)
25-
26-
// private[this] def maybePrefetch(): Unit = {
27-
// if (!resultSet.isFullyFetched && resultSet.getAvailableWithoutFetching < prefetchWindowSize) {
28-
// val t0 = System.nanoTime()
29-
// val future: ListenableFuture[ResultSet] = resultSet.fetchMoreResults()
30-
// if (timer.isDefined)
31-
// Futures.addCallback(future, new FutureCallback[ResultSet] {
32-
// override def onSuccess(ignored: ResultSet): Unit = {
33-
// timer.get.update(System.nanoTime() - t0, TimeUnit.NANOSECONDS)
34-
// }
35-
//
36-
// override def onFailure(ignored: Throwable): Unit = { }
37-
// })
38-
// }
39-
// }
40-
41-
override def next() = {
42-
// maybePrefetch()
43-
iterator.next()
22+
class PrefetchingResultSetIterator(resultSet: AsyncResultSet, timer: Option[Timer] = None) extends Iterator[Row] {
23+
private var currentIterator = resultSet.currentPage().iterator()
24+
private var currentResultSet = resultSet
25+
private var nextResultSet = fetchNextPage()
26+
27+
private def fetchNextPage(): Option[Future[AsyncResultSet]] = {
28+
if (currentResultSet.hasMorePages) {
29+
val t0 = System.nanoTime();
30+
val next = ScalaJavaUtil.asScalaFuture(currentResultSet.fetchNextPage())
31+
timer.foreach { t =>
32+
next.foreach(_ => t.update(System.nanoTime() - t0, TimeUnit.NANOSECONDS))
33+
}
34+
Option(next)
35+
} else
36+
None
37+
}
38+
39+
private def maybePrefetch(): Unit = {
40+
if (!currentIterator.hasNext && currentResultSet.hasMorePages) {
41+
currentResultSet = Await.result(nextResultSet.get, Duration.Inf)
42+
currentIterator = currentResultSet.currentPage().iterator()
43+
nextResultSet = fetchNextPage()
44+
}
45+
}
46+
47+
override def hasNext: Boolean =
48+
currentIterator.hasNext || currentResultSet.hasMorePages
49+
50+
override def next(): Row = {
51+
val row = currentIterator.next() // let's try to exhaust the current iterator first
52+
maybePrefetch()
53+
row
4454
}
4555
}

connector/src/main/scala/com/datastax/spark/connector/util/Threads.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,19 @@ import java.util.concurrent.{Executors, LinkedBlockingQueue, ThreadPoolExecutor,
44

55
import com.google.common.util.concurrent.ThreadFactoryBuilder
66

7-
import scala.concurrent.ExecutionContext
7+
import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService}
88

9-
object Threads {
9+
object Threads extends Logging {
1010

11-
implicit val BlockingIOExecutionContext = {
11+
implicit val BlockingIOExecutionContext: ExecutionContextExecutorService = {
1212
val threadFactory = new ThreadFactoryBuilder()
1313
.setDaemon(true)
1414
.setNameFormat("spark-cassandra-connector-io" + "%d")
15+
.setUncaughtExceptionHandler(new Thread.UncaughtExceptionHandler {
16+
override def uncaughtException(t: Thread, e: Throwable): Unit = {
17+
logWarning(s"Unhandled exception in thread ${t.getName}.", e)
18+
}
19+
})
1520
.build
1621
ExecutionContext.fromExecutorService(Executors.newCachedThreadPool(threadFactory))
1722
}

0 commit comments

Comments
 (0)