Skip to content

Commit 53034f5

Browse files
committed
Replace KeyGroupedPartitioning with KeyedPartitioning, add new GroupPartitionsExec operator, remove old code
1 parent 3d75c6b commit 53034f5

File tree

16 files changed

+986
-882
lines changed

16 files changed

+986
-882
lines changed

core/src/main/scala/org/apache/spark/Partitioner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ private[spark] class PartitionIdPassthrough(override val numPartitions: Int) ext
142142
/**
143143
* A [[org.apache.spark.Partitioner]] that partitions all records using partition value map.
144144
* The `valueMap` is a map that contains tuples of (partition value, partition id). It is generated
145-
* by [[org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning]], used to partition
145+
* by [[org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning]], used to partition
146146
* the other side of a join to make sure records with same partition value are in the same
147147
* partition.
148148
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 123 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.physical
1919

20+
import java.util.Objects
21+
2022
import scala.annotation.tailrec
2123
import scala.collection.mutable
2224

@@ -346,43 +348,85 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa
346348
}
347349

348350
/**
349-
* Represents a partitioning where rows are split across partitions based on transforms defined
350-
* by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in
351-
* ascending order, after evaluated by the transforms in `expressions`, for each input partition.
352-
* In addition, its length must be the same as the number of Spark partitions (and thus is a 1-1
353-
* mapping), and each row in `partitionValues` must be unique.
351+
* Represents a partitioning where rows are split across partitions based on transforms defined by
352+
* `expressions`. `partitionKeys`, should contain value of partition key(s) in ascending order,
353+
* after evaluated by the transforms in `expressions`, for each input partition.
354+
* `partitionKeys` might not be unique when this partitioning is returned from a data source, but
355+
* the `GroupPartitionsExec` operator can group partitions with the same key and so make
356+
* `partitionKeys` unique.
354357
*
355-
* The `originalPartitionValues`, on the other hand, are partition values from the original input
358+
* The `originalPartitionKeys`, on the other hand, are partition values from the original input
356359
* splits returned by data sources. It may contain duplicated values.
357360
*
358361
* For example, if a data source reports partition transform expressions `[years(ts_col)]` with 4
359-
* input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions`
360-
* in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, which
361-
* represents 3 input partitions with distinct partition values. All rows in each partition have
362-
* the same value for column `ts_col` (which is of timestamp type), after being applied by the
363-
* `years` transform. This is generated after combining the two splits with partition value `2`
364-
* into a single Spark partition.
365-
*
366-
* On the other hand, in this example `[0, 1, 2, 2]` is the value of `originalPartitionValues`
367-
* which is calculated from the original input splits.
362+
* input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions` in
363+
* this case is `[years(ts_col)]`, while both `partitionKeys` and `originalPartitionKeys` are
364+
* `[0, 1, 2, 2]`.
365+
* After placing a `GroupPartitionsExec` operator on top of the data source, `partitionKeys` becomes
366+
* `[0, 1, 2]` but `originalPartitionKeys` remains `[0, 1, 2, 2]`.
368367
*
369-
* @param expressions partition expressions for the partitioning.
370-
* @param numPartitions the number of partitions
371-
* @param partitionValues the values for the final cluster keys (that is, after applying grouping
372-
* on the input splits according to `expressions`) of the distribution,
373-
* must be in ascending order, and must NOT contain duplicated values.
374-
* @param originalPartitionValues the original input partition values before any grouping has been
375-
* applied, must be in ascending order, and may contain duplicated
376-
* values
368+
* @param expressions Partition expressions for the partitioning.
369+
* @param partitionKeys The keys for the partitions, must be in ascending order.
370+
* @param originalPartitionKeys The original partition keys before any grouping has been applied by
371+
* a `GroupPartitionsExec` operator, must be in ascending order.
377372
*/
378-
case class KeyGroupedPartitioning(
373+
case class KeyedPartitioning(
379374
expressions: Seq[Expression],
380-
numPartitions: Int,
381-
partitionValues: Seq[InternalRow] = Seq.empty,
382-
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike {
375+
partitionKeys: Seq[InternalRow],
376+
originalPartitionKeys: Seq[InternalRow]) extends Expression with Partitioning with Unevaluable {
377+
override val numPartitions = partitionKeys.length
378+
379+
override def children: Seq[Expression] = expressions
380+
override def nullable: Boolean = false
381+
override def dataType: DataType = IntegerType
382+
383+
override protected def withNewChildrenInternal(
384+
newChildren: IndexedSeq[Expression]): KeyedPartitioning =
385+
copy(expressions = newChildren)
386+
387+
@transient private lazy val dataTypes: Seq[DataType] = expressions.map(_.dataType)
388+
389+
@transient private lazy val comparableWrapperFactory =
390+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes)
391+
392+
@transient private lazy val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes)
393+
394+
@transient lazy val isGrouped: Boolean = {
395+
partitionKeys.map(comparableWrapperFactory).distinct.size == partitionKeys.size
396+
}
397+
398+
def toGrouped: KeyedPartitioning = {
399+
val groupedPartitions = partitionKeys
400+
.map(comparableWrapperFactory)
401+
.distinct
402+
.map(_.row)
403+
.sorted(rowOrdering)
404+
405+
KeyedPartitioning(expressions, groupedPartitions, originalPartitionKeys)
406+
}
407+
408+
def projectAndGroup(positions: Seq[Int]): KeyedPartitioning = {
409+
val projectedExpressions = positions.map(expressions)
410+
val projectedDataTypes = projectedExpressions.map(_.dataType)
411+
val projectedPartitionKeys = partitionKeys.map(
412+
KeyedPartitioning.projectKey(_, positions, projectedDataTypes)
413+
)
414+
val projectedOriginalPartitionKeys = originalPartitionKeys.map(
415+
KeyedPartitioning.projectKey(_, positions, projectedDataTypes)
416+
)
417+
val internalRowComparableWrapperFactory =
418+
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes)
419+
val distinctPartitionKeys = projectedPartitionKeys
420+
.map(internalRowComparableWrapperFactory)
421+
.distinct
422+
.map(_.row)
423+
424+
copy(expressions = projectedExpressions, partitionKeys = distinctPartitionKeys,
425+
originalPartitionKeys = projectedOriginalPartitionKeys)
426+
}
383427

