Skip to content

Commit bc99b25

Browse files
authored
Fix infinite recursion when checking module of expression (#2159)
## Changes Don't drop existing statements for globals when appending trees ### Linked issues None ### Functionality None ### Tests - [x] added unit tests Was discovered when running make solacc. Fixes 1 of the 2 issue. Other issue is fixed by #2157 Co-authored-by: Eric Vergnaud <[email protected]>
1 parent 0e0ecbd commit bc99b25

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

src/databricks/labs/ucx/source_code/linters/python_ast.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Attribute,
1313
Call,
1414
Const,
15+
Expr,
1516
Import,
1617
ImportFrom,
1718
Module,
@@ -220,12 +221,16 @@ def append_statements(self, tree: Tree) -> Tree:
220221
stmt.parent = self_module
221222
self_module.body.append(stmt)
222223
for name, value in tree_module.globals.items():
223-
self_module.globals[name] = value
224+
statements: list[Expr] = self_module.globals.get(name, None)
225+
if statements is None:
226+
self_module.globals[name] = list(value) # clone the source list to avoid side-effects
227+
continue
228+
statements.extend(value)
224229
# the following may seem strange but it's actually ok to use the original module as tree root
225230
return tree
226231

227232
def is_from_module(self, module_name: str):
228-
# if his is the call's root node, check it against the required module
233+
# if this is the call's root node, check it against the required module
229234
if isinstance(self._node, Name):
230235
if self._node.name == module_name:
231236
return True

tests/unit/source_code/linters/test_python_ast.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,21 @@ def test_is_from_module():
159159
Call, [("saveAsTable", Attribute), ("format", Attribute), ("write", Attribute), ("df", Name)]
160160
)[0]
161161
assert Tree(save_call).is_from_module("spark")
162+
163+
164+
def test_supports_recursive_refs_when_checking_module():
165+
source_1 = """
166+
df = spark.read.csv("hi")
167+
"""
168+
source_2 = """
169+
df = df.withColumn(stuff)
170+
"""
171+
source_3 = """
172+
df = df.withColumn(stuff2)
173+
"""
174+
main_tree = Tree.normalize_and_parse(source_1)
175+
main_tree.append_statements(Tree.normalize_and_parse(source_2))
176+
tree = Tree.normalize_and_parse(source_3)
177+
main_tree.append_statements(tree)
178+
assign = tree.locate(Assign, [])[0]
179+
assert Tree(assign.value).is_from_module("spark")

0 commit comments

Comments
 (0)