Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ private[spark] class PartitionIdPassthrough(override val numPartitions: Int) ext
/**
* A [[org.apache.spark.Partitioner]] that partitions all records using partition value map.
* The `valueMap` is a map that contains tuples of (partition value, partition id). It is generated
* by [[org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning]], used to partition
* by [[org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning]], used to partition
* the other side of a join to make sure records with same partition value are in the same
* partition.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans.physical

import java.util.Objects

import scala.annotation.tailrec
import scala.collection.mutable

Expand Down Expand Up @@ -346,43 +348,113 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa
}

/**
* Represents a partitioning where rows are split across partitions based on transforms defined
* by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in
* ascending order, after evaluated by the transforms in `expressions`, for each input partition.
* In addition, its length must be the same as the number of Spark partitions (and thus is a 1-1
* mapping), and each row in `partitionValues` must be unique.
* Represents a partitioning where rows are split across partitions based on transforms defined by
* `expressions`.
*
* == Partition Keys ==
* This partitioning has two sets of partition keys:
*
* - `partitionKeys`: The current partition key for each partition, in ascending order. May contain
* duplicates when first created from a data source, but becomes unique after grouping.
*
* - `originalPartitionKeys`: The original partition keys from the data source, in ascending order.
* Always preserves the original values, even after grouping. Used to track the original
* distribution for optimization purposes.
*
* == Grouping State ==
* A KeyedPartitioning can be in two states:
*
* - '''Ungrouped''' (when `isGrouped == false`): `partitionKeys` contains duplicates. Multiple
* input partitions share the same key. This is the initial state when created from a data source.
*
* - '''Grouped''' (when `isGrouped == true`): `partitionKeys` contains only unique values. Each
* partition has a distinct key. This state is achieved by applying `GroupPartitionsExec`, which
* coalesces partitions with the same key.
*
* The `originalPartitionValues`, on the other hand, are partition values from the original input
* splits returned by data sources. It may contain duplicated values.
* == Example ==
* Consider a data source with partition transform `[years(ts_col)]` and 4 input splits:
*
* For example, if a data source reports partition transform expressions `[years(ts_col)]` with 4
* input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions`
* in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, which
* represents 3 input partitions with distinct partition values. All rows in each partition have
* the same value for column `ts_col` (which is of timestamp type), after being applied by the
* `years` transform. This is generated after combining the two splits with partition value `2`
* into a single Spark partition.
* '''Before GroupPartitionsExec''' (ungrouped):
* {{{
* expressions: [years(ts_col)]
* partitionKeys: [0, 1, 2, 2] // partition 2 and 3 have the same key
* originalPartitionKeys: [0, 1, 2, 2]
* numPartitions: 4
* isGrouped: false
* }}}
*
* On the other hand, in this example `[0, 1, 2, 2]` is the value of `originalPartitionValues`
* which is calculated from the original input splits.
* '''After GroupPartitionsExec''' (grouped):
* {{{
* expressions: [years(ts_col)]
* partitionKeys: [0, 1, 2] // duplicates removed, partitions coalesced
* originalPartitionKeys: [0, 1, 2, 2] // unchanged, preserves original distribution
* numPartitions: 3
* isGrouped: true
* }}}
*
* @param expressions partition expressions for the partitioning.
* @param numPartitions the number of partitions
* @param partitionValues the values for the final cluster keys (that is, after applying grouping
* on the input splits according to `expressions`) of the distribution,
* must be in ascending order, and must NOT contain duplicated values.
* @param originalPartitionValues the original input partition values before any grouping has been
* applied, must be in ascending order, and may contain duplicated
* values
* @param expressions Partition transform expressions (e.g., `years(col)`, `bucket(10, col)`).
* @param partitionKeys Current partition keys, one per partition, in ascending order.
* May contain duplicates before grouping.
* @param originalPartitionKeys Original partition keys from the data source, in ascending order.
* Preserves the initial distribution even after grouping.
*/
case class KeyGroupedPartitioning(
case class KeyedPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
partitionValues: Seq[InternalRow] = Seq.empty,
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike {
partitionKeys: Seq[InternalRow],
originalPartitionKeys: Seq[InternalRow]) extends Expression with Partitioning with Unevaluable {
override val numPartitions = partitionKeys.length

override def children: Seq[Expression] = expressions
override def nullable: Boolean = false
override def dataType: DataType = IntegerType

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): KeyedPartitioning =
copy(expressions = newChildren)

@transient private lazy val dataTypes: Seq[DataType] = expressions.map(_.dataType)

@transient private lazy val comparableWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes)

