Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4284,6 +4284,11 @@
"The function is invoked with DISTINCT and WITHIN GROUP but expressions <funcArg> and <orderingExpr> do not match. The WITHIN GROUP ordering expression must be picked from the function inputs."
]
},
"MISMATCH_WITH_DISTINCT_INPUT_UNSAFE_CAST" : {
"message" : [
"The function <funcName> with DISTINCT and WITHIN GROUP (ORDER BY) is not supported for <inputType> input. Explicitly cast the input to <castType> before passing it to the function argument and ORDER BY expression."
]
},
"WITHIN_GROUP_MISSING" : {
"message" : [
"WITHIN GROUP is required for the function."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,7 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
messageParameters = Map("funcName" -> toSQLExpr(w)))

case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
if agg.isDistinct && listAgg.needSaveOrderValue =>
throw QueryCompilationErrors.functionAndOrderExpressionMismatchError(
listAgg.prettyName, listAgg.child, listAgg.orderExpressions)
if agg.isDistinct => listAgg.validateDistinctOrderCompatibility()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this approach as it's unclear what happens next if we don't fail here. Does the DISTINCT execution path save the order value?

Even if we add comments here, it's making an assumption of the physical execution path that is far away from here.

I still prefer my previous proposal: we can replace the order value expression of ListAgg to a different but order-preserving expression (certain CAST). It needs to happen before CheckAnalysis, so we can add a new analyzer rule to do it. For the new single-pass analyzer, this rewrite should happen after we fully resolve ListAgg.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan from what i understand, wouldn't this new analyzer rule would only work for very limited types? (i can think of boolean, Date, and binary)

Numeric types would not work (e.g. 2 < 10 but "2" > "10"), which I feel like would be the main fixed case of this PR


case w: WindowExpression =>
WindowResolution.validateResolvedWindowExpression(w)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ class AggregateExpressionResolver(
* `handleOuterAggregateExpression`);
* - Validation:
* 1. [[ListAgg]] is not allowed in DISTINCT aggregates if it contains [[SortOrder]] different
* from its child;
* from its child. However, when [[SQLConf.LISTAGG_ALLOW_DISTINCT_CAST_WITH_ORDER]] is
* enabled, a mismatch is tolerated if it is solely due to a [[Cast]] whose source type is
* injective when cast to string (see [[ListAgg.validateDistinctOrderCompatibility]]);
* 2. Nested aggregate functions are not allowed;
* 3. Nondeterministic expressions in the subtree of a related aggregate function are not
* allowed;
Expand Down Expand Up @@ -113,26 +115,25 @@ class AggregateExpressionResolver(
}
}

private def validateResolvedAggregateExpression(aggregateExpression: AggregateExpression): Unit =
private def validateResolvedAggregateExpression(
aggregateExpression: AggregateExpression): Unit = {
aggregateExpression match {
case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
if agg.isDistinct && listAgg.needSaveOrderValue =>
throwFunctionAndOrderExpressionMismatchError(listAgg)
if agg.isDistinct => listAgg.validateDistinctOrderCompatibility()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for ping-ponging, but I've just realized that this approach is not correct. The logic here should follow the rule "if we go to a non-default branch, that means we found an error and must throw". Currently, we can just successfully executevalidateDistinctOrderCompatibility() and skip the general check from case _ =>. The same is applicable to CheckAnalysis as well.

Not sure how to better structure the code here. Probably it's okay to have a method with very similar logic to validateDistinctOrderCompatibility, but returning a bool, whether we should throw. But it's still code duplication...

Open to suggestions :)

Copy link
Author

@helioshe4 helioshe4 Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, i think a slightly less duplicated solution would be to move the general check from case _ to outside the match in AggregateExpressionResolver.scala so that it gets executed regardless. This is actually what's happening in CheckAnalysis (since ListAgg can't get matched to anything else and will break out of the match if no errors are thrown), so no changes should be required in CheckAnalysis

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it'll work now, but I believe it's not good in the long run. It'll be too easy for someone later to add some validation that should run for all AggregateExpression, and it will be very natural to add it in the end without reading all of them (you can see how many there are in CheckAnalysis). And all their tests will pass, because they won't use such a specific function as listagg for testing.

