Skip to content

Commit ae825f6

Browse files
authored
Merge pull request #1272 from datastax/SPARKC-619-2.5
SPARKC-619 restore PrefetchingResultSetIterator
2 parents c2ccc68 + 3142d23 commit ae825f6

File tree

12 files changed

+234
-88
lines changed

12 files changed

+234
-88
lines changed

connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ trait SparkCassandraITSpecBase
183183

184184
implicit val ec = SparkCassandraITSpecBase.ec
185185

186+
def await[T](unit: Future[T]): T = {
187+
Await.result(unit, Duration.Inf)
188+
}
189+
186190
def awaitAll[T](units: Future[T]*): Seq[T] = {
187191
Await.result(Future.sequence(units), Duration.Inf)
188192
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package com.datastax.spark.connector.rdd.reader
2+
3+
import com.codahale.metrics.Timer
4+
import com.datastax.oss.driver.api.core.cql.SimpleStatement.newInstance
5+
import com.datastax.spark.connector.SparkCassandraITFlatSpecBase
6+
import com.datastax.spark.connector.cluster.DefaultCluster
7+
import com.datastax.spark.connector.cql.CassandraConnector
8+
import org.scalatest.concurrent.Eventually.{eventually, timeout}
9+
import org.scalatest.time.{Seconds, Span}
10+
11+
class PrefetchingResultSetIteratorSpec extends SparkCassandraITFlatSpecBase with DefaultCluster {
12+
13+
private val table = "prefetching"
14+
private val emptyTable = "empty_prefetching"
15+
override lazy val conn = CassandraConnector(sparkConf)
16+
17+
override def beforeClass {
18+
conn.withSessionDo { session =>
19+
session.execute(
20+
s"CREATE KEYSPACE IF NOT EXISTS $ks WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }")
21+
22+
session.execute(
23+
s"CREATE TABLE IF NOT EXISTS $ks.$table (key INT, x INT, PRIMARY KEY (key))")
24+
25+
session.execute(
26+
s"CREATE TABLE IF NOT EXISTS $ks.$emptyTable (key INT, x INT, PRIMARY KEY (key))")
27+
28+
awaitAll(
29+
for (i <- 1 to 999) yield {
30+
executor.executeAsync(newInstance(s"INSERT INTO $ks.$table (key, x) values ($i, $i)"))
31+
}
32+
)
33+
}
34+
}
35+
36+
"PrefetchingResultSetIterator" should "return all rows regardless of the page sizes" in {
37+
val pageSizes = Seq(1, 2, 5, 111, 998, 999, 1000, 1001)
38+
for (pageSize <- pageSizes) {
39+
withClue(s"Prefetching iterator failed for the page size: $pageSize") {
40+
val statement = newInstance(s"select * from $ks.$table").setPageSize(pageSize)
41+
val result = executor.executeAsync(statement).map(new PrefetchingResultSetIterator(_))
42+
await(result).toList should have size 999
43+
}
44+
}
45+
}
46+
47+
it should "be empty for an empty table" in {
48+
val statement = newInstance(s"select * from $ks.$emptyTable")
49+
val result = executor.executeAsync(statement).map(new PrefetchingResultSetIterator(_))
50+
51+
await(result).hasNext should be(false)
52+
intercept[NoSuchElementException] {
53+
await(result).next()
54+
}
55+
}
56+
57+
it should "update the provided timer" in {
58+
val statement = newInstance(s"select * from $ks.$table").setPageSize(200)
59+
val timer = new Timer()
60+
val result = executor.executeAsync(statement).map(rs => new PrefetchingResultSetIterator(rs, Option(timer)))
61+
await(result).toList
62+
63+
eventually(timeout(Span(2, Seconds))) {
64+
timer.getCount should be(4)
65+
}
66+
}
67+
}

connector/src/main/scala/com/datastax/bdp/util/ScalaJavaUtil.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
package com.datastax.bdp.util
88

99
import java.time.{Duration => JavaDuration}
10-
import java.util.concurrent.Callable
10+
import java.util.concurrent.{Callable, CompletionStage}
1111
import java.util.function
12-
import java.util.function.{Consumer, Predicate, Supplier}
12+
import java.util.function.{BiConsumer, Consumer, Predicate, Supplier}
1313

14+
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, Promise}
1415
import scala.concurrent.duration.{Duration => ScalaDuration}
1516
import scala.language.implicitConversions
1617

@@ -45,4 +46,19 @@ object ScalaJavaUtil {
4546
}
4647