@transient private lazy val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes)

@transient lazy val isGrouped: Boolean = {
partitionKeys.map(comparableWrapperFactory).distinct.size == partitionKeys.size
}

def toGrouped: KeyedPartitioning = {
val groupedPartitions = partitionKeys
.map(comparableWrapperFactory)
.distinct
.map(_.row)
.sorted(rowOrdering)

KeyedPartitioning(expressions, groupedPartitions, originalPartitionKeys)
}

def projectAndGroup(positions: Seq[Int]): KeyedPartitioning = {
val projectedExpressions = positions.map(expressions)
val projectedDataTypes = projectedExpressions.map(_.dataType)
val projectedPartitionKeys = partitionKeys.map(
KeyedPartitioning.projectKey(_, positions, projectedDataTypes)
)
val projectedOriginalPartitionKeys = originalPartitionKeys.map(
KeyedPartitioning.projectKey(_, positions, projectedDataTypes)
)
val internalRowComparableWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes)
val distinctPartitionKeys = projectedPartitionKeys
.map(internalRowComparableWrapperFactory)
.distinct
.map(_.row)

copy(expressions = projectedExpressions, partitionKeys = distinctPartitionKeys,
originalPartitionKeys = projectedOriginalPartitionKeys)
}

override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
super.satisfies0(required) || isGrouped && {
required match {
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
if (requireAllClusterKeys) {
Expand All @@ -395,9 +467,9 @@ case class KeyGroupedPartitioning(

if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// check that join keys (required clustering keys)
// overlap with partition keys (KeyGroupedPartitioning attributes)
// overlap with partition keys (KeyedPartitioning attributes)
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
expressions.forall(_.collectLeaves().size == 1)
} else {
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
Expand All @@ -416,63 +488,37 @@ case class KeyGroupedPartitioning(
val result = KeyGroupedShuffleSpec(this, distribution)
if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// If allowing join keys to be subset of clustering keys, we should create a new
// `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as
// `KeyedPartitioning` here that is grouped on the join keys instead, and use that as
// the returned shuffle spec.
val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions,
partitionValues, originalPartitionValues)
val projectedPartitioning = projectAndGroup(joinKeyPositions)
result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions))
} else {
result
}
}

lazy val uniquePartitionValues: Seq[InternalRow] = {
val internalRowComparableFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
expressions.map(_.dataType))
partitionValues
.map(internalRowComparableFactory)
.distinct
.map(_.row)
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(expressions = newChildren)
}
override def equals(that: Any): Boolean = that match {
case k: KeyedPartitioning if this.expressions == k.expressions =>
def keysEqual(keys1: Seq[InternalRow], keys2: Seq[InternalRow]): Boolean = {
keys1.size == keys2.size && keys1.zip(keys2).forall { case (l, r) =>
comparableWrapperFactory(l).equals(comparableWrapperFactory(r))
}
}

object KeyGroupedPartitioning {
def apply(
expressions: Seq[Expression],
projectionPositions: Seq[Int],
partitionValues: Seq[InternalRow],
originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
val projectedExpressions = projectionPositions.map(expressions(_))
val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
val projectedOriginalPartitionValues =
originalPartitionValues.map(project(expressions, projectionPositions, _))
val internalRowComparableFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
projectedExpressions.map(_.dataType))

val finalPartitionValues = projectedPartitionValues
.map(internalRowComparableFactory)
.distinct
.map(_.row)
keysEqual(partitionKeys, k.partitionKeys) &&
keysEqual(originalPartitionKeys, k.originalPartitionKeys)

KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
finalPartitionValues, projectedOriginalPartitionValues)
case _ => false
}

def project(
expressions: Seq[Expression],
positions: Seq[Int],
input: InternalRow): InternalRow = {
val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType))
.toArray
new GenericInternalRow(projectedValues)
override def hashCode(): Int = {
Objects.hash(expressions, partitionKeys.map(comparableWrapperFactory),
originalPartitionKeys.map(comparableWrapperFactory))
}
}

