|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.normalizer |
19 | 19 |
|
20 | | -import java.util.HashMap |
21 | 20 | import java.util.concurrent.atomic.AtomicLong |
22 | 21 |
|
| 22 | +import scala.collection.mutable |
| 23 | + |
23 | 24 | import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE} |
24 | 25 | import org.apache.spark.sql.catalyst.rules.Rule |
25 | 26 |
|
26 | 27 | object NormalizeCTEIds extends Rule[LogicalPlan] { |
27 | 28 | override def apply(plan: LogicalPlan): LogicalPlan = { |
28 | 29 | val curId = new java.util.concurrent.atomic.AtomicLong() |
29 | | - val cteIdToNewId = new HashMap[Long, Long]() |
| 30 | + val cteIdToNewId = mutable.Map.empty[Long, Long] |
30 | 31 | applyInternal(plan, curId, cteIdToNewId) |
31 | 32 | } |
32 | 33 |
|
33 | 34 | private def applyInternal( |
34 | 35 | plan: LogicalPlan, |
35 | 36 | curId: AtomicLong, |
36 | | - cteIdToNewId: HashMap[Long, Long]): LogicalPlan = { |
| 37 | + cteIdToNewId: mutable.Map[Long, Long]): LogicalPlan = { |
37 | 38 | plan transformDownWithSubqueries { |
38 | 39 | case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) => |
39 | 40 | ctas.copy(plan = applyInternal(plan, curId, cteIdToNewId)) |
40 | 41 |
|
41 | 42 | case withCTE @ WithCTE(plan, cteDefs) => |
42 | 43 | val newCteDefs = cteDefs.map { cteDef => |
43 | | - if (!cteIdToNewId.containsKey(cteDef.id)) { |
44 | | - cteIdToNewId.put(cteDef.id, curId.getAndIncrement()) |
| 44 | + if (!cteIdToNewId.contains(cteDef.id)) { |
| 45 | + cteIdToNewId(cteDef.id) = curId.getAndIncrement() |
45 | 46 | } |
46 | 47 | val normalizedCteDef = canonicalizeCTE(cteDef.child, cteIdToNewId) |
47 | | - cteDef.copy(child = normalizedCteDef, id = cteIdToNewId.get(cteDef.id)) |
| 48 | + cteDef.copy(child = normalizedCteDef, id = cteIdToNewId(cteDef.id)) |
48 | 49 | } |
49 | 50 | val normalizedPlan = canonicalizeCTE(plan, cteIdToNewId) |
50 | 51 | withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs) |
51 | 52 | } |
52 | 53 | } |
53 | 54 |
|
54 | | - private def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: HashMap[Long, Long]): LogicalPlan = { |
| 55 | + private def canonicalizeCTE( |
| 56 | + plan: LogicalPlan, |
| 57 | + defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = { |
55 | 58 | plan.transformDownWithSubqueries { |
56 | 59 | // For nested WithCTE, if defIndex didn't contain the cteId, |
57 | 60 | // means it's not current WithCTE's ref. |
58 | | - case ref: CTERelationRef if defIdToNewId.containsKey(ref.cteId) => |
59 | | - ref.copy(cteId = defIdToNewId.get(ref.cteId)) |
60 | | - case unionLoop: UnionLoop if defIdToNewId.containsKey(unionLoop.id) => |
61 | | - unionLoop.copy(id = defIdToNewId.get(unionLoop.id)) |
62 | | - case unionLoopRef: UnionLoopRef if defIdToNewId.containsKey(unionLoopRef.loopId) => |
63 | | - unionLoopRef.copy(loopId = defIdToNewId.get(unionLoopRef.loopId)) |
| 61 | + case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) => |
| 62 | + ref.copy(cteId = defIdToNewId(ref.cteId)) |
| 63 | + case unionLoop: UnionLoop if defIdToNewId.contains(unionLoop.id) => |
| 64 | + unionLoop.copy(id = defIdToNewId(unionLoop.id)) |
| 65 | + case unionLoopRef: UnionLoopRef if defIdToNewId.contains(unionLoopRef.loopId) => |
| 66 | + unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId)) |
64 | 67 | } |
65 | 68 | } |
66 | 69 | } |
0 commit comments