4748
def asScalaFunction[T, R](f: java.util.function.Function[T, R]): T => R = x => f(x)
49+
50+
def asScalaFuture[T](completionStage: CompletionStage[T])
51+
(implicit context: ExecutionContextExecutor): Future[T] = {
52+
val promise = Promise[T]()
53+
completionStage.whenCompleteAsync(new BiConsumer[T, java.lang.Throwable] {
54+
override def accept(t: T, throwable: Throwable): Unit = {
55+
if (throwable == null)
56+
promise.success(t)
57+
else
58+
promise.failure(throwable)
59+
60+
}
61+
}, context)
62+
promise.future
63+
}
4864
}

connector/src/main/scala/com/datastax/spark/connector/cql/Scanner.scala

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.datastax.spark.connector.cql
22

3+
import com.datastax.bdp.util.ScalaJavaUtil.asScalaFuture
34
import com.datastax.oss.driver.api.core.CqlSession
45
import com.datastax.oss.driver.api.core.cql.{Row, Statement}
56
import com.datastax.spark.connector.CassandraRowMetadata
@@ -8,6 +9,9 @@ import com.datastax.spark.connector.rdd.reader.PrefetchingResultSetIterator
89
import com.datastax.spark.connector.util.maybeExecutingAs
910
import com.datastax.spark.connector.writer.RateLimiter
1011

