Skip to content

Commit 37ead02

Browse files
committed
[SPARK-54827][SQL] Add helper function TreeNode.containsTag
### What changes were proposed in this pull request? Add helper function `TreeNode.containsTag` ### Why are the changes needed? In many places, we don't care the tag value, we only need to check whether a tag exists. This new function can help simplify the code a bit, e.g. `getTagValue(Cast.BY_TABLE_INSERTION).isDefined` -> `containsTag(Cast.BY_TABLE_INSERTION)` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #53587 from zhengruifeng/containsTag. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent f25b381 commit 37ead02

File tree

21 files changed

+27
-25
lines changed

21 files changed

+27
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AliasResolution.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ object AliasResolution {
7070
private def extractOnly(e: Expression): Boolean = e match {
7171
case _: ExtractValue => e.children.forall(extractOnly)
7272
case _: Literal => true
73-
case attr: Attribute if attr.getTagValue(ResolverTag.SINGLE_PASS_IS_LCA).isEmpty => true
73+
case attr: Attribute if !attr.containsTag(ResolverTag.SINGLE_PASS_IS_LCA) => true
7474
case _ => false
7575
}
7676
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollationToStringType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ object ApplyDefaultCollationToStringType extends Rule[LogicalPlan] {
268268
newType => columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType))
269269

270270
case cast: Cast if hasDefaultStringType(cast.dataType) &&
271-
cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined =>
271+
cast.containsTag(Cast.USER_SPECIFIED_CAST) =>
272272
newType => cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType))
273273

274274
case Literal(value, dt) if hasDefaultStringType(dt) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,7 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
898898
"invalidExprSqls" -> invalidExprSqls.mkString(", ")))
899899

900900
case j @ LateralJoin(_, right, _, _)
901-
if j.getTagValue(LateralJoin.BY_TABLE_ARGUMENT).isEmpty =>
901+
if !j.containsTag(LateralJoin.BY_TABLE_ARGUMENT) =>
902902
right.plan.foreach {
903903
case Generate(pyudtf: PythonUDTF, _, _, _, _, _)
904904
if pyudtf.evalType == PythonEvalType.SQL_ARROW_UDTF =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ object CollationTypeCoercion extends SQLConfHelper {
3636
private val COLLATION_CONTEXT_TAG = new TreeNodeTag[DataType]("collationContext")
3737

3838
private def hasCollationContextTag(expr: Expression): Boolean = {
39-
expr.getTagValue(COLLATION_CONTEXT_TAG).isDefined
39+
expr.containsTag(COLLATION_CONTEXT_TAG)
4040
}
4141

4242
def apply(expression: Expression): Expression = expression match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
140140
}
141141
matched(ordinal)
142142

143-
case u @ UnresolvedAttribute(nameParts)
144-
if u.getTagValue(LogicalPlan.PLAN_ID_TAG).isEmpty =>
143+
case u @ UnresolvedAttribute(nameParts) if !u.containsTag(LogicalPlan.PLAN_ID_TAG) =>
145144
// UnresolvedAttribute with PLAN_ID_TAG should be resolved in resolveDataFrameColumn
146145
val result = withPosition(u) {
147146
resolveColumnByName(nameParts)
@@ -451,7 +450,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
451450
u: UnresolvedAttribute,
452451
q: LogicalPlan,
453452
includeLastResort: Boolean = false): Option[Expression] = {
454-
assert(u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty,
453+
assert(u.containsTag(LogicalPlan.PLAN_ID_TAG),
455454
s"UnresolvedAttribute $u should have a Plan Id tag")
456455

457456
resolveDataFrameColumn(u, q.children).map { r =>
@@ -524,7 +523,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
524523
val planId = planIdOpt.get
525524
logDebug(s"Extract plan_id $planId from $u")
526525

527-
val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty
526+
val isMetadataAccess = u.containsTag(LogicalPlan.IS_METADATA_COL)
528527

529528
val (resolved, matched) = resolveDataFrameColumnByPlanId(
530529
u, planId, isMetadataAccess, q, 0)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class ResolveDataFrameDropColumns(val catalogManager: CatalogManager)
3636
// df.drop(col("non-existing-column"))
3737
val dropped = d.dropList.flatMap {
3838
case u: UnresolvedAttribute =>
39-
if (u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty) {
39+
if (u.containsTag(LogicalPlan.PLAN_ID_TAG)) {
4040
// Plan Id comes from Spark Connect,
4141
// Here we ignore the `UnresolvedAttribute` if its Plan Id can be found
4242
// but column not found.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionValidation.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ object TypeCoercionValidation extends QueryErrorsBase {
9292
var issueFixedIfAnsiOff = true
9393
getAllExpressions(nonAnsiPlan).foreach(_.foreachUp {
9494
case e: Expression
95-
if e.getTagValue(DATA_TYPE_MISMATCH_ERROR).isDefined &&
96-
e.checkInputDataTypes().isFailure =>
95+
if e.containsTag(DATA_TYPE_MISMATCH_ERROR) && e.checkInputDataTypes().isFailure =>
9796
e.checkInputDataTypes() match {
9897
case TypeCheckResult.TypeCheckFailure(_) | _: TypeCheckResult.DataTypeMismatch =>
9998
issueFixedIfAnsiOff = false

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DefaultCollationTypeCoercion.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ object DefaultCollationTypeCoercion {
9292
* we should change all its occurrences to [[StringType]] with default collation.
9393
*/
9494
private def shouldApplyCollationToCast(cast: Cast): Boolean = {
95-
cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined &&
95+
cast.containsTag(Cast.USER_SPECIFIED_CAST) &&
9696
hasDefaultStringType(cast.dataType)
9797
}
9898
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/JoinResolver.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class JoinResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
208208
scopes.current.hiddenOutput.filter(_.qualifiedAccessOnly)
209209

210210
val newProjectList =
211-
if (unresolvedJoin.getTagValue(ResolverTag.TOP_LEVEL_OPERATOR).isEmpty) {
211+
if (!unresolvedJoin.containsTag(ResolverTag.TOP_LEVEL_OPERATOR)) {
212212
newOutputList ++ qualifiedAccessOnlyColumnsFromHiddenOutput
213213
} else {
214214
newOutputList

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper {
328328

329329
private def checkUnresolvedAttribute(unresolvedAttribute: UnresolvedAttribute) =
330330
!ResolverGuard.UNSUPPORTED_ATTRIBUTE_NAMES.contains(unresolvedAttribute.nameParts.head) &&
331-
!unresolvedAttribute.getTagValue(LogicalPlan.PLAN_ID_TAG).isDefined
331+
!unresolvedAttribute.containsTag(LogicalPlan.PLAN_ID_TAG)
332332

333333
private def checkUnresolvedPredicate(unresolvedPredicate: Predicate) = unresolvedPredicate match {
334334
case inSubquery: InSubquery =>

0 commit comments

Comments
 (0)