Codebase like Spark has a lot of contributors, and we should make the code as error-proof as possible for future generations.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. I've implemented it similar to your original suggestion. We have a boolean ListAgg.hasDistinctOrderIncompatibility that indicates whether or not an illegal mismatch has occurred.

if it has, then we call the listAgg.throwDistinctOrderError(), which chooses the correct error to throw depending on the type of mistmatch.

If it hasn't then we continue matching the next cases.

There's a bit of duplication on the orderMistmatchCastSafety call but it's matching on different results.

case _ =>
if (expressionResolutionContextStack.peek().hasAggregateExpressions) {
throwNestedAggregateFunction(aggregateExpression)
}

aggregateExpression.aggregateFunction.children.foreach { child =>
if (!child.deterministic) {
throwAggregateFunctionWithNondeterministicExpression(
aggregateExpression,
child
)
}
}
}

if (expressionResolutionContextStack.peek().hasAggregateExpressions) {
throwNestedAggregateFunction(aggregateExpression)
}

aggregateExpression.aggregateFunction.children.foreach { child =>
if (!child.deterministic) {
throwAggregateFunctionWithNondeterministicExpression(aggregateExpression, child)
}
}
}

/**
* If the [[AggregateExpression]] has outer references in its subtree, we need to handle it in a
* special way. The whole process is explained in the [[SubqueryScope]] scaladoc, but in short
Expand Down Expand Up @@ -212,14 +213,6 @@ class AggregateExpressionResolver(
}
}

private def throwFunctionAndOrderExpressionMismatchError(listAgg: ListAgg) = {
throw QueryCompilationErrors.functionAndOrderExpressionMismatchError(
listAgg.prettyName,
listAgg.child,
listAgg.orderExpressions
)
}

private def throwNestedAggregateFunction(aggregateExpression: AggregateExpression): Nothing = {
throw new AnalysisException(
errorClass = "NESTED_AGGREGATE_FUNCTION",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtil
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
Expand Down Expand Up @@ -564,6 +565,126 @@ case class ListAgg(
false
}

/**
* Validates that the ordering expression is compatible with DISTINCT deduplication.
*
* When LISTAGG(DISTINCT col) WITHIN GROUP (ORDER BY col) is used on a non-string column,
* the child is implicitly cast to string (with UTF8_BINARY collation). The DISTINCT rewrite
* (see [[RewriteDistinctAggregates]]) uses GROUP BY on both the original and cast columns,
* so the cast must preserve equality semantics: values that are GROUP BY-equal must cast to
* equal strings, and vice versa. Types like Float/Double violate this because IEEE 754
* negative zero (-0.0) and positive zero (0.0) are equal but produce different strings.
*
* This method is a no-op when the order expression matches the child (i.e.,
* [[needSaveOrderValue]] is false). Otherwise, the behavior depends on the
* [[SQLConf.LISTAGG_ALLOW_DISTINCT_CAST_WITH_ORDER]] config:
* - If enabled, delegates to [[orderMismatchCastSafety]] to determine whether the
* mismatch is due to a safe cast, an unsafe cast, or not a cast at all.
* - If disabled, rejects any mismatch.
*
* @throws AnalysisException if the ordering is incompatible with DISTINCT
*/
def validateDistinctOrderCompatibility(): Unit = {
if (needSaveOrderValue) {
if (SQLConf.get.listaggAllowDistinctCastWithOrder) {
orderMismatchCastSafety match {
case CastSafetyResult.SafeCast => // safe cast, allow
case CastSafetyResult.UnsafeCast(inputType, castType) =>
throwFunctionAndOrderExpressionUnsafeCastError(inputType, castType)
case CastSafetyResult.NotACast =>
throwFunctionAndOrderExpressionMismatchError()
}
} else {
throwFunctionAndOrderExpressionMismatchError()
}
}
}

