@@ -83,7 +83,18 @@ class DataClass:
8383 x : int
8484
8585
86+ class ReferencedClass :
87+ def __init__ (self , value : int ):
88+ self .value = value
89+
90+ def get_value (self ) -> int :
91+ return self .value
92+
93+
8694class MyClass :
95+ def __init__ (self , x : int ):
96+ self .helper = ReferencedClass (x * 2 )
97+
8798 @staticmethod
8899 def foo ():
89100 return KLASS_X
@@ -95,6 +106,13 @@ def bar(cls):
95106 def baz (self ):
96107 return KLASS_Z
97108
109+ def use_referenced (self , value : int ) -> int :
110+ ref = ReferencedClass (value )
111+ return ref .get_value ()
112+
113+ def compute_with_reference (self ) -> int :
114+ return self .helper .get_value () + 10
115+
98116
99117def other_func (a : int ) -> int :
100118 import sqlglot
@@ -103,7 +121,8 @@ def other_func(a: int) -> int:
103121 pd .DataFrame ([{"x" : 1 }])
104122 to_table ("y" )
105123 my_lambda () # type: ignore
106- return X + a + W
124+ obj = MyClass (a )
125+ return X + a + W + obj .compute_with_reference ()
107126
108127
109128@contextmanager
@@ -131,7 +150,7 @@ def function_with_custom_decorator():
131150def main_func (y : int , foo = exp .true (), * , bar = expressions .Literal .number (1 ) + 2 ) -> int :
132151 """DOC STRING"""
133152 sqlglot .parse_one ("1" )
134- MyClass ()
153+ MyClass (47 )
135154 DataClass (x = y )
136155 normalize_model_name ("test" + SQLGLOT_META )
137156 fetch_data ()
@@ -177,6 +196,7 @@ def test_func_globals() -> None:
177196 assert func_globals (other_func ) == {
178197 "X" : 1 ,
179198 "W" : 0 ,
199+ "MyClass" : MyClass ,
180200 "my_lambda" : my_lambda ,
181201 "pd" : pd ,
182202 "to_table" : to_table ,
@@ -202,7 +222,7 @@ def test_normalize_source() -> None:
202222 == """def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
203223 ):
204224 sqlglot.parse_one('1')
205- MyClass()
225+ MyClass(47 )
206226 DataClass(x=y)
207227 normalize_model_name('test' + SQLGLOT_META)
208228 fetch_data()
@@ -223,7 +243,8 @@ def closure(z: int):
223243 pd.DataFrame([{'x': 1}])
224244 to_table('y')
225245 my_lambda()
226- return X + a + W"""
246+ obj = MyClass(a)
247+ return X + a + W + obj.compute_with_reference()"""
227248 )
228249
229250
@@ -252,7 +273,7 @@ def test_serialize_env() -> None:
252273 payload = """def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
253274 ):
254275 sqlglot.parse_one('1')
255- MyClass()
276+ MyClass(47 )
256277 DataClass(x=y)
257278 normalize_model_name('test' + SQLGLOT_META)
258279 fetch_data()
@@ -295,6 +316,9 @@ class DataClass:
295316 path = "test_metaprogramming.py" ,
296317 payload = """class MyClass:
297318
319+ def __init__(self, x: int):
320+ self.helper = ReferencedClass(x * 2)
321+
298322 @staticmethod
299323 def foo():
300324 return KLASS_X
@@ -304,7 +328,26 @@ def bar(cls):
304328 return KLASS_Y
305329
306330 def baz(self):
307- return KLASS_Z""" ,
331+ return KLASS_Z
332+
333+ def use_referenced(self, value: int):
334+ ref = ReferencedClass(value)
335+ return ref.get_value()
336+
337+ def compute_with_reference(self):
338+ return self.helper.get_value() + 10""" ,
339+ ),
340+ "ReferencedClass" : Executable (
341+ kind = ExecutableKind .DEFINITION ,
342+ name = "ReferencedClass" ,
343+ path = "test_metaprogramming.py" ,
344+ payload = """class ReferencedClass:
345+
346+ def __init__(self, value: int):
347+ self.value = value
348+
349+ def get_value(self):
350+ return self.value""" ,
308351 ),
309352 "dataclass" : Executable (
310353 payload = "from dataclasses import dataclass" , kind = ExecutableKind .IMPORT
@@ -341,7 +384,8 @@ def sample_context_manager():
341384 pd.DataFrame([{'x': 1}])
342385 to_table('y')
343386 my_lambda()
344- return X + a + W""" ,
387+ obj = MyClass(a)
388+ return X + a + W + obj.compute_with_reference()""" ,
345389 ),
346390 "sample_context_manager" : Executable (
347391 payload = """@contextmanager
@@ -424,6 +468,21 @@ def function_with_custom_decorator():
424468 assert all (is_metadata for (_ , is_metadata ) in env .values ())
425469 assert serialized_env == expected_env
426470
471+ # Check that class references inside init are captured
472+ init_globals = func_globals (MyClass .__init__ )
473+ assert "ReferencedClass" in init_globals
474+
475+ env = {}
476+ build_env (other_func , env = env , name = "other_func_test" , path = path )
477+ serialized_env = serialize_env (env , path = path )
478+
479+ assert "MyClass" in serialized_env
480+ assert "ReferencedClass" in serialized_env
481+
482+ prepared_env = prepare_env (serialized_env )
483+ result = eval ("other_func_test(2)" , prepared_env )
484+ assert result == 17
485+
427486
428487def test_serialize_env_with_enum_import_appearing_in_two_functions () -> None :
429488 path = Path ("tests/utils" )
0 commit comments