Skip to content

Commit 20e0e25

Browse files
Fix!: Allow init to be walked to track its dependencies (#5688)
1 parent 5211a57 commit 20e0e25

File tree

3 files changed

+82
-13
lines changed

3 files changed

+82
-13
lines changed

sqlmesh/utils/metaprogramming.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,8 @@ def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None:
352352
walk(base, base.__qualname__, is_metadata)
353353

354354
for k, v in obj.__dict__.items():
355-
if k.startswith("__"):
355+
# skip dunder methods bar __init__ as it might contain user defined logic with cross class references
356+
if k.startswith("__") and k != "__init__":
356357
continue
357358

358359
# Traverse methods in a class to find global references
@@ -362,10 +363,14 @@ def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None:
362363
if callable(v):
363364
# Walk the method if it's part of the object, else it's a global function and we just store it
364365
if v.__qualname__.startswith(obj.__qualname__):
365-
for k, v in func_globals(v).items():
366-
walk(v, k, is_metadata)
367-
else:
368-
walk(v, v.__name__, is_metadata)
366+
try:
367+
for k, v in func_globals(v).items():
368+
walk(v, k, is_metadata)
369+
except (OSError, TypeError):
370+
# __init__ may come from built-ins or wrapped callables
371+
pass
372+
else:
373+
walk(v, k, is_metadata)
369374
elif callable(obj):
370375
for k, v in func_globals(obj).items():
371376
walk(v, k, is_metadata)

tests/core/test_context.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1506,14 +1506,19 @@ def test_requirements(copy_to_temp_path: t.Callable):
15061506
"dev", no_prompts=True, skip_tests=True, skip_backfill=True, auto_apply=True
15071507
).environment
15081508
requirements = {"ipywidgets", "numpy", "pandas", "test_package"}
1509+
if IS_WINDOWS:
1510+
requirements.add("pendulum")
15091511
assert environment.requirements["pandas"] == "2.2.2"
15101512
assert set(environment.requirements) == requirements
15111513

15121514
context._requirements = {"numpy": "2.1.2", "pandas": "2.2.1"}
15131515
context._excluded_requirements = {"ipywidgets", "ruamel.yaml", "ruamel.yaml.clib"}
15141516
diff = context.plan_builder("dev", skip_tests=True, skip_backfill=True).build().context_diff
15151517
assert set(diff.previous_requirements) == requirements
1516-
assert set(diff.requirements) == {"numpy", "pandas"}
1518+
reqs = {"numpy", "pandas"}
1519+
if IS_WINDOWS:
1520+
reqs.add("pendulum")
1521+
assert set(diff.requirements) == reqs
15171522

15181523

15191524
def test_deactivate_automatic_requirement_inference(copy_to_temp_path: t.Callable):

tests/utils/test_metaprogramming.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,18 @@ class DataClass:
8383
x: int
8484

8585

86+
class ReferencedClass:
87+
def __init__(self, value: int):
88+
self.value = value
89+
90+
def get_value(self) -> int:
91+
return self.value
92+
93+
8694
class MyClass:
95+
def __init__(self, x: int):
96+
self.helper = ReferencedClass(x * 2)
97+
8798
@staticmethod
8899
def foo():
89100
return KLASS_X
@@ -95,6 +106,13 @@ def bar(cls):
95106
def baz(self):
96107
return KLASS_Z
97108

109+
def use_referenced(self, value: int) -> int:
110+
ref = ReferencedClass(value)
111+
return ref.get_value()
112+
113+
def compute_with_reference(self) -> int:
114+
return self.helper.get_value() + 10
115+
98116

99117
def other_func(a: int) -> int:
100118
import sqlglot
@@ -103,7 +121,8 @@ def other_func(a: int) -> int:
103121
pd.DataFrame([{"x": 1}])
104122
to_table("y")
105123
my_lambda() # type: ignore
106-
return X + a + W
124+
obj = MyClass(a)
125+
return X + a + W + obj.compute_with_reference()
107126

108127