384428
override def satisfies0(required: Distribution): Boolean = {
385-
super.satisfies0(required) || {
429+
super.satisfies0(required) || isGrouped && {
386430
required match {
387431
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
388432
if (requireAllClusterKeys) {
@@ -395,9 +439,9 @@ case class KeyGroupedPartitioning(
395439

396440
if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
397441
// check that join keys (required clustering keys)
398-
// overlap with partition keys (KeyGroupedPartitioning attributes)
442+
// overlap with partition keys (KeyedPartitioning attributes)
399443
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
400-
expressions.forall(_.collectLeaves().size == 1)
444+
expressions.forall(_.collectLeaves().size == 1)
401445
} else {
402446
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
403447
}
@@ -416,63 +460,37 @@ case class KeyGroupedPartitioning(
416460
val result = KeyGroupedShuffleSpec(this, distribution)
417461
if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
418462
// If allowing join keys to be subset of clustering keys, we should create a new
419-
// `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as
463+
// `KeyedPartitioning` here that is grouped on the join keys instead, and use that as
420464
// the returned shuffle spec.
421465
val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
422-
val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions,
423-
partitionValues, originalPartitionValues)
466+
val projectedPartitioning = projectAndGroup(joinKeyPositions)
424467
result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions))
425468
} else {
426469
result
427470
}
428471
}
429472

430-
lazy val uniquePartitionValues: Seq[InternalRow] = {
431-
val internalRowComparableFactory =
432-
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
433-
expressions.map(_.dataType))
434-
partitionValues
435-
.map(internalRowComparableFactory)
436-
.distinct
437-
.map(_.row)
438-
}
439-
440-
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
441-
copy(expressions = newChildren)
442-
}
473+
override def equals(that: Any): Boolean = that match {
474+
case k: KeyedPartitioning if this.expressions == k.expressions =>
475+
def keysEqual(keys1: Seq[InternalRow], keys2: Seq[InternalRow]): Boolean = {
476+
keys1.size == keys2.size && keys1.zip(keys2).forall { case (l, r) =>
477+
comparableWrapperFactory(l).equals(comparableWrapperFactory(r))
478+
}
479+
}
443480

444-
object KeyGroupedPartitioning {
445-
def apply(
446-
expressions: Seq[Expression],
447-
projectionPositions: Seq[Int],
448-
partitionValues: Seq[InternalRow],
449-
originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
450-
val projectedExpressions = projectionPositions.map(expressions(_))
451-
val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
452-
val projectedOriginalPartitionValues =
453-
originalPartitionValues.map(project(expressions, projectionPositions, _))
454-
val internalRowComparableFactory =
455-
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
456-
projectedExpressions.map(_.dataType))
457-
458-
val finalPartitionValues = projectedPartitionValues
459-
.map(internalRowComparableFactory)
460-
.distinct
461-
.map(_.row)
481+
keysEqual(partitionKeys, k.partitionKeys) &&
482+
keysEqual(originalPartitionKeys, k.originalPartitionKeys)
462483

463-
KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
464-
finalPartitionValues, projectedOriginalPartitionValues)
484+
case _ => false
465485
}
466486

467-
def project(
468-
expressions: Seq[Expression],
469-
positions: Seq[Int],
470-
input: InternalRow): InternalRow = {
471-
val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType))
472-
.toArray
473-
new GenericInternalRow(projectedValues)
487+
override def hashCode(): Int = {
488+
Objects.hash(expressions, partitionKeys.map(comparableWrapperFactory),
489+
originalPartitionKeys.map(comparableWrapperFactory))
474490
}
491+
}
475492

