diff --git a/tests/unit/codegen/sdk/python/function/test_function_move_to_file.py b/tests/unit/codegen/sdk/python/function/test_function_move_to_file.py index 68ea4234d..31dc17fa9 100644 --- a/tests/unit/codegen/sdk/python/function/test_function_move_to_file.py +++ b/tests/unit/codegen/sdk/python/function/test_function_move_to_file.py @@ -1275,3 +1275,109 @@ def external_dep2(): assert file2.content.strip() == EXPECTED_FILE_2_CONTENT.strip() assert file3.content.strip() == EXPECTED_FILE_3_CONTENT.strip() assert file4.content.strip() == EXPECTED_FILE_4_CONTENT.strip() + + +def test_move_to_file_with_dataclass_dependencies(tmpdir) -> None: + # ========== [ BEFORE ] ========== + # language=python + FILE_1_CONTENT = """ +from dataclasses import dataclass + +@dataclass +class Config: + '''Base config class''' + name: str + value: int + +def foo(): + return 1 +""" + + # language=python + FILE_2_CONTENT = """ +from dataclasses import dataclass +from file1 import Config + +@dataclass +class ExtendedConfig(Config): + '''Extended config that depends on Config''' + extra: str = "default" + +def bar(config: ExtendedConfig): + '''Function that uses the dataclass''' + return config.value + 1 +""" + + # ========== [ AFTER ] ========== + # language=python + EXPECTED_FILE_1_CONTENT = """ +from dataclasses import dataclass + +def foo(): + return 1 +""" + + # language=python + EXPECTED_FILE_1_TYPES_CONTENT = """ +from dataclasses import dataclass + + +@dataclass +class Config: + '''Base config class''' + name: str + value: int +""" + + # language=python + EXPECTED_FILE_2_CONTENT = """ +from file2.types import ExtendedConfig +from file1.types import Config +from dataclasses import dataclass + +def bar(config: ExtendedConfig): + '''Function that uses the dataclass''' + return config.value + 1 +""" + + # language=python + EXPECTED_FILE_2_TYPES_CONTENT = """ +from file1.types import Config +from dataclasses import dataclass + + +@dataclass +class ExtendedConfig(Config): + '''Extended config that depends on Config''' + extra: str = "default" +""" + + # =============================== + + with get_codebase_session( + tmpdir=tmpdir, + files={ + "file1.py": FILE_1_CONTENT, + "file2.py": FILE_2_CONTENT, + }, + ) as codebase: + file1 = codebase.get_file("file1.py") + file2 = codebase.get_file("file2.py") + + # Create types.py files + file1_types = codebase.create_file("file1/types.py", "") + file2_types = codebase.create_file("file2/types.py", "") + + # Move Config dataclass first since ExtendedConfig depends on it + config_class = file1.get_class("Config") + config_class.move_to_file(file1_types, strategy="update_all_imports", include_dependencies=True) + codebase.commit() + + # Then move ExtendedConfig + extended_config_class = file2.get_class("ExtendedConfig") + extended_config_class.move_to_file(file2_types, strategy="update_all_imports", include_dependencies=True) + + assert file1.content.strip() == EXPECTED_FILE_1_CONTENT.strip() + assert file1_types.content.strip() == EXPECTED_FILE_1_TYPES_CONTENT.strip() + assert file2.content.strip() == EXPECTED_FILE_2_CONTENT.strip() + assert file2_types.content.strip() == EXPECTED_FILE_2_TYPES_CONTENT.strip()