private def throwFunctionAndOrderExpressionMismatchError() = {
throw QueryCompilationErrors.functionAndOrderExpressionMismatchError(
prettyName, child, orderExpressions)
}

private def throwFunctionAndOrderExpressionUnsafeCastError(
inputType: DataType, castType: DataType) = {
throw QueryCompilationErrors.functionAndOrderExpressionUnsafeCastError(
prettyName, inputType, castType)
}

/**
* Classifies the order-expression mismatch as a safe cast, unsafe cast, or not a cast.
*
* @see [[validateDistinctOrderCompatibility]] for the full invariant this enforces
*/
private def orderMismatchCastSafety: CastSafetyResult = {
if (orderExpressions.size != 1) return CastSafetyResult.NotACast
child match {
case Cast(castChild, castType, _, _)
if orderExpressions.head.child.semanticEquals(castChild) =>
if (isCastSafeForDistinct(castChild.dataType) &&
isCastTargetSafeForDistinct(castType)) {
CastSafetyResult.SafeCast
} else {
CastSafetyResult.UnsafeCast(castChild.dataType, castType)
}
case _ => CastSafetyResult.NotACast
}
}

/**
* Returns true if casting `dt` to string/binary is injective for DISTINCT deduplication.
*
* @see [[validateDistinctOrderCompatibility]]
*/
private def isCastSafeForDistinct(dt: DataType): Boolean = dt match {
case _: IntegerType | LongType | ShortType | ByteType => true
case _: DecimalType => true
case _: DateType | TimestampNTZType => true
case _: TimeType => true
case _: CalendarIntervalType => true
case _: YearMonthIntervalType => true
case _: DayTimeIntervalType => true
case BooleanType => true
case BinaryType => true
case st: StringType => st.isUTF8BinaryCollation
case _: DoubleType | FloatType => false
// During DST fall-back, two distinct UTC epochs can format to the same local time string
// because the default format omits the timezone offset. TimestampNTZType is safe (uses UTC).
case _: TimestampType => false
case _ => false
}

/**
* Returns true if the target type's equality semantics are safe for DISTINCT deduplication
* (i.e., UTF8_BINARY collation or BinaryType).
*
* @see [[validateDistinctOrderCompatibility]]
*/
private def isCastTargetSafeForDistinct(dt: DataType): Boolean = dt match {
case st: StringType => st.isUTF8BinaryCollation
case BinaryType => true
case _ => false
}

/**
* Result of checking whether a LISTAGG(DISTINCT) order-expression mismatch
* is caused by a cast and whether that cast is safe for deduplication.
*/
private sealed trait CastSafetyResult

private object CastSafetyResult {
/** The mismatch is not due to a cast at all. */
case object NotACast extends CastSafetyResult

/** The mismatch is due to a cast that is safe for DISTINCT. */
case object SafeCast extends CastSafetyResult

/** The mismatch is due to a cast that is unsafe for DISTINCT. */
case class UnsafeCast(
inputType: DataType,
castType: DataType) extends CastSafetyResult
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(
child = newChildren.head,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,18 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
"orderingExpr" -> orderExpr.map(order => toSQLExpr(order.child)).mkString(", ")))
}

def functionAndOrderExpressionUnsafeCastError(
functionName: String,
inputType: DataType,
castType: DataType): Throwable = {
new AnalysisException(
errorClass = "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT_UNSAFE_CAST",
messageParameters = Map(
"funcName" -> toSQLId(functionName),
"inputType" -> toSQLType(inputType),
"castType" -> toSQLType(castType)))
}

