Skip to content

Commit 9eb9e81

Browse files
SPARKC-693 Support for Spark 3.3 (#1351)
Co-authored-by: Jack Richard Buggins <[email protected]>
1 parent f125820 commit 9eb9e81

File tree

11 files changed

+29
-35
lines changed

11 files changed

+29
-35
lines changed

CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
3.3.0
2+
* Spark 3.3.x support (SPARKC-693)
13
3.2.0
24
* Spark 3.2.x support (SPARKC-670)
35
* Fix: Cassandra Direct Join doesn't quote keyspace and table names (SPARKC-667)

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Currently, the following branches are actively supported:
5252

5353
| Connector | Spark | Cassandra | Cassandra Java Driver | Minimum Java Version | Supported Scala Versions |
5454
| --------- | ------------- | --------- | --------------------- | -------------------- | ----------------------- |
55+
| 3.3 | 3.3 | 2.1.5*, 2.2, 3.x, 4.0 | 4.13 | 8 | 2.12 |
5556
| 3.2 | 3.2 | 2.1.5*, 2.2, 3.x, 4.0 | 4.13 | 8 | 2.12 |
5657
| 3.1 | 3.1 | 2.1.5*, 2.2, 3.x, 4.0 | 4.12 | 8 | 2.12 |
5758
| 3.0 | 3.0 | 2.1.5*, 2.2, 3.x, 4.0 | 4.12 | 8 | 2.12 |

connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ trait SparkCassandraITSpecBase
247247

248248
def getCassandraScan(plan: SparkPlan): CassandraScan = {
249249
plan.collectLeaves.collectFirst{
250-
case BatchScanExec(_, cassandraScan: CassandraScan, _) => cassandraScan
250+
case BatchScanExec(_, cassandraScan: CassandraScan, _, _) => cassandraScan
251251
}.getOrElse(throw new IllegalArgumentException("No Cassandra Scan Found"))
252252
}
253253

connector/src/it/scala/com/datastax/spark/connector/cql/sai/SaiBaseSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ trait SaiBaseSpec extends Matchers with SparkCassandraITSpecBase {
4747

4848
def findCassandraScan(plan: SparkPlan): CassandraScan = {
4949
plan match {
50-
case BatchScanExec(_, scan: CassandraScan, _) => scan
50+
case BatchScanExec(_, scan: CassandraScan, _, _) => scan
5151
case filter: FilterExec => findCassandraScan(filter.child)
5252
case project: ProjectExec => findCassandraScan(project.child)
5353
case _ => throw new NoSuchElementException("RowDataSourceScanExec was not found in the given plan")

connector/src/it/scala/com/datastax/spark/connector/sql/CassandraDataSourceSpec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,10 @@ class CassandraDataSourceSpec extends SparkCassandraITFlatSpecBase with DefaultC
274274
if (pushDown)
275275
withClue(s"Given Dataframe plan does not contain CassandraInJoin in its predecessors.\n${df.queryExecution.sparkPlan.toString()}") {
276276
df.queryExecution.executedPlan.collectLeaves().collectFirst{
277-
case a@BatchScanExec(_, _: CassandraInJoin, _) => a
277+
case a@BatchScanExec(_, _: CassandraInJoin, _, _) => a
278278
case b@AdaptiveSparkPlanExec(_, _, _, _, _) =>
279279
b.executedPlan.collectLeaves().collectFirst{
280-
case a@BatchScanExec(_, _: CassandraInJoin, _) => a
280+
case a@BatchScanExec(_, _: CassandraInJoin, _, _) => a
281281
}
282282
} shouldBe defined
283283
}
@@ -288,7 +288,7 @@ class CassandraDataSourceSpec extends SparkCassandraITFlatSpecBase with DefaultC
288288
private def assertOnAbsenceOfCassandraInJoin(df: DataFrame): Unit =
289289
withClue(s"Given Dataframe plan contains CassandraInJoin in its predecessors.\n${df.queryExecution.sparkPlan.toString()}") {
290290
df.queryExecution.executedPlan.collectLeaves().collectFirst{
291-
case a@BatchScanExec(_, _: CassandraInJoin, _) => a
291+
case a@BatchScanExec(_, _: CassandraInJoin, _, _) => a
292292
} shouldBe empty
293293
}
294294

connector/src/it/scala/com/datastax/spark/connector/util/CatalystUtil.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
77
object CatalystUtil {
88

99
def findCassandraScan(sparkPlan: SparkPlan): Option[CassandraScan] = {
10-
sparkPlan.collectFirst{ case BatchScanExec(_, scan: CassandraScan, _) => scan}
10+
sparkPlan.collectFirst{ case BatchScanExec(_, scan: CassandraScan, _, _) => scan}
1111
}
1212
}

connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraCatalog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class CassandraCatalog extends CatalogPlugin
190190
.asJava
191191
}
192192

193-
override def dropNamespace(namespace: Array[String]): Boolean = {
193+
override def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = {
194194
checkNamespace(namespace)
195195
val keyspace = getKeyspaceMeta(connector, namespace)
196196
val dropResult = connector.withSessionDo(session => session.execute(SchemaBuilder.dropKeyspace(keyspace.getName).asCql()))

connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraScanBuilder.scala

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ import com.datastax.spark.connector.{ColumnRef, RowCountRef, TTL, WriteTime}
1313
import org.apache.spark.SparkConf
1414
import org.apache.spark.sql.cassandra.CassandraSourceRelation.{AdditionalCassandraPushDownRulesParam, InClauseToJoinWithTableConversionThreshold}
1515
import org.apache.spark.sql.cassandra.{AnalyzedPredicates, Auto, BasicCassandraPredicatePushDown, CassandraPredicateRules, CassandraSourceRelation, DsePredicateRules, DseSearchOptimizationSetting, InClausePredicateRules, Off, On, SolrConstants, SolrPredicateRules, TimeUUIDPredicateRules}
16+
import org.apache.spark.sql.connector.expressions.{Expression, Expressions}
1617
import org.apache.spark.sql.connector.read._
17-
import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning}
18+
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning}
1819
import org.apache.spark.sql.sources.{EqualTo, Filter, In}
1920
import org.apache.spark.sql.types._
2021
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -307,7 +308,7 @@ case class CassandraScan(
307308
}
308309

309310
override def outputPartitioning(): Partitioning = {
310-
CassandraPartitioning(tableDef.partitionKey.map(_.columnName).toArray, inputPartitions.length)
311+
new CassandraPartitioning(tableDef.partitionKey.map(_.columnName).map(Expressions.identity).toArray, inputPartitions.length)
311312
}
312313

313314
override def description(): String = {
@@ -317,17 +318,7 @@ case class CassandraScan(
317318
}
318319
}
319320

320-
case class CassandraPartitioning(partitionKeys: Array[String], numPartitions: Int) extends Partitioning {
321-
322-
/*
323-
Currently we only satisfy distributions which rely on all partition key values having identical
324-
values. In the future we may be able to support some other distributions but Spark doesn't have
325-
means to support those atm 3.0
326-
*/
327-
override def satisfy(distribution: Distribution): Boolean = distribution match {
328-
case cD: ClusteredDistribution => partitionKeys.forall(cD.clusteredColumns.contains)
329-
case _ => false
330-
}
321+
class CassandraPartitioning(keys: Array[Expression], numPartitions: Int) extends KeyGroupedPartitioning(keys, numPartitions) {
331322
}
332323

333324
case class CassandraInJoin(
@@ -359,7 +350,7 @@ case class CassandraInJoin(
359350
}
360351

361352
override def outputPartitioning(): Partitioning = {
362-
CassandraPartitioning(tableDef.partitionKey.map(_.columnName).toArray, numPartitions)
353+
new CassandraPartitioning(tableDef.partitionKey.map(_.columnName).map(Expressions.identity).toArray, numPartitions)
363354
}
364355
}
365356

connector/src/main/scala/org/apache/spark/sql/cassandra/CassandraSourceRelation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ object CassandraSourceRelation extends Logging {
211211
oldPlan.transform {
212212
case ds@DataSourceV2Relation(_: CassandraTable, _, _, _, options) =>
213213
ds.copy(options = applyDirectJoinSetting(options, directJoinSetting))
214-
case ds@DataSourceV2ScanRelation(_: CassandraTable, scan: CassandraScan, _) =>
214+
case ds@DataSourceV2ScanRelation(_: CassandraTable, scan: CassandraScan, _, _) =>
215215
ds.copy(scan = scan.copy(consolidatedConf = applyDirectJoinSetting(scan.consolidatedConf, directJoinSetting)))
216216
}
217217
)

connector/src/main/scala/org/apache/spark/sql/cassandra/execution/CassandraDirectJoinStrategy.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ case class CassandraDirectJoinStrategy(spark: SparkSession) extends Strategy wit
2525
val conf = spark.sqlContext.conf
2626

2727
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
28-
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
28+
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, _, left, right, _)
2929
if hasValidDirectJoin(joinType, leftKeys, rightKeys, condition, left, right) =>
3030

3131
val (otherBranch, joinTargetBranch, buildType) = {
@@ -46,7 +46,7 @@ case class CassandraDirectJoinStrategy(spark: SparkSession) extends Strategy wit
4646
val cassandraScanExec = getScanExec(dataSourceOptimizedPlan).get
4747

4848
joinTargetBranch match {
49-
case PhysicalOperation(attributes, _, DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _)) =>
49+
case PhysicalOperation(attributes, _, DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _, _)) =>
5050
val directJoin =
5151
CassandraDirectJoinExec(
5252
leftKeys,
@@ -147,7 +147,7 @@ object CassandraDirectJoinStrategy extends Logging {
147147
*/
148148
def getScanExec(plan: SparkPlan): Option[BatchScanExec] = {
149149
plan.collectFirst {
150-
case exec @ BatchScanExec(_, _: CassandraScan, _) => exec
150+
case exec @ BatchScanExec(_, _: CassandraScan, _, _) => exec
151151
}
152152
}
153153

@@ -170,7 +170,7 @@ object CassandraDirectJoinStrategy extends Logging {
170170
def getDSV2CassandraRelation(plan: LogicalPlan): Option[DataSourceV2ScanRelation] = {
171171
val children = plan.collectLeaves()
172172
if (children.length == 1) {
173-
plan.collectLeaves().collectFirst { case ds @ DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _) => ds }
173+
plan.collectLeaves().collectFirst { case ds @ DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _, _) => ds }
174174
} else {
175175
None
176176
}
@@ -183,7 +183,7 @@ object CassandraDirectJoinStrategy extends Logging {
183183
def getCassandraTable(plan: LogicalPlan): Option[CassandraTable] = {
184184
val children = plan.collectLeaves()
185185
if (children.length == 1) {
186-
children.collectFirst { case DataSourceV2ScanRelation(DataSourceV2Relation(table: CassandraTable, _, _, _, _), _, _) => table }
186+
children.collectFirst { case DataSourceV2ScanRelation(DataSourceV2Relation(table: CassandraTable, _, _, _, _), _, _, _) => table }
187187
} else {
188188
None
189189
}
@@ -192,7 +192,7 @@ object CassandraDirectJoinStrategy extends Logging {
192192
def getCassandraScan(plan: LogicalPlan): Option[CassandraScan] = {
193193
val children = plan.collectLeaves()
194194
if (children.length == 1) {
195-
plan.collectLeaves().collectFirst { case DataSourceV2ScanRelation(_: DataSourceV2Relation, cs: CassandraScan, _) => cs }
195+
plan.collectLeaves().collectFirst { case DataSourceV2ScanRelation(_: DataSourceV2Relation, cs: CassandraScan, _, _) => cs }
196196
} else {
197197
None
198198
}
@@ -204,8 +204,8 @@ object CassandraDirectJoinStrategy extends Logging {
204204
*/
205205
def hasCassandraChild[T <: QueryPlan[T]](plan: T): Boolean = {
206206
plan.children.size == 1 && plan.children.exists {
207-
case DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _) => true
208-
case BatchScanExec(_, _: CassandraScan, _) => true
207+
case DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _, _) => true
208+
case BatchScanExec(_, _: CassandraScan, _, _) => true
209209
case _ => false
210210
}
211211
}
@@ -238,7 +238,7 @@ object CassandraDirectJoinStrategy extends Logging {
238238
originalOutput: Seq[Attribute]): SparkPlan = {
239239
val reordered = plan match {
240240
//This may be the only node in the Plan
241-
case BatchScanExec(_, _: CassandraScan, _) => directJoin
241+
case BatchScanExec(_, _: CassandraScan, _, _) => directJoin
242242
// Plan has children
243243
case normalPlan => normalPlan.transform {
244244
case penultimate if hasCassandraChild(penultimate) =>
@@ -301,7 +301,7 @@ object CassandraDirectJoinStrategy extends Logging {
301301
plan match {
302302
case PhysicalOperation(
303303
attributes, _,
304-
DataSourceV2ScanRelation(DataSourceV2Relation(cassandraTable: CassandraTable, _, _, _, _), _, _)) =>
304+
DataSourceV2ScanRelation(DataSourceV2Relation(cassandraTable: CassandraTable, _, _, _, _), _, _, _)) =>
305305

306306
val joinKeysExprId = joinKeys.collect{ case attributeReference: AttributeReference => attributeReference.exprId }
307307

@@ -341,7 +341,7 @@ object CassandraDirectJoinStrategy extends Logging {
341341
*/
342342
def containsSafePlans(plan: LogicalPlan): Boolean = {
343343
plan match {
344-
case PhysicalOperation(_, _, DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), scan: CassandraScan, _))
344+
case PhysicalOperation(_, _, DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), scan: CassandraScan, _, _))
345345
if getDirectJoinSetting(scan.consolidatedConf) != AlwaysOff => true
346346
case _ => false
347347
}

0 commit comments

Comments
 (0)