Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4284,6 +4284,11 @@
"The function is invoked with DISTINCT and WITHIN GROUP but expressions <funcArg> and <orderingExpr> do not match. The WITHIN GROUP ordering expression must be picked from the function inputs."
]
},
"MISMATCH_WITH_DISTINCT_INPUT_UNSAFE_CAST" : {
"message" : [
"The function <funcName> with DISTINCT and WITHIN GROUP (ORDER BY) is not supported for <inputType> input. Explicitly cast the input to <castType> before passing it to both the function argument and ORDER BY expression."
]
},
"WITHIN_GROUP_MISSING" : {
"message" : [
"WITHIN GROUP is required for the function."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,8 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
messageParameters = Map("funcName" -> toSQLExpr(w)))

case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
if agg.isDistinct && listAgg.needSaveOrderValue =>
throw QueryCompilationErrors.functionAndOrderExpressionMismatchError(
listAgg.prettyName, listAgg.child, listAgg.orderExpressions)
if agg.isDistinct && listAgg.hasDistinctOrderIncompatibility =>
listAgg.throwDistinctOrderError()

case w: WindowExpression =>
WindowResolution.validateResolvedWindowExpression(w)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ class AggregateExpressionResolver(
* `handleOuterAggregateExpression`);
* - Validation:
* 1. [[ListAgg]] is not allowed in DISTINCT aggregates if it contains [[SortOrder]] different
* from its child;
* from its child. However, when [[SQLConf.LISTAGG_ALLOW_DISTINCT_CAST_WITH_ORDER]] is
* enabled, a mismatch is tolerated if it is solely due to a [[Cast]] whose source type is
* injective when cast to string (see [[ListAgg.hasDistinctOrderIncompatibility]]);
* 2. Nested aggregate functions are not allowed;
* 3. Nondeterministic expressions in the subtree of a related aggregate function are not
* allowed;
Expand Down Expand Up @@ -116,8 +118,8 @@ class AggregateExpressionResolver(
private def validateResolvedAggregateExpression(aggregateExpression: AggregateExpression): Unit =
aggregateExpression match {
case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
if agg.isDistinct && listAgg.needSaveOrderValue =>
throwFunctionAndOrderExpressionMismatchError(listAgg)
if agg.isDistinct && listAgg.hasDistinctOrderIncompatibility =>
listAgg.throwDistinctOrderError()
case _ =>
if (expressionResolutionContextStack.peek().hasAggregateExpressions) {
throwNestedAggregateFunction(aggregateExpression)
Expand Down Expand Up @@ -212,14 +214,6 @@ class AggregateExpressionResolver(
}
}

private def throwFunctionAndOrderExpressionMismatchError(listAgg: ListAgg) = {
throw QueryCompilationErrors.functionAndOrderExpressionMismatchError(
listAgg.prettyName,
listAgg.child,
listAgg.orderExpressions
)
}

private def throwNestedAggregateFunction(aggregateExpression: AggregateExpression): Nothing = {
throw new AnalysisException(
errorClass = "NESTED_AGGREGATE_FUNCTION",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, Growable}
import scala.util.{Left, Right}

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
Expand All @@ -31,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtil
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
Expand Down Expand Up @@ -564,6 +566,140 @@ case class ListAgg(
false
}

/**
* Returns true if the ordering expression is incompatible with DISTINCT deduplication.
*
* When LISTAGG(DISTINCT col) WITHIN GROUP (ORDER BY col) is used on a non-string column,
* the child is implicitly cast to string (with UTF8_BINARY collation). The DISTINCT rewrite
* (see [[RewriteDistinctAggregates]]) uses GROUP BY on both the original and cast columns,
* so the cast must preserve equality semantics: values that are GROUP BY-equal must cast to
* equal strings, and vice versa. Types like Float/Double violate this because IEEE 754
* negative zero (-0.0) and positive zero (0.0) are equal but produce different strings.
*
* Returns false when the order expression matches the child (i.e., [[needSaveOrderValue]]
* is false). Otherwise, the behavior depends on the
* [[SQLConf.LISTAGG_ALLOW_DISTINCT_CAST_WITH_ORDER]] config:
* - If enabled, delegates to [[orderMismatchCastSafety]] to determine whether the
* mismatch is due to a safe cast, an unsafe cast, or not a cast at all.
* - If disabled, any mismatch is considered incompatible.
*
* @return true if an incompatibility exists, false if the ordering is safe
* @see [[throwDistinctOrderError]] to throw the appropriate error when this returns true
*/
def hasDistinctOrderIncompatibility: Boolean = {
needSaveOrderValue && {
if (SQLConf.get.listaggAllowDistinctCastWithOrder) {
orderMismatchCastSafety match {
case CastSafetyResult.SafeCast => false
case _ => true
}
} else {
true
}
}
}

def throwDistinctOrderError(): Nothing = {
if (SQLConf.get.listaggAllowDistinctCastWithOrder) {
orderMismatchCastSafety match {
case CastSafetyResult.UnsafeCast(inputType, castType) =>
throwFunctionAndOrderExpressionUnsafeCastError(inputType, castType)
case CastSafetyResult.NotACast =>
throwFunctionAndOrderExpressionMismatchError()
case CastSafetyResult.SafeCast =>
throw SparkException.internalError(
"ListAgg.throwDistinctOrderError should not be called when the cast is safe")
}
} else {
throwFunctionAndOrderExpressionMismatchError()
}
}

private def throwFunctionAndOrderExpressionMismatchError() = {
throw QueryCompilationErrors.functionAndOrderExpressionMismatchError(
prettyName, child, orderExpressions)
}

private def throwFunctionAndOrderExpressionUnsafeCastError(
inputType: DataType, castType: DataType) = {
throw QueryCompilationErrors.functionAndOrderExpressionUnsafeCastError(
prettyName, inputType, castType)
}

/**
* Classifies the order-expression mismatch as a safe cast, unsafe cast, or not a cast.
*
* @see [[hasDistinctOrderIncompatibility]] for the full invariant this enforces
*/
private def orderMismatchCastSafety: CastSafetyResult = {
if (orderExpressions.size != 1) return CastSafetyResult.NotACast
child match {
case Cast(castChild, castType, _, _)
if orderExpressions.head.child.semanticEquals(castChild) =>
if (isCastSafeForDistinct(castChild.dataType) &&
isCastTargetSafeForDistinct(castType)) {
CastSafetyResult.SafeCast
} else {
CastSafetyResult.UnsafeCast(castChild.dataType, castType)
}
case _ => CastSafetyResult.NotACast
}
}

/**
* Returns true if casting `dt` to string/binary is injective for DISTINCT deduplication.
*
* @see [[hasDistinctOrderIncompatibility]]
*/
private def isCastSafeForDistinct(dt: DataType): Boolean = dt match {
case _: IntegerType | LongType | ShortType | ByteType => true
case _: DecimalType => true
case _: DateType | TimestampNTZType => true
case _: TimeType => true
case _: CalendarIntervalType => true
case _: YearMonthIntervalType => true
case _: DayTimeIntervalType => true
case BooleanType => true
case BinaryType => true
case st: StringType => st.isUTF8BinaryCollation
case _: DoubleType | FloatType => false
// During DST fall-back, two distinct UTC epochs can format to the same local time string
// because the default format omits the timezone offset. TimestampNTZType is safe (uses UTC).
case _: TimestampType => false
case _ => false
}

/**
* Returns true if the target type's equality semantics are safe for DISTINCT deduplication
* (i.e., UTF8_BINARY collation or BinaryType).
*
* @see [[hasDistinctOrderIncompatibility]]
*/
private def isCastTargetSafeForDistinct(dt: DataType): Boolean = dt match {
case st: StringType => st.isUTF8BinaryCollation
case BinaryType => true
case _ => false
}

/**
* Result of checking whether a LISTAGG(DISTINCT) order-expression mismatch
* is caused by a cast and whether that cast is safe for deduplication.
*/
private sealed trait CastSafetyResult

private object CastSafetyResult {
/** The mismatch is not due to a cast at all. */
case object NotACast extends CastSafetyResult

/** The mismatch is due to a cast that is safe for DISTINCT. */
case object SafeCast extends CastSafetyResult

/** The mismatch is due to a cast that is unsafe for DISTINCT. */
case class UnsafeCast(
inputType: DataType,
castType: DataType) extends CastSafetyResult
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(
child = newChildren.head,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,18 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
"orderingExpr" -> orderExpr.map(order => toSQLExpr(order.child)).mkString(", ")))
}

def functionAndOrderExpressionUnsafeCastError(
functionName: String,
inputType: DataType,
castType: DataType): Throwable = {
new AnalysisException(
errorClass = "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT_UNSAFE_CAST",
messageParameters = Map(
"funcName" -> toSQLId(functionName),
"inputType" -> toSQLType(inputType),
"castType" -> toSQLType(castType)))
}

def wrongCommandForObjectTypeError(
operation: String,
requiredType: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6889,6 +6889,17 @@ object SQLConf {
.booleanConf
.createWithDefault(Utils.isTesting)

val LISTAGG_ALLOW_DISTINCT_CAST_WITH_ORDER =
buildConf("spark.sql.listagg.allowDistinctCastWithOrder.enabled")
.internal()
.doc("When true, LISTAGG(DISTINCT expr) WITHIN GROUP (ORDER BY expr) is allowed on " +
"non-string expr when the implicit cast to string preserves equality (e.g., integer, " +
"decimal, date). When false, the function argument and ORDER BY expression must have " +
"the exact same type, which requires explicit casts.")
.version("4.2.0")
.booleanConf
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -8115,6 +8126,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def isTimeTypeEnabled: Boolean = getConf(SQLConf.TIME_TYPE_ENABLED)

def listaggAllowDistinctCastWithOrder: Boolean = getConf(LISTAGG_ALLOW_DISTINCT_CAST_WITH_ORDER)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,72 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"orderingExpr" : "\"collate(c1, utf8_binary)\""
}
}


-- !query
SELECT listagg(DISTINCT CAST(col AS STRING)) WITHIN GROUP (ORDER BY col) FROM VALUES ('ABC'), ('abc'), ('ABC') AS t(col)
-- !query analysis
Aggregate [listagg(distinct cast(col#x as string), null, col#x ASC NULLS FIRST, 0, 0) AS listagg(DISTINCT CAST(col AS STRING), NULL) WITHIN GROUP (ORDER BY col ASC NULLS FIRST)#x]
+- SubqueryAlias t
+- LocalRelation [col#x]


-- !query
SELECT listagg(DISTINCT CAST(col AS STRING COLLATE UTF8_LCASE)) WITHIN GROUP (ORDER BY col) FROM VALUES ('ABC'), ('abc'), ('ABC') AS t(col)
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT_UNSAFE_CAST",
"sqlState" : "42K0K",
"messageParameters" : {
"castType" : "\"STRING COLLATE UTF8_LCASE\"",
"funcName" : "`listagg`",
"inputType" : "\"STRING\""
}
}


-- !query
SELECT listagg(DISTINCT CAST(col AS STRING)) WITHIN GROUP (ORDER BY col) FROM VALUES (X'414243'), (X'616263'), (X'414243') AS t(col)
-- !query analysis
Aggregate [listagg(distinct cast(col#x as string), null, col#x ASC NULLS FIRST, 0, 0) AS listagg(DISTINCT CAST(col AS STRING), NULL) WITHIN GROUP (ORDER BY col ASC NULLS FIRST)#x]
+- SubqueryAlias t
+- LocalRelation [col#x]


-- !query
SELECT listagg(DISTINCT CAST(col AS STRING COLLATE UTF8_LCASE)) WITHIN GROUP (ORDER BY col) FROM VALUES (X'414243'), (X'616263'), (X'414243') AS t(col)
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT_UNSAFE_CAST",
"sqlState" : "42K0K",
"messageParameters" : {
"castType" : "\"STRING COLLATE UTF8_LCASE\"",
"funcName" : "`listagg`",
"inputType" : "\"BINARY\""
}
}


-- !query
SELECT listagg(DISTINCT CAST(col AS BINARY)) WITHIN GROUP (ORDER BY col) FROM VALUES ('ABC'), ('abc'), ('ABC') AS t(col)
-- !query analysis
Aggregate [listagg(distinct cast(col#x as binary), null, col#x ASC NULLS FIRST, 0, 0) AS listagg(DISTINCT CAST(col AS BINARY), NULL) WITHIN GROUP (ORDER BY col ASC NULLS FIRST)#x]
+- SubqueryAlias t
+- LocalRelation [col#x]


-- !query
SELECT listagg(DISTINCT CAST(col AS BINARY)) WITHIN GROUP (ORDER BY col) FROM (SELECT col COLLATE UTF8_LCASE AS col FROM VALUES ('ABC'), ('abc'), ('ABC') AS t(col))
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT_UNSAFE_CAST",
"sqlState" : "42K0K",
"messageParameters" : {
"castType" : "\"BINARY\"",
"funcName" : "`listagg`",
"inputType" : "\"STRING COLLATE UTF8_LCASE\""
}
}
Loading