109128
@contextmanager
@@ -131,7 +150,7 @@ def function_with_custom_decorator():
131150
def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2) -> int:
132151
"""DOC STRING"""
133152
sqlglot.parse_one("1")
134-
MyClass()
153+
MyClass(47)
135154
DataClass(x=y)
136155
normalize_model_name("test" + SQLGLOT_META)
137156
fetch_data()
@@ -177,6 +196,7 @@ def test_func_globals() -> None:
177196
assert func_globals(other_func) == {
178197
"X": 1,
179198
"W": 0,
199+
"MyClass": MyClass,
180200
"my_lambda": my_lambda,
181201
"pd": pd,
182202
"to_table": to_table,
@@ -202,7 +222,7 @@ def test_normalize_source() -> None:
202222
== """def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
203223
):
204224
sqlglot.parse_one('1')
205-
MyClass()
225+
MyClass(47)
206226
DataClass(x=y)
207227
normalize_model_name('test' + SQLGLOT_META)
208228
fetch_data()
@@ -223,7 +243,8 @@ def closure(z: int):
223243
pd.DataFrame([{'x': 1}])
224244
to_table('y')
225245
my_lambda()
226-
return X + a + W"""
246+
obj = MyClass(a)
247+
return X + a + W + obj.compute_with_reference()"""
227248
)
228249

229250

@@ -252,7 +273,7 @@ def test_serialize_env() -> None:
252273
payload="""def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
253274
):
254275
sqlglot.parse_one('1')
255-
MyClass()
276+
MyClass(47)
256277
DataClass(x=y)
257278
normalize_model_name('test' + SQLGLOT_META)
258279
fetch_data()
@@ -295,6 +316,9 @@ class DataClass:
295316
path="test_metaprogramming.py",
296317
payload="""class MyClass:
297318
319+
def __init__(self, x: int):
320+
self.helper = ReferencedClass(x * 2)
321+
298322
@staticmethod
299323
def foo():
300324
return KLASS_X
@@ -304,7 +328,26 @@ def bar(cls):
304328
return KLASS_Y
305329
306330
def baz(self):
307-
return KLASS_Z""",
331+
return KLASS_Z
332+
333+
def use_referenced(self, value: int):
334+
ref = ReferencedClass(value)
335+
return ref.get_value()
336+
337+
def compute_with_reference(self):
338+
return self.helper.get_value() + 10""",
339+
),
340+
"ReferencedClass": Executable(
341+
kind=ExecutableKind.DEFINITION,
342+
name="ReferencedClass",
343+
path="test_metaprogramming.py",
344+
payload="""class ReferencedClass:
345+
346+
def __init__(self, value: int):
347+
self.value = value
348+
349+
def get_value(self):
350+
return self.value""",
308351
),
309352
"dataclass": Executable(
310353
payload="from dataclasses import dataclass", kind=ExecutableKind.IMPORT
@@ -341,7 +384,8 @@ def sample_context_manager():
341384
pd.DataFrame([{'x': 1}])
342385
to_table('y')
343386
my_lambda()
344-
return X + a + W""",
387+
obj = MyClass(a)
388+
return X + a + W + obj.compute_with_reference()""",
345389
),
346390
"sample_context_manager": Executable(
347391
payload="""@contextmanager
@@ -424,6 +468,21 @@ def function_with_custom_decorator():
424468
assert all(is_metadata for (_, is_metadata) in env.values())
425469
assert serialized_env == expected_env
426470

471+
# Check that class references inside init are captured
472+
init_globals = func_globals(MyClass.__init__)
473+
assert "ReferencedClass" in init_globals
474+
475+
env = {}
476+
build_env(other_func, env=env, name="other_func_test", path=path)
477+
serialized_env = serialize_env(env, path=path)
478+
479+
assert "MyClass" in serialized_env
480+
assert "ReferencedClass" in serialized_env
481+
482+
prepared_env = prepare_env(serialized_env)
483+
result = eval("other_func_test(2)", prepared_env)
484+
assert result == 17
485+
427486

428487
def test_serialize_env_with_enum_import_appearing_in_two_functions() -> None:
429488
path = Path("tests/utils")

0 commit comments

Comments
 (0)