Skip to content

Commit 8213f3f

Browse files
authored
Merge pull request #1279 from datastax/SPARKC-613-2.5
SPARKC-613 fix direct join for spark sql
2 parents 7b36b6e + 729c1dd commit 8213f3f

File tree

3 files changed

+28
-17
lines changed

3 files changed

+28
-17
lines changed

connector/src/it/scala/org/apache/spark/sql/cassandra/execution/CassandraDirectJoinSpec.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package org.apache.spark.sql.cassandra.execution
22

33
import java.sql.Timestamp
44
import java.time.Instant
5-
import java.util.concurrent.CompletableFuture
65

76
import com.datastax.oss.driver.api.core.DefaultProtocolVersion._
87
import com.datastax.spark.connector.cluster.DefaultCluster
@@ -634,6 +633,18 @@ class CassandraDirectJoinSpec extends SparkCassandraITFlatSpecBase with DefaultC
634633
quotes.count should be (1)
635634
}
636635

636+
/** SPARKC-613 */
637+
it should "use direct join for sql join" in compareDirectOnDirectOff { spark =>
638+
import spark.implicits._
639+
640+
val toJoin = spark.range(1, 5).map(_.intValue).withColumnRenamed("value", "id")
641+
642+
toJoin.createOrReplaceTempView("tojoin")
643+
spark.read.cassandraFormat("kv", ks).load().createOrReplaceTempView(s"cassdata")
644+
645+
spark.sql("select * from tojoin tj inner join cassdata cd on tj.id = cd.k")
646+
}
647+
637648
private def compareDirectOnDirectOff(test: ((SparkSession) => DataFrame)) = {
638649
val sparkJoinOn = sparkSession.cloneSession()
639650
sparkJoinOn.conf.set(DirectJoinSettingParam.name, "on")

connector/src/main/scala/org/apache/spark/sql/cassandra/execution/CassandraDirectJoinExec.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import com.datastax.spark.connector.rdd.{CassandraJoinRDD, CassandraLeftJoinRDD,
55
import org.apache.spark.rdd.RDD
66
import org.apache.spark.sql.cassandra.execution.unsafe.{UnsafeRowReaderFactory, UnsafeRowWriterFactory}
77
import org.apache.spark.sql.catalyst.InternalRow
8-
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BindReferences, EqualTo, Expression, GenericInternalRow, JoinedRow, UnsafeProjection, UnsafeRow}
8+
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BindReferences, EqualTo, ExprId, Expression, GenericInternalRow, JoinedRow, UnsafeProjection, UnsafeRow}
99
import org.apache.spark.sql.catalyst.plans._
1010
import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan, UnaryExecNode}
1111
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildSide}
@@ -21,7 +21,7 @@ case class CassandraDirectJoinExec(
2121
cassandraSide: BuildSide,
2222
condition: Option[Expression],
2323
child: SparkPlan,
24-
aliasMap: Map[String, Attribute],
24+
aliasMap: Map[String, ExprId],
2525
cassandraScan: CassandraTableScanRDD[_],
2626
cassandraPlan: DataSourceScanExec) extends UnaryExecNode {
2727

@@ -41,7 +41,7 @@ case class CassandraDirectJoinExec(
4141
val primaryKeys = cassandraScan.tableDef.primaryKey.map(_.columnName)
4242
val cassandraSchema = cassandraPlan.schema
4343

44-
val attributeToCassandra = aliasMap.map(_.swap)
44+
val exprIdToCassandra = aliasMap.map(_.swap)
4545

4646
val leftJoinCouplets =
4747
if (cassandraSide == BuildLeft) leftKeys.zip(rightKeys) else rightKeys.zip(leftKeys)
@@ -54,15 +54,15 @@ case class CassandraDirectJoinExec(
5454
*/
5555
val (pkJoinCoulplets, otherJoinCouplets) = leftJoinCouplets.partition {
5656
case (cassandraAttribute: Attribute, _) =>
57-
attributeToCassandra.get(cassandraAttribute) match {
57+
exprIdToCassandra.get(cassandraAttribute.exprId) match {
5858
case Some(name) if primaryKeys.contains(name) => true
5959
case _ => false
60-
}
60+
}
6161
case _ => false
6262
}
6363

6464
val (joinColumns, joinExpressions) = pkJoinCoulplets.map { case (cAttr: Attribute, otherCol: Expression) =>
65-
(ColumnName(attributeToCassandra(cAttr)), BindReferences.bindReference(otherCol, keySource.output))
65+
(ColumnName(exprIdToCassandra(cAttr.exprId)), BindReferences.bindReference(otherCol, keySource.output))
6666
}.unzip
6767

6868
/**
@@ -210,7 +210,7 @@ case class CassandraDirectJoinExec(
210210
val selectString = selectedColumns.mkString("Reading (", ", ", ")")
211211

212212
val joinString = pkJoinCoulplets
213-
.map{ case (colref: Attribute, exp) => s"${attributeToCassandra(colref)} = ${exp}"}
213+
.map{ case (colref: Attribute, exp) => s"${exprIdToCassandra(colref.exprId)} = ${exp}"}
214214
.mkString(", ")
215215

216216
s"Cassandra Direct Join [${joinString}] $keyspace.$table - $selectString${pushedWhere} "

connector/src/main/scala/org/apache/spark/sql/cassandra/execution/CassandraDirectJoinStrategy.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,14 @@ object CassandraDirectJoinStrategy extends Logging {
298298
def allPartitionKeysAreJoined(plan: LogicalPlan, joinKeys: Seq[Expression]): Boolean =
299299
plan match {
300300
case PhysicalOperation(
301-
attributes, _,
302-
LogicalRelation(cassandraSource: CassandraSourceRelation, _, _, _)) =>
301+
attributes, _,
302+
LogicalRelation(cassandraSource: CassandraSourceRelation, _, _, _)) =>
303+
304+
val joinKeysExprId = joinKeys.collect { case attributeReference: AttributeReference => attributeReference.exprId }
303305

304306
val joinKeyAliases =
305307
aliasMap(attributes)
306-
.filter{ case (_, value) => joinKeys.contains(value) }
308+
.filter { case (_, value) => joinKeysExprId.contains(value) }
307309
val partitionKeyNames = cassandraSource.tableDef.partitionKey.map(_.columnName)
308310
val allKeysPresent = partitionKeyNames.forall(joinKeyAliases.contains)
309311

@@ -312,15 +314,15 @@ object CassandraDirectJoinStrategy extends Logging {
312314
}
313315

314316
allKeysPresent
315-
case _ => false
316-
}
317+
case _ => false
318+
}
317319

318320
/**
319321
* Map Source Names to Attributes
320322
*/
321323
def aliasMap(aliases: Seq[NamedExpression]) = aliases.map {
322-
case a @ Alias(child: AttributeReference, _) => child.name -> a.toAttribute
323-
case namedExpression: NamedExpression => namedExpression.name -> namedExpression.toAttribute
324+
case a @ Alias(child: AttributeReference, _) => child.name -> a.exprId
325+
case attributeReference: AttributeReference => attributeReference.name -> attributeReference.exprId
324326
}.toMap
325327

326328
/**
@@ -335,6 +337,4 @@ object CassandraDirectJoinStrategy extends Logging {
335337
}
336338
}
337339

338-
339-
340340
}

0 commit comments

Comments
 (0)