Skip to content

Commit d0b1b0b

Browse files
vladimirg-dbMaxGekk
authored andcommitted
[SPARK-50990][SQL] Refactor UpCast resolution out of the Analyzer
### What changes were proposed in this pull request? Refactor `UpCast` resolution out of the `Analyzer`. ### Why are the changes needed? To reuse this code in single-pass `Resolver`. ### Does this PR introduce _any_ user-facing change? No, just a refactoring. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? Copilot.nvim. Closes #49669 from vladimirg-db/vladimirg-db/refactor-upcast-resolution-out. Authored-by: Vladimir Golubev <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent 762599c commit d0b1b0b

File tree

2 files changed

+73
-33
lines changed

2 files changed

+73
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3727,45 +3727,16 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
37273727
* Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate.
37283728
*/
37293729
object ResolveUpCast extends Rule[LogicalPlan] {
3730-
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
3731-
val fromStr = from match {
3732-
case l: LambdaVariable => "array element"
3733-
case e => e.sql
3734-
}
3735-
throw QueryCompilationErrors.upCastFailureError(fromStr, from, to, walkedTypePath)
3736-
}
3737-
37383730
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
37393731
_.containsPattern(UP_CAST), ruleId) {
37403732
case p if !p.childrenResolved => p
37413733
case p if p.resolved => p
37423734

37433735
case p => p.transformExpressionsWithPruning(_.containsPattern(UP_CAST), ruleId) {
3744-
case u @ UpCast(child, _, _) if !child.resolved => u
3745-
3746-
case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] =>
3747-
throw SparkException.internalError(
3748-
s"UpCast only supports DecimalType as AbstractDataType yet, but got: $target")
3749-
3750-
case UpCast(child, target, walkedTypePath) if target == DecimalType
3751-
&& child.dataType.isInstanceOf[DecimalType] =>
3752-
assert(walkedTypePath.nonEmpty,
3753-
"object DecimalType should only be used inside ExpressionEncoder")
3754-
3755-
// SPARK-31750: if we want to upcast to the general decimal type, and the `child` is
3756-
// already decimal type, we can remove the `Upcast` and accept any precision/scale.
3757-
// This can happen for cases like `spark.read.parquet("/tmp/file").as[BigDecimal]`.
3758-
child
3759-
3760-
case UpCast(child, target: AtomicType, _)
3761-
if conf.getConf(SQLConf.LEGACY_LOOSE_UPCAST) &&
3762-
child.dataType == StringType =>
3763-
Cast(child, target.asNullable)
3764-
3765-
case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) =>
3766-
fail(child, u.dataType, walkedTypePath)
3767-
3768-
case u @ UpCast(child, _, _) => Cast(child, u.dataType)
3736+
case unresolvedUpCast @ UpCast(child, _, _) if !child.resolved =>
3737+
unresolvedUpCast
3738+
case unresolvedUpCast: UpCast =>
3739+
UpCastResolution.resolve(unresolvedUpCast)
37693740
}
37703741
}
37713742
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.SparkException
21+
import org.apache.spark.sql.catalyst.SQLConfHelper
22+
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, UpCast}
23+
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
24+
import org.apache.spark.sql.errors.QueryCompilationErrors
25+
import org.apache.spark.sql.internal.SQLConf
26+
import org.apache.spark.sql.types.{AtomicType, DataType, DecimalType, StringType}
27+
28+
object UpCastResolution extends SQLConfHelper {
29+
def resolve(unresolvedUpCast: UpCast): Expression = unresolvedUpCast match {
30+
case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] =>
31+
throw SparkException.internalError(
32+
s"UpCast only supports DecimalType as AbstractDataType yet, but got: $target"
33+
)
34+
35+
case UpCast(child, target, walkedTypePath)
36+
if target == DecimalType
37+
&& child.dataType.isInstanceOf[DecimalType] =>
38+
assert(
39+
walkedTypePath.nonEmpty,
40+
"object DecimalType should only be used inside ExpressionEncoder"
41+
)
42+
43+
// SPARK-31750: if we want to upcast to the general decimal type, and the `child` is
44+
// already decimal type, we can remove the `Upcast` and accept any precision/scale.
45+
// This can happen for cases like `spark.read.parquet("/tmp/file").as[BigDecimal]`.
46+
child
47+
48+
case UpCast(child, target: AtomicType, _)
49+
if conf.getConf(SQLConf.LEGACY_LOOSE_UPCAST) &&
50+
child.dataType == StringType =>
51+
Cast(child, target.asNullable)
52+
53+
case unresolvedUpCast @ UpCast(child, _, walkedTypePath)
54+
if !Cast.canUpCast(child.dataType, unresolvedUpCast.dataType) =>
55+
fail(child, unresolvedUpCast.dataType, walkedTypePath)
56+
57+
case unresolvedUpCast @ UpCast(child, _, _) =>
58+
Cast(child, unresolvedUpCast.dataType)
59+
}
60+
61+
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
62+
val fromStr = from match {
63+
case l: LambdaVariable => "array element"
64+
case e => e.sql
65+
}
66+
67+
throw QueryCompilationErrors.upCastFailureError(fromStr, from, to, walkedTypePath)
68+
}
69+
}

0 commit comments

Comments
 (0)