493+
object KeyedPartitioning {
476494
def supportsExpressions(expressions: Seq[Expression]): Boolean = {
477495
def isSupportedTransform(transform: TransformExpression): Boolean = {
478496
transform.children.size == 1 && isReference(transform.children.head)
@@ -491,6 +509,28 @@ object KeyGroupedPartitioning {
491509
case _ => false
492510
}
493511
}
512+
513+
def projectKey(
514+
key: InternalRow,
515+
positions: Seq[Int],
516+
dataTypes: Seq[DataType]): InternalRow = {
517+
val projectedKey = positions.zip(dataTypes).map {
518+
case (position, dataType) => key.get(position, dataType)
519+
}.toArray[Any]
520+
new GenericInternalRow(projectedKey)
521+
}
522+
523+
def reduceKey(
524+
key: InternalRow,
525+
reducers: Seq[Option[Reducer[_, _]]],
526+
dataTypes: Seq[DataType]): InternalRow = {
527+
val keyValues = key.toSeq(dataTypes)
528+
val reducedKey = keyValues.zip(reducers).map{
529+
case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v)
530+
case (v, _) => v
531+
}.toArray
532+
new GenericInternalRow(reducedKey)
533+
}
494534
}
495535

496536
/**
@@ -827,18 +867,20 @@ case class CoalescedHashShuffleSpec(
827867
}
828868

829869
/**
830-
* [[ShuffleSpec]] created by [[KeyGroupedPartitioning]].
870+
* [[ShuffleSpec]] created by [[KeyedPartitioning]].
831871
*
832872
* @param partitioning key grouped partitioning
833873
* @param distribution distribution
834874
* @param joinKeyPositions position of join keys among cluster keys.
835875
* This is set if joining on a subset of cluster keys is allowed.
836876
*/
837877
case class KeyGroupedShuffleSpec(
838-
partitioning: KeyGroupedPartitioning,
878+
partitioning: KeyedPartitioning,
839879
distribution: ClusteredDistribution,
840880
joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec {
841881

882+
assert(partitioning.isGrouped)
883+
842884
/**
843885
* A sequence where each element is a set of positions of the partition expression to the cluster
844886
* keys. For instance, if cluster keys are [a, b, b] and partition expressions are
@@ -878,7 +920,7 @@ case class KeyGroupedShuffleSpec(
878920
partitioning.expressions.map(_.dataType))
879921
distribution.clustering.length == otherDistribution.clustering.length &&
880922
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
881-
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
923+
partitioning.partitionKeys.zip(otherPartitioning.partitionKeys).forall {
882924
case (left, right) =>
883925
internalRowComparableFactory(left).equals(internalRowComparableFactory(right))
884926
}
@@ -959,21 +1001,20 @@ case class KeyGroupedShuffleSpec(
9591001
te.copy(children = te.children.map(_ => clustering(positionSet.head)))
9601002
case (_, positionSet) => clustering(positionSet.head)
9611003
}
962-
KeyGroupedPartitioning(newExpressions,
963-
partitioning.numPartitions,
964-
partitioning.partitionValues)
1004+
KeyedPartitioning(newExpressions, partitioning.partitionKeys,
1005+
partitioning.originalPartitionKeys)
9651006
}
9661007
}
9671008

9681009
object KeyGroupedShuffleSpec {
969-
def reducePartitionValue(
1010+
def reducePartitionKey(
9701011
row: InternalRow,
9711012
reducers: Seq[Option[Reducer[_, _]]],
9721013
dataTypes: Seq[DataType],
9731014
internalRowComparableWrapperFactory: InternalRow => InternalRowComparableWrapper
9741015
): InternalRowComparableWrapper = {
975-
val partitionVals = row.toSeq(dataTypes)
976-
val reducedRow = partitionVals.zip(reducers).map{
1016+
val partitionKeys = row.toSeq(dataTypes)
1017+
val reducedRow = partitionKeys.zip(reducers).map{
9771018
case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v)
9781019
case (v, _) => v
9791020
}.toArray

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util
1919

2020
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
2121
import org.apache.spark.sql.catalyst.expressions.Literal
22-
import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
22+
import org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning
2323
import org.apache.spark.sql.connector.catalog.PartitionInternalRow
2424
import org.apache.spark.sql.types.IntegerType
2525

@@ -61,10 +61,10 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase {
6161
// just to mock the data types
6262
val expressions = (Seq(Literal(day, IntegerType), Literal(0, IntegerType)))
6363

64-
val leftPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions)
65-
val rightPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions)
64+
val leftPartitioning = KeyedPartitioning(expressions, partitions, partitions)
65+
val rightPartitioning = KeyedPartitioning(expressions, partitions, partitions)
6666
val merged = InternalRowComparableWrapper.mergePartitions(
67-
leftPartitioning.partitionValues, rightPartitioning.partitionValues, expressions)
67+
leftPartitioning.partitionKeys, rightPartitioning.partitionKeys, expressions)
6868
assert(merged.size == bucketNum)
6969
}
7070

0 commit comments

Comments
 (0)