Skip to content

Commit ae16b0f

Browse files
committed
Update NormalizeCTEIds.scala
1 parent de65261 commit ae16b0f

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,50 +17,53 @@
1717

1818
package org.apache.spark.sql.catalyst.normalizer
1919

20-
import java.util.HashMap
2120
import java.util.concurrent.atomic.AtomicLong
2221

22+
import scala.collection.mutable
23+
2324
import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE}
2425
import org.apache.spark.sql.catalyst.rules.Rule
2526

2627
object NormalizeCTEIds extends Rule[LogicalPlan] {
2728
override def apply(plan: LogicalPlan): LogicalPlan = {
2829
val curId = new java.util.concurrent.atomic.AtomicLong()
29-
val cteIdToNewId = new HashMap[Long, Long]()
30+
val cteIdToNewId = mutable.Map.empty[Long, Long]
3031
applyInternal(plan, curId, cteIdToNewId)
3132
}
3233

3334
private def applyInternal(
3435
plan: LogicalPlan,
3536
curId: AtomicLong,
36-
cteIdToNewId: HashMap[Long, Long]): LogicalPlan = {
37+
cteIdToNewId: mutable.Map[Long, Long]): LogicalPlan = {
3738
plan transformDownWithSubqueries {
3839
case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) =>
3940
ctas.copy(plan = applyInternal(plan, curId, cteIdToNewId))
4041

4142
case withCTE @ WithCTE(plan, cteDefs) =>
4243
val newCteDefs = cteDefs.map { cteDef =>
43-
if (!cteIdToNewId.containsKey(cteDef.id)) {
44-
cteIdToNewId.put(cteDef.id, curId.getAndIncrement())
44+
if (!cteIdToNewId.contains(cteDef.id)) {
45+
cteIdToNewId(cteDef.id) = curId.getAndIncrement()
4546
}
4647
val normalizedCteDef = canonicalizeCTE(cteDef.child, cteIdToNewId)
47-
cteDef.copy(child = normalizedCteDef, id = cteIdToNewId.get(cteDef.id))
48+
cteDef.copy(child = normalizedCteDef, id = cteIdToNewId(cteDef.id))
4849
}
4950
val normalizedPlan = canonicalizeCTE(plan, cteIdToNewId)
5051
withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs)
5152
}
5253
}
5354

54-
private def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: HashMap[Long, Long]): LogicalPlan = {
55+
private def canonicalizeCTE(
56+
plan: LogicalPlan,
57+
defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = {
5558
plan.transformDownWithSubqueries {
5659
// For nested WithCTE, if defIndex didn't contain the cteId,
5760
// means it's not current WithCTE's ref.
58-
case ref: CTERelationRef if defIdToNewId.containsKey(ref.cteId) =>
59-
ref.copy(cteId = defIdToNewId.get(ref.cteId))
60-
case unionLoop: UnionLoop if defIdToNewId.containsKey(unionLoop.id) =>
61-
unionLoop.copy(id = defIdToNewId.get(unionLoop.id))
62-
case unionLoopRef: UnionLoopRef if defIdToNewId.containsKey(unionLoopRef.loopId) =>
63-
unionLoopRef.copy(loopId = defIdToNewId.get(unionLoopRef.loopId))
61+
case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) =>
62+
ref.copy(cteId = defIdToNewId(ref.cteId))
63+
case unionLoop: UnionLoop if defIdToNewId.contains(unionLoop.id) =>
64+
unionLoop.copy(id = defIdToNewId(unionLoop.id))
65+
case unionLoopRef: UnionLoopRef if defIdToNewId.contains(unionLoopRef.loopId) =>
66+
unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId))
6467
}
6568
}
6669
}

0 commit comments

Comments
 (0)