1717
1818package org .apache .spark .sql .catalyst .plans .physical
1919
20+ import java .util .Objects
21+
2022import scala .annotation .tailrec
2123import scala .collection .mutable
2224
@@ -346,43 +348,85 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa
346348}
347349
348350/**
349- * Represents a partitioning where rows are split across partitions based on transforms defined
350- * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in
351- * ascending order, after evaluated by the transforms in `expressions`, for each input partition.
352- * In addition, its length must be the same as the number of Spark partitions (and thus is a 1-1
353- * mapping), and each row in `partitionValues` must be unique.
351+ * Represents a partitioning where rows are split across partitions based on transforms defined by
352+ * `expressions`. `partitionKeys`, should contain value of partition key(s) in ascending order,
353+ * after evaluated by the transforms in `expressions`, for each input partition.
354+ * `partitionKeys` might not be unique when this partitioning is returned from a data source, but
355+ * the `GroupPartitionsExec` operator can group partitions with the same key and so make
356+ * `partitionKeys` unique.
354357 *
355- * The `originalPartitionValues `, on the other hand, are partition values from the original input
358+ * The `originalPartitionKeys `, on the other hand, are partition values from the original input
356359 * splits returned by data sources. It may contain duplicated values.
357360 *
358361 * For example, if a data source reports partition transform expressions `[years(ts_col)]` with 4
359- * input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions`
360- * in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, which
361- * represents 3 input partitions with distinct partition values. All rows in each partition have
362- * the same value for column `ts_col` (which is of timestamp type), after being applied by the
363- * `years` transform. This is generated after combining the two splits with partition value `2`
364- * into a single Spark partition.
365- *
366- * On the other hand, in this example `[0, 1, 2, 2]` is the value of `originalPartitionValues`
367- * which is calculated from the original input splits.
362+ * input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions` in
363+ * this case is `[years(ts_col)]`, while both `partitionKeys` and `originalPartitionKeys` are
364+ * `[0, 1, 2, 2]`.
365+ * After placing a `GroupPartitionsExec` operator on top of the data source, `partitionKeys` becomes
366+ * `[0, 1, 2]` but `originalPartitionKeys` remains `[0, 1, 2, 2]`.
368367 *
369- * @param expressions partition expressions for the partitioning.
370- * @param numPartitions the number of partitions
371- * @param partitionValues the values for the final cluster keys (that is, after applying grouping
372- * on the input splits according to `expressions`) of the distribution,
373- * must be in ascending order, and must NOT contain duplicated values.
374- * @param originalPartitionValues the original input partition values before any grouping has been
375- * applied, must be in ascending order, and may contain duplicated
376- * values
368+ * @param expressions Partition expressions for the partitioning.
369+ * @param partitionKeys The keys for the partitions, must be in ascending order.
370+ * @param originalPartitionKeys The original partition keys before any grouping has been applied by
371+ * a `GroupPartitionsExec` operator, must be in ascending order.
377372 */
378- case class KeyGroupedPartitioning (
373+ case class KeyedPartitioning (
379374 expressions : Seq [Expression ],
380- numPartitions : Int ,
381- partitionValues : Seq [InternalRow ] = Seq .empty,
382- originalPartitionValues : Seq [InternalRow ] = Seq .empty) extends HashPartitioningLike {
375+ partitionKeys : Seq [InternalRow ],
376+ originalPartitionKeys : Seq [InternalRow ]) extends Expression with Partitioning with Unevaluable {
377+ override val numPartitions = partitionKeys.length
378+
379+ override def children : Seq [Expression ] = expressions
380+ override def nullable : Boolean = false
381+ override def dataType : DataType = IntegerType
382+
383+ override protected def withNewChildrenInternal (
384+ newChildren : IndexedSeq [Expression ]): KeyedPartitioning =
385+ copy(expressions = newChildren)
386+
387+ @ transient private lazy val dataTypes : Seq [DataType ] = expressions.map(_.dataType)
388+
389+ @ transient private lazy val comparableWrapperFactory =
390+ InternalRowComparableWrapper .getInternalRowComparableWrapperFactory(dataTypes)
391+
392+ @ transient private lazy val rowOrdering = RowOrdering .createNaturalAscendingOrdering(dataTypes)
393+
394+ @ transient lazy val isGrouped : Boolean = {
395+ partitionKeys.map(comparableWrapperFactory).distinct.size == partitionKeys.size
396+ }
397+
398+ def toGrouped : KeyedPartitioning = {
399+ val groupedPartitions = partitionKeys
400+ .map(comparableWrapperFactory)
401+ .distinct
402+ .map(_.row)
403+ .sorted(rowOrdering)
404+
405+ KeyedPartitioning (expressions, groupedPartitions, originalPartitionKeys)
406+ }
407+
408+ def projectAndGroup (positions : Seq [Int ]): KeyedPartitioning = {
409+ val projectedExpressions = positions.map(expressions)
410+ val projectedDataTypes = projectedExpressions.map(_.dataType)
411+ val projectedPartitionKeys = partitionKeys.map(
412+ KeyedPartitioning .projectKey(_, positions, projectedDataTypes)
413+ )
414+ val projectedOriginalPartitionKeys = originalPartitionKeys.map(
415+ KeyedPartitioning .projectKey(_, positions, projectedDataTypes)
416+ )
417+ val internalRowComparableWrapperFactory =
418+ InternalRowComparableWrapper .getInternalRowComparableWrapperFactory(projectedDataTypes)
419+ val distinctPartitionKeys = projectedPartitionKeys
420+ .map(internalRowComparableWrapperFactory)
421+ .distinct
422+ .map(_.row)
423+
424+ copy(expressions = projectedExpressions, partitionKeys = distinctPartitionKeys,
425+ originalPartitionKeys = projectedOriginalPartitionKeys)
426+ }
383427
384428 override def satisfies0 (required : Distribution ): Boolean = {
385- super .satisfies0(required) || {
429+ super .satisfies0(required) || isGrouped && {
386430 required match {
387431 case c @ ClusteredDistribution (requiredClustering, requireAllClusterKeys, _) =>
388432 if (requireAllClusterKeys) {
@@ -395,9 +439,9 @@ case class KeyGroupedPartitioning(
395439
396440 if (SQLConf .get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
397441 // check that join keys (required clustering keys)
398- // overlap with partition keys (KeyGroupedPartitioning attributes)
442+ // overlap with partition keys (KeyedPartitioning attributes)
399443 requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
400- expressions.forall(_.collectLeaves().size == 1 )
444+ expressions.forall(_.collectLeaves().size == 1 )
401445 } else {
402446 attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
403447 }
@@ -416,63 +460,37 @@ case class KeyGroupedPartitioning(
416460 val result = KeyGroupedShuffleSpec (this , distribution)
417461 if (SQLConf .get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
418462 // If allowing join keys to be subset of clustering keys, we should create a new
419- // `KeyGroupedPartitioning ` here that is grouped on the join keys instead, and use that as
463+ // `KeyedPartitioning ` here that is grouped on the join keys instead, and use that as
420464 // the returned shuffle spec.
421465 val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
422- val projectedPartitioning = KeyGroupedPartitioning (expressions, joinKeyPositions,
423- partitionValues, originalPartitionValues)
466+ val projectedPartitioning = projectAndGroup(joinKeyPositions)
424467 result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some (joinKeyPositions))
425468 } else {
426469 result
427470 }
428471 }
429472
430- lazy val uniquePartitionValues : Seq [InternalRow ] = {
431- val internalRowComparableFactory =
432- InternalRowComparableWrapper .getInternalRowComparableWrapperFactory(
433- expressions.map(_.dataType))
434- partitionValues
435- .map(internalRowComparableFactory)
436- .distinct
437- .map(_.row)
438- }
439-
440- override protected def withNewChildrenInternal (newChildren : IndexedSeq [Expression ]): Expression =
441- copy(expressions = newChildren)
442- }
473+ override def equals (that : Any ): Boolean = that match {
474+ case k : KeyedPartitioning if this .expressions == k.expressions =>
475+ def keysEqual (keys1 : Seq [InternalRow ], keys2 : Seq [InternalRow ]): Boolean = {
476+ keys1.size == keys2.size && keys1.zip(keys2).forall { case (l, r) =>
477+ comparableWrapperFactory(l).equals(comparableWrapperFactory(r))
478+ }
479+ }
443480
444- object KeyGroupedPartitioning {
445- def apply (
446- expressions : Seq [Expression ],
447- projectionPositions : Seq [Int ],
448- partitionValues : Seq [InternalRow ],
449- originalPartitionValues : Seq [InternalRow ]): KeyGroupedPartitioning = {
450- val projectedExpressions = projectionPositions.map(expressions(_))
451- val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
452- val projectedOriginalPartitionValues =
453- originalPartitionValues.map(project(expressions, projectionPositions, _))
454- val internalRowComparableFactory =
455- InternalRowComparableWrapper .getInternalRowComparableWrapperFactory(
456- projectedExpressions.map(_.dataType))
457-
458- val finalPartitionValues = projectedPartitionValues
459- .map(internalRowComparableFactory)
460- .distinct
461- .map(_.row)
481+ keysEqual(partitionKeys, k.partitionKeys) &&
482+ keysEqual(originalPartitionKeys, k.originalPartitionKeys)
462483
463- KeyGroupedPartitioning (projectedExpressions, finalPartitionValues.length,
464- finalPartitionValues, projectedOriginalPartitionValues)
484+ case _ => false
465485 }
466486
467- def project (
468- expressions : Seq [Expression ],
469- positions : Seq [Int ],
470- input : InternalRow ): InternalRow = {
471- val projectedValues : Array [Any ] = positions.map(i => input.get(i, expressions(i).dataType))
472- .toArray
473- new GenericInternalRow (projectedValues)
487+ override def hashCode (): Int = {
488+ Objects .hash(expressions, partitionKeys.map(comparableWrapperFactory),
489+ originalPartitionKeys.map(comparableWrapperFactory))
474490 }
491+ }
475492
493+ object KeyedPartitioning {
476494 def supportsExpressions (expressions : Seq [Expression ]): Boolean = {
477495 def isSupportedTransform (transform : TransformExpression ): Boolean = {
478496 transform.children.size == 1 && isReference(transform.children.head)
@@ -491,6 +509,28 @@ object KeyGroupedPartitioning {
491509 case _ => false
492510 }
493511 }
512+
513+ def projectKey (
514+ key : InternalRow ,
515+ positions : Seq [Int ],
516+ dataTypes : Seq [DataType ]): InternalRow = {
517+ val projectedKey = positions.zip(dataTypes).map {
518+ case (position, dataType) => key.get(position, dataType)
519+ }.toArray[Any ]
520+ new GenericInternalRow (projectedKey)
521+ }
522+
523+ def reduceKey (
524+ key : InternalRow ,
525+ reducers : Seq [Option [Reducer [_, _]]],
526+ dataTypes : Seq [DataType ]): InternalRow = {
527+ val keyValues = key.toSeq(dataTypes)
528+ val reducedKey = keyValues.zip(reducers).map{
529+ case (v, Some (reducer : Reducer [Any , Any ])) => reducer.reduce(v)
530+ case (v, _) => v
531+ }.toArray
532+ new GenericInternalRow (reducedKey)
533+ }
494534}
495535
496536/**
@@ -827,18 +867,20 @@ case class CoalescedHashShuffleSpec(
827867}
828868
829869/**
830- * [[ShuffleSpec ]] created by [[KeyGroupedPartitioning ]].
870+ * [[ShuffleSpec ]] created by [[KeyedPartitioning ]].
831871 *
832872 * @param partitioning key grouped partitioning
833873 * @param distribution distribution
834874 * @param joinKeyPositions position of join keys among cluster keys.
835875 * This is set if joining on a subset of cluster keys is allowed.
836876 */
837877case class KeyGroupedShuffleSpec (
838- partitioning : KeyGroupedPartitioning ,
878+ partitioning : KeyedPartitioning ,
839879 distribution : ClusteredDistribution ,
840880 joinKeyPositions : Option [Seq [Int ]] = None ) extends ShuffleSpec {
841881
882+ assert(partitioning.isGrouped)
883+
842884 /**
843885 * A sequence where each element is a set of positions of the partition expression to the cluster
844886 * keys. For instance, if cluster keys are [a, b, b] and partition expressions are
@@ -878,7 +920,7 @@ case class KeyGroupedShuffleSpec(
878920 partitioning.expressions.map(_.dataType))
879921 distribution.clustering.length == otherDistribution.clustering.length &&
880922 numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
881- partitioning.partitionValues .zip(otherPartitioning.partitionValues ).forall {
923+ partitioning.partitionKeys .zip(otherPartitioning.partitionKeys ).forall {
882924 case (left, right) =>
883925 internalRowComparableFactory(left).equals(internalRowComparableFactory(right))
884926 }
@@ -959,21 +1001,20 @@ case class KeyGroupedShuffleSpec(
9591001 te.copy(children = te.children.map(_ => clustering(positionSet.head)))
9601002 case (_, positionSet) => clustering(positionSet.head)
9611003 }
962- KeyGroupedPartitioning (newExpressions,
963- partitioning.numPartitions,
964- partitioning.partitionValues)
1004+ KeyedPartitioning (newExpressions, partitioning.partitionKeys,
1005+ partitioning.originalPartitionKeys)
9651006 }
9661007}
9671008
9681009object KeyGroupedShuffleSpec {
969- def reducePartitionValue (
1010+ def reducePartitionKey (
9701011 row : InternalRow ,
9711012 reducers : Seq [Option [Reducer [_, _]]],
9721013 dataTypes : Seq [DataType ],
9731014 internalRowComparableWrapperFactory : InternalRow => InternalRowComparableWrapper
9741015 ): InternalRowComparableWrapper = {
975- val partitionVals = row.toSeq(dataTypes)
976- val reducedRow = partitionVals .zip(reducers).map{
1016+ val partitionKeys = row.toSeq(dataTypes)
1017+ val reducedRow = partitionKeys .zip(reducers).map{
9771018 case (v, Some (reducer : Reducer [Any , Any ])) => reducer.reduce(v)
9781019 case (v, _) => v
9791020 }.toArray
0 commit comments