object KeyedPartitioning {
def supportsExpressions(expressions: Seq[Expression]): Boolean = {
def isSupportedTransform(transform: TransformExpression): Boolean = {
transform.children.size == 1 && isReference(transform.children.head)
Expand All @@ -491,6 +537,28 @@ object KeyGroupedPartitioning {
case _ => false
}
}

def projectKey(
key: InternalRow,
positions: Seq[Int],
dataTypes: Seq[DataType]): InternalRow = {
val projectedKey = positions.zip(dataTypes).map {
case (position, dataType) => key.get(position, dataType)
}.toArray[Any]
new GenericInternalRow(projectedKey)
}

def reduceKey(
key: InternalRow,
reducers: Seq[Option[Reducer[_, _]]],
dataTypes: Seq[DataType]): InternalRow = {
val keyValues = key.toSeq(dataTypes)
val reducedKey = keyValues.zip(reducers).map{
case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v)
case (v, _) => v
}.toArray
new GenericInternalRow(reducedKey)
}
}

/**
Expand Down Expand Up @@ -827,18 +895,20 @@ case class CoalescedHashShuffleSpec(
}

/**
* [[ShuffleSpec]] created by [[KeyGroupedPartitioning]].
* [[ShuffleSpec]] created by [[KeyedPartitioning]].
*
* @param partitioning key grouped partitioning
* @param distribution distribution
* @param joinKeyPositions position of join keys among cluster keys.
* This is set if joining on a subset of cluster keys is allowed.
*/
case class KeyGroupedShuffleSpec(
partitioning: KeyGroupedPartitioning,
partitioning: KeyedPartitioning,
distribution: ClusteredDistribution,
joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec {

assert(partitioning.isGrouped)

/**
* A sequence where each element is a set of positions of the partition expression to the cluster
* keys. For instance, if cluster keys are [a, b, b] and partition expressions are
Expand Down Expand Up @@ -878,7 +948,7 @@ case class KeyGroupedShuffleSpec(
partitioning.expressions.map(_.dataType))
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
partitioning.partitionKeys.zip(otherPartitioning.partitionKeys).forall {
case (left, right) =>
internalRowComparableFactory(left).equals(internalRowComparableFactory(right))
}
Expand Down Expand Up @@ -959,21 +1029,20 @@ case class KeyGroupedShuffleSpec(
te.copy(children = te.children.map(_ => clustering(positionSet.head)))
case (_, positionSet) => clustering(positionSet.head)
}
KeyGroupedPartitioning(newExpressions,
partitioning.numPartitions,
partitioning.partitionValues)
KeyedPartitioning(newExpressions, partitioning.partitionKeys,
partitioning.originalPartitionKeys)
}
}

object KeyGroupedShuffleSpec {
def reducePartitionValue(
def reducePartitionKey(
row: InternalRow,
reducers: Seq[Option[Reducer[_, _]]],
dataTypes: Seq[DataType],
internalRowComparableWrapperFactory: InternalRow => InternalRowComparableWrapper
): InternalRowComparableWrapper = {
val partitionVals = row.toSeq(dataTypes)
val reducedRow = partitionVals.zip(reducers).map{
val partitionKeys = row.toSeq(dataTypes)
val reducedRow = partitionKeys.zip(reducers).map{
case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v)
case (v, _) => v
}.toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util

import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
import org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning
import org.apache.spark.sql.connector.catalog.PartitionInternalRow
import org.apache.spark.sql.types.IntegerType

Expand Down Expand Up @@ -61,10 +61,10 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase {
// just to mock the data types
val expressions = (Seq(Literal(day, IntegerType), Literal(0, IntegerType)))

val leftPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions)
val rightPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions)
val leftPartitioning = KeyedPartitioning(expressions, partitions, partitions)
val rightPartitioning = KeyedPartitioning(expressions, partitions, partitions)
val merged = InternalRowComparableWrapper.mergePartitions(
leftPartitioning.partitionValues, rightPartitioning.partitionValues, expressions)
leftPartitioning.partitionKeys, rightPartitioning.partitionKeys, expressions)
assert(merged.size == bucketNum)
}

Expand Down
Loading