[SPARK-55501][SQL] Fix listagg distinct + within group order by bug#54297
[SPARK-55501][SQL] Fix listagg distinct + within group order by bug#54297helioshe4 wants to merge 16 commits intoapache:masterfrom
Conversation
mikhailnik-db
left a comment
There was a problem hiding this comment.
Thank you for working on this pr! It will be a useful change for users. I'm a little concerned about the correctness of the solution for each type. Maybe we can think in a way of whitelisting the types we consider safe to cast
| case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _) | ||
| if agg.isDistinct && listAgg.needSaveOrderValue => | ||
| throwFunctionAndOrderExpressionMismatchError(listAgg) | ||
| // Allow when the mismatch is only because child was cast | ||
| val mismatchDueToCast = listAgg.orderExpressions.size == 1 && | ||
| (listAgg.child match { | ||
| case Cast(castChild, _, _, _) => | ||
| listAgg.orderExpressions.head.child.semanticEquals(castChild) | ||
| case _ => false | ||
| }) | ||
| if (!mismatchDueToCast) { | ||
| throwFunctionAndOrderExpressionMismatchError(listAgg) | ||
| } |
There was a problem hiding this comment.
For context: the purpose of this check is to prevent problems with a distinct framework. Simply speaking, all aggregation functions with distinct arguments are rewritten by moving those arguments to GROUP BY.
You can imagine it as
SELECT agg(distinct col)
FROM table
~
SELECT agg(col)
FROM (
SELECT col FROM table
GROUP BY col
)
listagg with distinct treats the argument and the order expression as keys.
SELECT listagg(distinct col) WITHIN GROUP (ORDER BY col')
FROM table
~
SELECT listagg(col) WITHIN GROUP (ORDER BY col')
FROM (SELECT col, col'
FROM table
GROUP BY col, col'
)
Before this change, there was a simple invariant: if col semantically equals col' then GROUP BY col, col' is equivalent to GROUP BY col, which is the expected behavior for a user.
Now we want to relax this check, assuming that for the column of any type GROUP BY CAST(col AS STRING), col ~ GROUP BY CAST(col AS STRING) ~ GROUP BY col (and the same with CAST(col AS BINARY)). I do not have any counterexamples. It seems a reasonable assumption, but it should be double-checked.
@helioshe4, I'm afraid the only way to prove correctness is to go through all existing types and check the logic of the cast to string and binary. My three main concerns:
- There could be some normalisation or absence of it when needed, e.g., floating‑point numbers usually have 2 encodings for zero:
0and-0. They are equal, but will we normalize them when casting to string or binary? - loss of precision or some other information. e.g. when converting timestamps or floating‑point numbers
- Collations. They were created to control the equality relation of strings. Can we change them by casting?
There was a problem hiding this comment.
Now we want to relax this check, assuming that for the column of any type GROUP BY CAST(col AS STRING), col ~ GROUP BY CAST(col AS STRING) ~ GROUP BY col (and the same with CAST(col AS BINARY)). I do not have any counterexamples. It seems a reasonable assumption, but it should be double-checked.
@cloud-fan @MaxGekk, maybe you know some counterexamples?
There was a problem hiding this comment.
@mikhailnik-db thanks for the detailed explanation!
To address your concerns
- you're right about the normalization of fp numbers.
The implicit cast to STRING doesn't preserve GROUP BY equality for float/double types because ListAgg.child is not normalized before casting to string (leading to "0.0" and "-0.0"), while GROUP BY keys are normalized.
The root cause is that the implicit cast is applied before DISTINCT deduplication rather than after, so DISTINCT operates on string values (where "-0.0" != "0.0") instead of double values (where -0.0 = 0.0). I feel like a cleaner solution would be to normalize col before casting to string (to keep operations between the order expression col and child col consistent), but this may cause some other side effects or go against user expectations.
For now, I've added a whitelist of types where the cast preserves equality semantics, and a specific error message for unsafe types (float, double) explaining why they're rejected.
- no loss in precision
For DecimalType, toPlainString is used (preserves scale/precision), and toString() is used for other numeric types. The Date types and Interval types all used precise conversion with no loss.
- yes, explicit casting with collation could be an issue if child col's collation isn't the same as order by col's collation
implicit casting doesn't change collation, but we block explicit casting with collation. I'm taking a conservative approach where we can only explicitly cast FROM StringType with UTF8_binary and cast TO StringType with UTF8_binary.
I've updated my PR comment to explain the logic/safety of casting.
sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
Outdated
Show resolved
Hide resolved
| if agg.isDistinct && listAgg.needSaveOrderValue => | ||
| throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( | ||
| listAgg.prettyName, listAgg.child, listAgg.orderExpressions) | ||
| // Allow when the mismatch is only because child was cast | ||
| val mismatchDueToCast = listAgg.orderExpressions.size == 1 && | ||
| (listAgg.child match { | ||
| case Cast(castChild, _, _, _) => | ||
| listAgg.orderExpressions.head.child.semanticEquals(castChild) | ||
| case _ => false | ||
| }) | ||
| if (!mismatchDueToCast) { | ||
| throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( | ||
| listAgg.prettyName, listAgg.child, listAgg.orderExpressions) | ||
| } |
There was a problem hiding this comment.
nit: now there is more logic, so it's worth abstracting into ListAgg's method
There was a problem hiding this comment.
refactored as a member function of ListAgg, which returns 3 possible results to indicate the nature of the column mismatch (1. Safe cast, 2. Unsafe cast, 3. Mismatch not due to casting)
There was a problem hiding this comment.
Actually, I think it'd be even better to extract everything, including listAgg.needSaveOrderValue, into a method like validateOrderingForDistinctFunction that throws when needed.
case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
if agg.isDistinct =>
listAgg.validateOrderingForDistinctFunction()
There was a problem hiding this comment.
Moreover, the logic in this pr is not trivial, so it makes sense to have a feature flag guarding changes in this pr. It will be convenient to put if(flag) branching inside a listagg method
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
Outdated
Show resolved
Hide resolved
| if agg.isDistinct && listAgg.needSaveOrderValue => | ||
| throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( | ||
| listAgg.prettyName, listAgg.child, listAgg.orderExpressions) | ||
| // Allow when the mismatch is only because child was cast | ||
| val mismatchDueToCast = listAgg.orderExpressions.size == 1 && | ||
| (listAgg.child match { | ||
| case Cast(castChild, _, _, _) => | ||
| listAgg.orderExpressions.head.child.semanticEquals(castChild) | ||
| case _ => false | ||
| }) | ||
| if (!mismatchDueToCast) { | ||
| throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( | ||
| listAgg.prettyName, listAgg.child, listAgg.orderExpressions) | ||
| } |
There was a problem hiding this comment.
Actually, I think it'd be even better to extract everything, including listAgg.needSaveOrderValue, into a method like validateOrderingForDistinctFunction that throws when needed.
case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
if agg.isDistinct =>
listAgg.validateOrderingForDistinctFunction()
| case _: DateType | TimestampType | TimestampNTZType => true | ||
| case _: TimeType => true |
There was a problem hiding this comment.
Just to double check: is there a timezone stored in any types, and if yes, how is it represented in a string after cast?
There was a problem hiding this comment.
good point. the Timestamp object is internally represented as microseconds in epoch UTC, and also holds the Timezone information (set at session level). So when converting to string, the string displays the time according to local (session) timezone, but the timezone is not in the actual string.
So i believe this causes issues for daylight savings fallback (which I validated with a test). e.g. if 2 timestamps are recorded, one 30 min before and one 30 min after DST fallback occurs, their string representation would be the same (since the later one gets reduced 1hr), but their GROUP BY key value would be different.
if this is the case, we should remove TimestampType entirely from the whitelist, but TimestampNTZType should be safe because its timezone-agnostic.
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
Outdated
Show resolved
Hide resolved
| * @see [[orderMismatchCastSafety]] | ||
| */ | ||
| private def isCastTargetSafeForDistinct(dt: DataType): Boolean = dt match { | ||
| case st: StringType => st.supportsBinaryEquality |
There was a problem hiding this comment.
implicit casting doesn't change collation, but we block explicit casting with collation.
I think, at this point, we cannot say whether the child's cast was explicit or implicit. So, if we do this check for both, is it true that the implicit cast always uses UTF8_binary as the default collation? Because otherwise, we can accidentally block some implicit casts like int -> string(UTF8_LCASE_COLLATION_ID)
There was a problem hiding this comment.
Yes UTF8_binary is the default (and only possible) collation for implicit casts.
in the implicitCast function in TypeCoercion.scala:
(Note the case on L234 is matched instead of L233 since ListAgg defines its inputType as StringTypeWithCollation which inherits from AbstractStringType
And st.defaultConcreteType is a StringType which has collation UTF8_BINARY_COLLATION_ID:
spark/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
Lines 112 to 113 in 0ab4107
| if agg.isDistinct && listAgg.needSaveOrderValue => | ||
| throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( | ||
| listAgg.prettyName, listAgg.child, listAgg.orderExpressions) | ||
| // Allow when the mismatch is only because child was cast | ||
| val mismatchDueToCast = listAgg.orderExpressions.size == 1 && | ||
| (listAgg.child match { | ||
| case Cast(castChild, _, _, _) => | ||
| listAgg.orderExpressions.head.child.semanticEquals(castChild) | ||
| case _ => false | ||
| }) | ||
| if (!mismatchDueToCast) { | ||
| throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( | ||
| listAgg.prettyName, listAgg.child, listAgg.orderExpressions) | ||
| } |
There was a problem hiding this comment.
Moreover, the logic in this pr is not trivial, so it makes sense to have a feature flag guarding changes in this pr. It will be convenient to put if(flag) branching inside a listagg method
|
|
||
| /** | ||
| * Determines whether the order mismatch between [[child]] and [[orderExpressions]] is due to | ||
| * a cast, and if so, whether that cast is safe for DISTINCT deduplication. |
There was a problem hiding this comment.
I think the general theory here is: if ordering key is col and the input expression is transform(col), we don't need to save order-value, if the transformation can preserve the equality.
So a cleaner solution is to add an optimizer rule to match ListAgg, and replace its ordering key with the input expression, if the transformation preserves the equality.
We can still use the current cast check in this PR to determine equality preserving transformations, and leave a TODO to detect more such cases.
There was a problem hiding this comment.
I think the general theory here is: if ordering key is col and the input expression is transform(col), we don't need to save order-value, if the transformation can preserve the equality.
So a cleaner solution is to add an optimizer rule to match ListAgg, and replace its ordering key with the input expression, if the transformation preserves the equality.
It won't work out of box, because even if the transformation preserves the equality, it does not necessarily preserve the ordering. eg, int -> string changes the order from numeric to lexicographic.
We can do the opposite: save col and transform and do the transformation on the fly during execution.
mikhailnik-db
left a comment
There was a problem hiding this comment.
LGTM after resolving comments
| case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _) | ||
| if agg.isDistinct && listAgg.needSaveOrderValue => | ||
| throwFunctionAndOrderExpressionMismatchError(listAgg) | ||
| if agg.isDistinct => listAgg.validateDistinctOrderCompatibility() |
There was a problem hiding this comment.
Sorry for ping-ponging, but I've just realized that this approach is not correct. The logic here should follow the rule "if we go to a non-default branch, that means we found an error and must throw". Currently, we can just successfully executevalidateDistinctOrderCompatibility() and skip the general check from case _ =>. The same is applicable to CheckAnalysis as well.
Not sure how to better structure the code here. Probably it's okay to have a method with very similar logic to validateDistinctOrderCompatibility, but returning a bool, whether we should throw. But it's still code duplication...
Open to suggestions :)
There was a problem hiding this comment.
hm, i think a slightly less duplicated solution would be to move the general check from case _ to outside the match in AggregateExpressionResolver.scala so that it gets executed regardless. This is actually what's happening in CheckAnalysis (since ListAgg can't get matched to anything else and will break out of the match if no errors are thrown), so no changes should be required in CheckAnalysis
There was a problem hiding this comment.
Here it'll work now, but I believe it's not good in the long run. It'll be too easy for someone later to add some validation that should run for all AggregateExpression, and it will be very natural to add it in the end without reading all of them (you can see how many there are in CheckAnalysis). And all their tests will pass, because they won't use such a specific function as listagg for testing.
Codebase like Spark has a lot of contributors, and we should make the code as error-proof as possible for future generations.
There was a problem hiding this comment.
sounds good. I've implemented it similar to your original suggestion. We have a boolean ListAgg.hasDistinctOrderIncompatibility that indicates whether or not an illegal mismatch has occurred.
if it has, then we call the listAgg.throwDistinctOrderError(), which chooses the correct error to throw depending on the type of mistmatch.
If it hasn't then we continue matching the next cases.
There's a bit of duplication on the orderMistmatchCastSafety call but it's matching on different results.
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
Outdated
Show resolved
Hide resolved
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
Outdated
Show resolved
Hide resolved
sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql
Outdated
Show resolved
Hide resolved
| if agg.isDistinct && listAgg.needSaveOrderValue => | ||
| throw QueryCompilationErrors.functionAndOrderExpressionMismatchError( | ||
| listAgg.prettyName, listAgg.child, listAgg.orderExpressions) | ||
| if agg.isDistinct => listAgg.validateDistinctOrderCompatibility() |
There was a problem hiding this comment.
I don't like this approach as it's unclear what happens next if we don't fail here. Does the DISTINCT execution path save the order value?
Even if we add comments here, it's making an assumption of the physical execution path that is far away from here.
I still prefer my previous proposal: we can replace the order value expression of ListAgg to a different but order-preserving expression (certain CAST). It needs to happen before CheckAnalysis, so we can add a new analyzer rule to do it. For the new single-pass analyzer, this rewrite should happen after we fully resolve ListAgg.
There was a problem hiding this comment.
@cloud-fan from what i understand, wouldn't this new analyzer rule would only work for very limited types? (i can think of boolean, Date, and binary)
Numeric types would not work (e.g. 2 < 10 but "2" > "10"), which I feel like would be the main fixed case of this PR
What changes were proposed in this pull request?
There is a bug with listagg expression and using DISTINCT and ORDER BY together, when the ORDER BY column is non-string/binary. The ListAgg.child gets casted to a string type, and the CheckAnalyzer/Resolver evaluates the child column as not semantically equal to the ORDER BY column. This fails the query (since it believes the listagg.child column is not enough to determine order, and could produce non-deterministic results). This is not the expected behaviour, as the ORDER BY column is deterministic since it's equivalent to the child col (before casting).
The fix I'm proposing is to loosen the restriction on the check by the Analyzer/Resolver. We allow the listagg query to execute with DISTINCT + ORDER BY even if the child col is not semantically equal to the ORDER BY col, we only need to ensure that the child col without casting is semantically equal to the ORDER BY col and the cast is safe.
We follow this criteria to determine if a DataType can be safely casted to StringType (no datatype is implicitly casted to BinaryType, so we can ignore):
For 2 values a,b, of DataType T:
We only consider the datatypes that can be casted to string (e.g. we ignore complex datatypes like Array, Struct, Map).
The only 2 DataTypes that don't pass these criteria are DoubleType and FloatType, since GROUP BY 0.0 = -0.0, but CAST(a as STRING) = "0.0" != CAST(b as STRING) = "-0.0". This is because Double/Float are normalized before GROUP BY, but not before casting.
Other numeric types are casted using
.toString()ortoPlainString()which preserve precision/scale. Datetime/Interval types are converted with no loss.Why are the changes needed?
It's a bug, as explained above.
Does this PR introduce any user-facing change?
Yes. Previous behaviour resulted in error:
Example query:
throws
[INVALID_WITHIN_GROUP_EXPRESSION.MISMATCH_WITH_DISTINCT_INPUT] Invalid functionlistaggwith WITHIN GROUP. The function is invoked with DISTINCT and WITHIN GROUP but expressions "col" and "col" do not match. The WITHIN GROUP ordering expression must be picked from the function inputs. SQLSTATE: 42K0K;I'm proposing that this query (and similar ones) now pass with the result
1, 3, 99, 100It is a user-facing change compared to the released Spark versions.
How was this patch tested?
Unit tests added to
DataFrameAggregateSuite.Was this patch authored or co-authored using generative AI tooling?
Co-authored.
Generated-by: Claude v2.1.39