12+
import scala.concurrent.duration.Duration
13+
import scala.concurrent.{Await}
14+
1115
/**
1216
* Object which will be used in Table Scanning Operations.
1317
* One Scanner will be created per Spark Partition, it will be
@@ -35,21 +39,25 @@ class DefaultScanner (
3539
}
3640

3741
override def scan[StatementT <: Statement[StatementT]](statement: StatementT): ScanResult = {
38-
val rs = session.execute(maybeExecutingAs(statement, readConf.executeAs))
39-
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, codecRegistry)
40-
val prefetchingIterator = new PrefetchingResultSetIterator(rs, readConf.fetchSizeInRows)
41-
val rateLimitingIterator = readConf.throughputMiBPS match {
42-
case Some(throughput) =>
43-
val rateLimiter = new RateLimiter((throughput * 1024 * 1024).toLong, 1024 * 1024)
44-
prefetchingIterator.map { row =>
45-
rateLimiter.maybeSleep(getRowBinarySize(row))
46-
row
47-
}
48-
case None =>
49-
prefetchingIterator
50-
}
42+
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
5143

52-
ScanResult(rateLimitingIterator, columnMetaData)
44+
val rs = session.executeAsync(maybeExecutingAs(statement, readConf.executeAs))
45+
val scanResult = asScalaFuture(rs).map { rs =>
46+
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, codecRegistry)
47+
val prefetchingIterator = new PrefetchingResultSetIterator(rs)
48+
val rateLimitingIterator = readConf.throughputMiBPS match {
49+
case Some(throughput) =>
50+
val rateLimiter = new RateLimiter((throughput * 1024 * 1024).toLong, 1024 * 1024)
51+
prefetchingIterator.map { row =>
52+
rateLimiter.maybeSleep(getRowBinarySize(row))
53+
row
54+
}
55+
case None =>
56+
prefetchingIterator
57+
}
58+
ScanResult(rateLimitingIterator, columnMetaData)
59+
}
60+
Await.result(scanResult, Duration.Inf)
5361
}
5462

5563
override def getSession(): CqlSession = session

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraCoGroupedRDD.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ package com.datastax.spark.connector.rdd
77

88
import java.io.IOException
99

10+
import com.datastax.bdp.util.ScalaJavaUtil._
1011
import com.datastax.oss.driver.api.core.CqlSession
1112
import com.datastax.oss.driver.api.core.cql.{BoundStatement, Row}
1213
import com.datastax.spark.connector.util._
14+
1315
import scala.collection.JavaConversions._
1416
import scala.language.existentials
1517
import scala.reflect.ClassTag
@@ -28,6 +30,9 @@ import com.datastax.spark.connector.types.ColumnType
2830
import com.datastax.spark.connector.util.Quote._
2931
import com.datastax.spark.connector.util.{CountingIterator, MultiMergeJoinIterator, NameTools}
3032

33+
import scala.concurrent.Await
34+
import scala.concurrent.duration.Duration
35+
3136
/**
3237
* A RDD which pulls from provided separate CassandraTableScanRDDs which share partition keys type and
3338
* keyspaces. These tables will be joined on READ using a merge iterator. As long as we join
@@ -158,21 +163,21 @@ class CassandraCoGroupedRDD[T](
158163
s"with params ${values.mkString("[", ",", "]")}")
159164
val stmt = createStatement(session, fromRDD.readConf, cql, values: _*)
160165

161-
try {
162-
val rs = session.execute(stmt)
166+
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
167+
168+
val fetchResult = asScalaFuture(session.executeAsync(stmt)).map { rs =>
163169
val columnNames = fromRDD.selectedColumnRefs.map(_.selectedAs).toIndexedSeq ++ Seq(TokenColumn)
164-
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames,rs, session)
165-
val iterator = new PrefetchingResultSetIterator(rs, fromRDD.readConf.fetchSizeInRows)
170+
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, session.getContext.getCodecRegistry)
171+
val iterator = new PrefetchingResultSetIterator(rs)
166172
val iteratorWithMetrics = iterator.map(inputMetricsUpdater.updateMetrics)
167173
logDebug(s"Row iterator for range $range obtained successfully.")
168174
(columnMetaData, iteratorWithMetrics)
169-
} catch {
170-
case t: Throwable =>
171-
throw new IOException(s"Exception during execution of $cql: ${t.getMessage}", t)
175+
}.recover {
176+
case t: Throwable => throw new IOException(s"Exception during execution of $cql: ${t.getMessage}", t)
172177
}
178+
Await.result(fetchResult, Duration.Inf)
173179
}
174180

175-
176181
@DeveloperApi
177182
override def compute(split: Partition, context: TaskContext): Iterator[Seq[Seq[T]]] = {
178183
/** Open two sessions if Cluster Configurations are different **/

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,20 @@ class CassandraJoinRDD[L, R] (
124124
metricsUpdater: InputMetricsUpdater
125125
): Iterator[(L, R)] = {
126126

127-
128127
val queryExecutor = QueryExecutor(session, readConf.parallelismLevel, None, None)
129128

130129
def pairWithRight(left: L): SettableFuture[Iterator[(L, R)]] = {
131130
val resultFuture = SettableFuture.create[Iterator[(L, R)]]
132131
val leftSide = Iterator.continually(left)
133132

134-
queryExecutor.executeAsync(bsb.bind(left).executeAs(readConf.executeAs)).onComplete {
133+
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
134+
135+
val stmt = bsb.bind(left)
136+
.update(_.setPageSize(readConf.fetchSizeInRows))
137+
.executeAs(readConf.executeAs)
138+
queryExecutor.executeAsync(stmt).onComplete {
135139
case Success(rs) =>
136-
val resultSet = new PrefetchingResultSetIterator(ResultSets.newInstance(rs), fetchSize)
140+
val resultSet = new PrefetchingResultSetIterator(rs)
137141
val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics)
138142
/* This is a much less than ideal place to actually rate limit, we are buffering
139143
these futures this means we will most likely exceed our threshold*/
@@ -142,13 +146,11 @@ class CassandraJoinRDD[L, R] (
142146
resultFuture.set(leftSide.zip(rightSide))
143147
case Failure(throwable) =>
144148
resultFuture.setException(throwable)
145-
}(ExecutionContext.Implicits.global) // TODO: use dedicated context, use Future down the road, remove SettableFuture
149+
}
146150

147151
resultFuture
148152
}
149153

150-
151-
152154
val queryFutures = leftIterator.map(left => {
153155
requestsPerSecondRateLimiter.maybeSleep(1)
154156
pairWithRight(left)

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,20 @@ class CassandraLeftJoinRDD[L, R] (
149149
leftIterator: Iterator[L],
150150
metricsUpdater: InputMetricsUpdater
151151
): Iterator[(L, Option[R])] = {
152+
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
152153

153154
val queryExecutor = QueryExecutor(session, readConf.parallelismLevel, None, None)
154155

155156
def pairWithRight(left: L): SettableFuture[Iterator[(L, Option[R])]] = {
156157
val resultFuture = SettableFuture.create[Iterator[(L, Option[R])]]
157158
val leftSide = Iterator.continually(left)
158159

159-
queryExecutor.executeAsync(bsb.bind(left).executeAs(readConf.executeAs)).onComplete {
160+
val stmt = bsb.bind(left)
161+
.update(_.setPageSize(readConf.fetchSizeInRows))
162+
.executeAs(readConf.executeAs)
163+
queryExecutor.executeAsync(stmt).onComplete {
160164
case Success(rs) =>
161-
val resultSet = new PrefetchingResultSetIterator(ResultSets.newInstance(rs), fetchSize)
165+
val resultSet = new PrefetchingResultSetIterator(rs)
162166
val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics)
163167
/* This is a much less than ideal place to actually rate limit, we are buffering
164168
these futures this means we will most likely exceed our threshold*/
@@ -170,10 +174,11 @@ class CassandraLeftJoinRDD[L, R] (
170174
resultFuture.set(leftSide.zip(rightSide))
171175
case Failure(throwable) =>
172176
resultFuture.setException(throwable)
173-
}(ExecutionContext.Implicits.global) // TODO: use dedicated context, use Future instead of SettableFuture
177+
}
174178

175179
resultFuture
176180
}
181+
177182
val queryFutures = leftIterator.map(left => {
178183
requestsPerSecondRateLimiter.maybeSleep(1)
179184
pairWithRight(left)

connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraMergeJoinRDD.scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,30 @@ package com.datastax.spark.connector.rdd
77

88
import java.io.IOException
99

10+
import com.datastax.bdp.util.ScalaJavaUtil.asScalaFuture
11+
1012
import scala.collection.JavaConversions._
1113
import scala.language.existentials
1214
import scala.reflect.ClassTag
1315
import org.apache.spark.annotation.DeveloperApi
1416
import org.apache.spark.metrics.InputMetricsUpdater
1517
import org.apache.spark.rdd.RDD
1618
import org.apache.spark.{Partition, SparkContext, TaskContext}
17-
import com.datastax.driver.core._
1819
import com.datastax.oss.driver.api.core.CqlSession
19-
import com.datastax.oss.driver.api.core.cql.{BoundStatement, Row, Statement}
20+
import com.datastax.oss.driver.api.core.cql.{BoundStatement, Row}
2021
import com.datastax.oss.driver.api.core.metadata.Metadata
2122
import com.datastax.oss.driver.api.core.metadata.token.Token
2223
import com.datastax.spark.connector.CassandraRowMetadata
23-
import com.datastax.spark.connector.cql.{CassandraConnector, ColumnDef, Schema}
24+
import com.datastax.spark.connector.cql.{CassandraConnector, ColumnDef}
2425
import com.datastax.spark.connector.rdd.partitioner.{CassandraPartition, CqlTokenRange, NodeAddresses}
2526
import com.datastax.spark.connector.rdd.reader.{PrefetchingResultSetIterator, RowReader}
2627
import com.datastax.spark.connector.types.ColumnType
2728
import com.datastax.spark.connector.util.Quote._
2829
import com.datastax.spark.connector.util.{CountingIterator, MergeJoinIterator, NameTools, schemaFromCassandra}
2930

31+
import scala.concurrent.Await
32+
import scala.concurrent.duration.Duration
33+
3034
/**
3135
* A RDD which pulls from two separate CassandraTableScanRDDs which share partition keys and
3236
* keyspaces. These tables will be joined on READ using a merge iterator. As long as we join
@@ -151,21 +155,21 @@ class CassandraMergeJoinRDD[L,R](
151155
s"with params ${values.mkString("[", ",", "]")}")
152156
val stmt = createStatement(session, fromRDD.readConf, cql, values: _*)
153157

154-
try {
155-
val rs = session.execute(stmt)
158+
import com.datastax.spark.connector.util.Threads.BlockingIOExecutionContext
159+
160+
val fetchResult = asScalaFuture(session.executeAsync(stmt)).map { rs =>
156161
val columnNames = fromRDD.selectedColumnRefs.map(_.selectedAs).toIndexedSeq ++ Seq(TokenColumn)
157162
val columnMetaData = CassandraRowMetadata.fromResultSet(columnNames, rs, session)
158-
val iterator = new PrefetchingResultSetIterator(rs, fromRDD.readConf.fetchSizeInRows)
163+
val iterator = new PrefetchingResultSetIterator(rs)
159164
val iteratorWithMetrics = iterator.map(inputMetricsUpdater.updateMetrics)
160165
logDebug(s"Row iterator for range $range obtained successfully.")
161166
(columnMetaData, iteratorWithMetrics)
162-
} catch {
163-
case t: Throwable =>
164-
throw new IOException(s"Exception during execution of $cql: ${t.getMessage}", t)
167+
}.recover {
168+
case t: Throwable => throw new IOException(s"Exception during execution of $cql: ${t.getMessage}", t)
165169
}
170+
Await.result(fetchResult, Duration.Inf)
166171
}
167172

168-
169173
@DeveloperApi
170174
override def compute(split: Partition, context: TaskContext): Iterator[(Seq[L], Seq[R])] = {
171175

0 commit comments

Comments
 (0)