def wrongCommandForObjectTypeError(
operation: String,
requiredType: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6889,6 +6889,17 @@ object SQLConf {
.booleanConf
.createWithDefault(Utils.isTesting)

val LISTAGG_ALLOW_DISTINCT_CAST_WITH_ORDER =
buildConf("spark.sql.listagg.allowDistinctCastWithOrder.enabled")
.internal()
.doc("When true, LISTAGG(DISTINCT expr) WITHIN GROUP (ORDER BY expr) is allowed on " +
"non-string expr when the implicit cast to string preserves equality (e.g., integer, " +
"decimal, date). When false, the function argument and ORDER BY expression must have " +
"the exact same type, which requires explicit casts.")
.version("4.2.0")
.booleanConf
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -8115,6 +8126,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def isTimeTypeEnabled: Boolean = getConf(SQLConf.TIME_TYPE_ENABLED)

def listaggAllowDistinctCastWithOrder: Boolean = getConf(LISTAGG_ALLOW_DISTINCT_CAST_WITH_ORDER)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,72 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"orderingExpr" : "\"collate(c1, utf8_binary)\""
}
}


-- !query
SELECT listagg(DISTINCT CAST(col AS STRING)) WITHIN GROUP (ORDER BY col) FROM VALUES ('ABC'), ('abc'), ('ABC') AS t(col)
-- !query analysis
Aggregate [listagg(distinct cast(col#x as string), null, col#x ASC NULLS FIRST, 0, 0) AS listagg(DISTINCT CAST(col AS STRING), NULL) WITHIN GROUP (ORDER BY col ASC NULLS FIRST)#x]
+- SubqueryAlias t
+- LocalRelation [col#x]


-- !query
SELECT listagg(DISTINCT CAST(col AS STRING COLLATE UTF8_LCASE)) WITHIN GROUP (ORDER BY col) FROM VALUES ('ABC'), ('abc'), ('ABC') AS t(col)
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT_UNSAFE_CAST",
"sqlState" : "42K0K",
"messageParameters" : {
"castType" : "\"STRING COLLATE UTF8_LCASE\"",
"funcName" : "`listagg`",
"inputType" : "\"STRING\""
}
}


-- !query
SELECT listagg(DISTINCT CAST(col AS STRING)) WITHIN GROUP (ORDER BY col) FROM VALUES (X'414243'), (X'616263'), (X'414243') AS t(col)
-- !query analysis
Aggregate [listagg(distinct cast(col#x as string), null, col#x ASC NULLS FIRST, 0, 0) AS listagg(DISTINCT CAST(col AS STRING), NULL) WITHIN GROUP (ORDER BY col ASC NULLS FIRST)#x]
+- SubqueryAlias t
+- LocalRelation [col#x]


-- !query
SELECT listagg(DISTINCT CAST(col AS STRING COLLATE UTF8_LCASE)) WITHIN GROUP (ORDER BY col) FROM VALUES (X'414243'), (X'616263'), (X'414243') AS t(col)
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT_UNSAFE_CAST",
"sqlState" : "42K0K",
"messageParameters" : {
"castType" : "\"STRING COLLATE UTF8_LCASE\"",
"funcName" : "`listagg`",
"inputType" : "\"BINARY\""
}
}


-- !query
SELECT listagg(DISTINCT CAST(col AS BINARY)) WITHIN GROUP (ORDER BY col) FROM VALUES ('ABC'), ('abc'), ('ABC') AS t(col)
-- !query analysis
Aggregate [listagg(distinct cast(col#x as binary), null, col#x ASC NULLS FIRST, 0, 0) AS listagg(DISTINCT CAST(col AS BINARY), NULL) WITHIN GROUP (ORDER BY col ASC NULLS FIRST)#x]
+- SubqueryAlias t
+- LocalRelation [col#x]


-- !query
SELECT listagg(DISTINCT CAST(col AS BINARY)) WITHIN GROUP (ORDER BY col) FROM (SELECT col COLLATE UTF8_LCASE AS col FROM VALUES ('ABC'), ('abc'), ('ABC') AS t(col))
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT_UNSAFE_CAST",
"sqlState" : "42K0K",
"messageParameters" : {
"castType" : "\"BINARY\"",
"funcName" : "`listagg`",
"inputType" : "\"STRING COLLATE UTF8_LCASE\""
}
}
Loading