From bdb11d0491db62661d36cb0ac11e6a1cbd2c1e23 Mon Sep 17 00:00:00 2001 From: mohammed Date: Fri, 27 Jun 2025 20:21:30 +0300 Subject: [PATCH 1/6] don't import from same module --- codeflash/code_utils/code_extractor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 0dcc2357f..9fab16ab7 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -331,7 +331,9 @@ def add_needed_imports_from_module( RemoveImportsVisitor.remove_unused_import(dst_context, mod) for mod, obj_seq in gatherer.object_mapping.items(): for obj in obj_seq: - if f"{mod}.{obj}" in helper_functions_fqn: + if ( + f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name is mod + ): # avoid circular imports continue # Skip adding imports for helper functions already in the context AddImportsVisitor.add_needed_import(dst_context, mod, obj) RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj) From b4e9ab70f459c0f7f010684731604330718fb543 Mon Sep 17 00:00:00 2001 From: mohammed Date: Fri, 27 Jun 2025 20:23:55 +0300 Subject: [PATCH 2/6] formatting --- codeflash/code_utils/code_extractor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 9fab16ab7..e0a6f9af8 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -332,8 +332,9 @@ def add_needed_imports_from_module( for mod, obj_seq in gatherer.object_mapping.items(): for obj in obj_seq: if ( - f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name is mod - ): # avoid circular imports + f"{mod}.{obj}" in helper_functions_fqn + or dst_context.full_module_name is mod # avoid circular imports + ): continue # Skip adding imports for helper functions already in the context AddImportsVisitor.add_needed_import(dst_context, mod, obj) RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj) From 0212e01f98e17a5085db32d5becbc1fa21c8d56a Mon Sep 17 00:00:00 2001 From: mohammed Date: Sat, 28 Jun 2025 13:41:16 +0300 Subject: [PATCH 3/6] remove unused imports in one go --- codeflash/code_utils/code_extractor.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index e0a6f9af8..2b125ce7d 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -325,31 +325,36 @@ def add_needed_imports_from_module( ) ) cst.parse_module(src_module_code).visit(gatherer) + scheduled_unused_imports = [] try: for mod in gatherer.module_imports: AddImportsVisitor.add_needed_import(dst_context, mod) - RemoveImportsVisitor.remove_unused_import(dst_context, mod) + scheduled_unused_imports.append((mod, "", "")) for mod, obj_seq in gatherer.object_mapping.items(): + logger.debug(f"dst_context.full_module_name: {dst_context.full_module_name}") + logger.debug(f"mod: {mod}") + logger.debug(f"obj_seq: {obj_seq}") + logger.debug(f"helper_functions_fqn: {helper_functions_fqn}") for obj in obj_seq: if ( f"{mod}.{obj}" in helper_functions_fqn - or dst_context.full_module_name is mod # avoid circular imports + or dst_context.full_module_name == mod # avoid circular imports ): continue # Skip adding imports for helper functions already in the context AddImportsVisitor.add_needed_import(dst_context, mod, obj) - RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj) + scheduled_unused_imports.append((mod, obj, "")) except Exception as e: logger.exception(f"Error adding imports to destination module code: {e}") return dst_module_code for mod, asname in gatherer.module_aliases.items(): AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname) - RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname) + scheduled_unused_imports.append((mod, "", asname)) for mod, alias_pairs in gatherer.alias_mapping.items(): for alias_pair in alias_pairs: if f"{mod}.{alias_pair[0]}" in helper_functions_fqn: continue AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1]) - RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1]) + scheduled_unused_imports.append((mod, alias_pair[0], alias_pair[1])) try: parsed_module = cst.parse_module(dst_module_code) @@ -358,6 +363,9 @@ def add_needed_imports_from_module( return dst_module_code # Return the original code if there's a syntax error try: transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module) + for _import in scheduled_unused_imports: + (_module, _obj, _alias) = _import + RemoveImportsVisitor.remove_unused_import(dst_context, module=_module, obj=_obj, asname=_alias) transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module) return transformed_module.code.lstrip("\n") except Exception as e: From 2acda6a41155af4485eb1e6f51ec64e546bbdc94 Mon Sep 17 00:00:00 2001 From: mohammed Date: Sat, 28 Jun 2025 21:14:26 +0300 Subject: [PATCH 4/6] initial test files --- .../circular_deps/api_client.py | 77 +++++++++++++++++++ .../circular_deps/constants.py | 8 ++ codeflash/code_utils/code_replacer.py | 7 ++ tests/test_code_context_extractor.py | 50 ++++++++++++ 4 files changed, 142 insertions(+) create mode 100644 code_to_optimize/code_directories/circular_deps/api_client.py create mode 100644 code_to_optimize/code_directories/circular_deps/constants.py diff --git a/code_to_optimize/code_directories/circular_deps/api_client.py b/code_to_optimize/code_directories/circular_deps/api_client.py new file mode 100644 index 000000000..aa4f768f0 --- /dev/null +++ b/code_to_optimize/code_directories/circular_deps/api_client.py @@ -0,0 +1,77 @@ +from os import getenv +from typing import Optional + +from attrs import define, evolve, field + +from code_to_optimize.code_directories.circular_deps.constants import DEFAULT_API_URL, DEFAULT_APP_URL + + +@define +class GalileoApiClient(): + """A Client which has been authenticated for use on secured endpoints + + The following are accepted as keyword arguments and will be used to construct httpx Clients internally: + + ``base_url``: The base URL for the API, all requests are made to a relative path to this URL + This can also be set via the GALILEO_CONSOLE_URL environment variable + + ``api_key``: The API key to be sent with every request + This can also be set via the GALILEO_API_KEY environment variable + + ``cookies``: A dictionary of cookies to be sent with every request + + ``headers``: A dictionary of headers to be sent with every request + + ``timeout``: The maximum amount of a time a request can take. API functions will raise + httpx.TimeoutException if this is exceeded. + + ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production, + but can be set to False for testing purposes. + + ``follow_redirects``: Whether or not to follow redirects. Default value is False. + + ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor. + + Attributes: + raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a + status code that was not documented in the source OpenAPI document. Can also be provided as a keyword + argument to the constructor. + token: The token to use for authentication + prefix: The prefix to use for the Authorization header + auth_header_name: The name of the Authorization header + """ + + _base_url: Optional[str] = field(factory=lambda: GalileoApiClient.get_api_url(), kw_only=True, alias="base_url") + _api_key: Optional[str] = field(factory=lambda: getenv("GALILEO_API_KEY", None), kw_only=True, alias="api_key") + token: Optional[str] = None + + api_key_header_name: str = "Galileo-API-Key" + client_type_header_name: str = "client-type" + client_type_header_value: str = "sdk-python" + + @staticmethod + def get_console_url() -> str: + console_url = getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL) + if DEFAULT_API_URL == console_url: + return DEFAULT_APP_URL + + return console_url + + def with_api_key(self, api_key: str) -> "GalileoApiClient": + """Get a new client matching this one with a new API key""" + if self._client is not None: + self._client.headers.update({self.api_key_header_name: api_key}) + if self._async_client is not None: + self._async_client.headers.update({self.api_key_header_name: api_key}) + return evolve(self, api_key=api_key) + + @staticmethod + def get_api_url(base_url: Optional[str] = None) -> str: + api_url = base_url or getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL) + if api_url is None: + raise ValueError("base_url or GALILEO_CONSOLE_URL must be set") + if any(map(api_url.__contains__, ["localhost", "127.0.0.1"])): + api_url = "http://localhost:8088" + else: + api_url = api_url.replace("app.galileo.ai", "api.galileo.ai").replace("console", "api") + return api_url diff --git a/code_to_optimize/code_directories/circular_deps/constants.py b/code_to_optimize/code_directories/circular_deps/constants.py new file mode 100644 index 000000000..dc4b0638e --- /dev/null +++ b/code_to_optimize/code_directories/circular_deps/constants.py @@ -0,0 +1,8 @@ +DEFAULT_API_URL = "https://api.galileo.ai/" +DEFAULT_APP_URL = "https://app.galileo.ai/" + + +# function_names: GalileoApiClient.get_console_url +# module_abs_path : /home/mohammed/Work/galileo-python/src/galileo/api_client.py +# preexisting_objects: {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))} +# project_root_path: /home/mohammed/Work/galileo-python/src diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index f0964aae7..b9acffcda 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -397,6 +397,13 @@ def replace_functions_and_add_imports( preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], project_root_path: Path, ) -> str: + logger.debug("start from here,...") + logger.debug(f"source_code: {source_code}") + logger.debug(f"function_names: {function_names}") + logger.debug(f"optimized_code: {optimized_code}") + logger.debug(f"module_abspath: {module_abspath}") + logger.debug(f"preexisting_objects: {preexisting_objects}") + logger.debug(f"project_root_path: {project_root_path}") return add_needed_imports_from_module( optimized_code, replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects), diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 010d3bc65..3ef07adbf 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -2434,3 +2434,53 @@ def simple_method(self): assert "class SimpleClass:" in code_content assert "def simple_method(self):" in code_content assert "return 42" in code_content + + +from codeflash.code_utils.code_replacer import replace_functions_and_add_imports +def test_replace_functions_and_add_imports(): + path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps" + optimized_code = '''from __future__ import annotations + +import urllib.parse +from os import getenv + +from attrs import define +from code_to_optimize.code_directories.circular_deps.constants import DEFAULT_API_URL, DEFAULT_APP_URL + +# Precompute constant netlocs for set membership test +_DEFAULT_APP_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc +_DEFAULT_API_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc +_NETLOC_SET = {_DEFAULT_APP_NETLOC, _DEFAULT_API_NETLOC} + +@define +class GalileoApiClient(): + + @staticmethod + def get_console_url() -> str: + # Return DEFAULT_APP_URL if the env var is not set or set to DEFAULT_API_URL + console_url = getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL) + if console_url == DEFAULT_API_URL: + return DEFAULT_APP_URL + return console_url + +def _set_destination(console_url: str) -> str: + """ + Parse the console_url and return the destination for the OpenTelemetry traces. + """ + destination = (console_url or GalileoApiClient.get_console_url()).replace("console.", "api.") + parsed_url = urllib.parse.urlparse(destination) + if parsed_url.netloc in _NETLOC_SET: + return f"{DEFAULT_APP_URL}api/galileo/otel/traces" + return f"{parsed_url.scheme}://{parsed_url.netloc}/otel/traces"''' + file_abs_path = path_to_root / "api_client.py" + content = Path(file_abs_path).read_text(encoding="utf-8") + new_code = replace_functions_and_add_imports( + source_code= content, + function_names= ["GalileoApiClient.get_console_url"], + optimized_code= optimized_code, + module_abspath= file_abs_path, + preexisting_objects= {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))}, + project_root_path= Path(path_to_root), + ) + print(new_code) + assert 1 == 1 \ No newline at end of file From 2e394f6b8fb44fb8d60e3f95541f6724d347f73a Mon Sep 17 00:00:00 2001 From: mohammed Date: Sun, 29 Jun 2025 00:23:36 +0300 Subject: [PATCH 5/6] tests and fix global assignments imports --- .../circular_deps/api_client.py | 64 ++----------------- .../circular_deps/optimized.py | 37 +++++++++++ .../circular_deps/pyproject.toml | 7 ++ codeflash/code_utils/code_extractor.py | 19 ++---- codeflash/code_utils/code_replacer.py | 17 ++--- tests/test_code_context_extractor.py | 51 ++++----------- 6 files changed, 73 insertions(+), 122 deletions(-) create mode 100644 code_to_optimize/code_directories/circular_deps/optimized.py create mode 100644 code_to_optimize/code_directories/circular_deps/pyproject.toml diff --git a/code_to_optimize/code_directories/circular_deps/api_client.py b/code_to_optimize/code_directories/circular_deps/api_client.py index aa4f768f0..bc93193d2 100644 --- a/code_to_optimize/code_directories/circular_deps/api_client.py +++ b/code_to_optimize/code_directories/circular_deps/api_client.py @@ -1,77 +1,25 @@ from os import getenv -from typing import Optional -from attrs import define, evolve, field +from attrs import define, evolve -from code_to_optimize.code_directories.circular_deps.constants import DEFAULT_API_URL, DEFAULT_APP_URL +from constants import DEFAULT_API_URL, DEFAULT_APP_URL @define -class GalileoApiClient(): - """A Client which has been authenticated for use on secured endpoints - - The following are accepted as keyword arguments and will be used to construct httpx Clients internally: - - ``base_url``: The base URL for the API, all requests are made to a relative path to this URL - This can also be set via the GALILEO_CONSOLE_URL environment variable - - ``api_key``: The API key to be sent with every request - This can also be set via the GALILEO_API_KEY environment variable - - ``cookies``: A dictionary of cookies to be sent with every request - - ``headers``: A dictionary of headers to be sent with every request - - ``timeout``: The maximum amount of a time a request can take. API functions will raise - httpx.TimeoutException if this is exceeded. - - ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production, - but can be set to False for testing purposes. - - ``follow_redirects``: Whether or not to follow redirects. Default value is False. - - ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor. - - Attributes: - raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a - status code that was not documented in the source OpenAPI document. Can also be provided as a keyword - argument to the constructor. - token: The token to use for authentication - prefix: The prefix to use for the Authorization header - auth_header_name: The name of the Authorization header - """ - - _base_url: Optional[str] = field(factory=lambda: GalileoApiClient.get_api_url(), kw_only=True, alias="base_url") - _api_key: Optional[str] = field(factory=lambda: getenv("GALILEO_API_KEY", None), kw_only=True, alias="api_key") - token: Optional[str] = None - - api_key_header_name: str = "Galileo-API-Key" +class ApiClient(): + api_key_header_name: str = "API-Key" client_type_header_name: str = "client-type" client_type_header_value: str = "sdk-python" @staticmethod def get_console_url() -> str: - console_url = getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL) + console_url = getenv("CONSOLE_URL", DEFAULT_API_URL) if DEFAULT_API_URL == console_url: return DEFAULT_APP_URL return console_url - def with_api_key(self, api_key: str) -> "GalileoApiClient": + def with_api_key(self, api_key: str) -> "ApiClient": # ---> here is the problem with circular dependency, this makes libcst thinks that ApiClient needs an import despite it's already in the same file. """Get a new client matching this one with a new API key""" - if self._client is not None: - self._client.headers.update({self.api_key_header_name: api_key}) - if self._async_client is not None: - self._async_client.headers.update({self.api_key_header_name: api_key}) return evolve(self, api_key=api_key) - @staticmethod - def get_api_url(base_url: Optional[str] = None) -> str: - api_url = base_url or getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL) - if api_url is None: - raise ValueError("base_url or GALILEO_CONSOLE_URL must be set") - if any(map(api_url.__contains__, ["localhost", "127.0.0.1"])): - api_url = "http://localhost:8088" - else: - api_url = api_url.replace("app.galileo.ai", "api.galileo.ai").replace("console", "api") - return api_url diff --git a/code_to_optimize/code_directories/circular_deps/optimized.py b/code_to_optimize/code_directories/circular_deps/optimized.py new file mode 100644 index 000000000..2fa5d9bd0 --- /dev/null +++ b/code_to_optimize/code_directories/circular_deps/optimized.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import urllib.parse +from os import getenv + +from attrs import define +from api_client import ApiClient +from constants import DEFAULT_API_URL, DEFAULT_APP_URL + + +@define +class ApiClient(): + + @staticmethod + def get_console_url() -> str: + # Cache env lookup for speed + console_url = getenv("CONSOLE_URL") + if not console_url or console_url == DEFAULT_API_URL: + return DEFAULT_APP_URL + return console_url + +# Pre-parse netlocs that are checked frequently to avoid parsing repeatedly +_DEFAULT_APP_URL_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc +_DEFAULT_API_URL_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc + +def get_dest_url(url: str) -> str: + destination = url if url else ApiClient.get_console_url() + # Replace only if 'console.' is at the beginning to avoid partial matches + if destination.startswith("console."): + destination = "api." + destination[len("console."):] + else: + destination = destination.replace("console.", "api.", 1) + + parsed_url = urllib.parse.urlparse(destination) + if parsed_url.netloc == _DEFAULT_APP_URL_NETLOC or parsed_url.netloc == _DEFAULT_API_URL_NETLOC: + return f"{DEFAULT_APP_URL}api/traces" + return f"{parsed_url.scheme}://{parsed_url.netloc}/traces" \ No newline at end of file diff --git a/code_to_optimize/code_directories/circular_deps/pyproject.toml b/code_to_optimize/code_directories/circular_deps/pyproject.toml new file mode 100644 index 000000000..bddef0ed3 --- /dev/null +++ b/code_to_optimize/code_directories/circular_deps/pyproject.toml @@ -0,0 +1,7 @@ +[tool.codeflash] +# All paths are relative to this pyproject.toml's directory. +module-root = "." +tests-root = "tests" +test-framework = "pytest" +ignore-paths = [] +formatter-cmds = ["black $file"] diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 2b125ce7d..73a1c326f 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -325,36 +325,30 @@ def add_needed_imports_from_module( ) ) cst.parse_module(src_module_code).visit(gatherer) - scheduled_unused_imports = [] try: for mod in gatherer.module_imports: AddImportsVisitor.add_needed_import(dst_context, mod) - scheduled_unused_imports.append((mod, "", "")) + RemoveImportsVisitor.remove_unused_import(dst_context, mod) for mod, obj_seq in gatherer.object_mapping.items(): - logger.debug(f"dst_context.full_module_name: {dst_context.full_module_name}") - logger.debug(f"mod: {mod}") - logger.debug(f"obj_seq: {obj_seq}") - logger.debug(f"helper_functions_fqn: {helper_functions_fqn}") for obj in obj_seq: if ( - f"{mod}.{obj}" in helper_functions_fqn - or dst_context.full_module_name == mod # avoid circular imports + f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps ): continue # Skip adding imports for helper functions already in the context AddImportsVisitor.add_needed_import(dst_context, mod, obj) - scheduled_unused_imports.append((mod, obj, "")) + RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj) except Exception as e: logger.exception(f"Error adding imports to destination module code: {e}") return dst_module_code for mod, asname in gatherer.module_aliases.items(): AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname) - scheduled_unused_imports.append((mod, "", asname)) + RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname) for mod, alias_pairs in gatherer.alias_mapping.items(): for alias_pair in alias_pairs: if f"{mod}.{alias_pair[0]}" in helper_functions_fqn: continue AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1]) - scheduled_unused_imports.append((mod, alias_pair[0], alias_pair[1])) + RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1]) try: parsed_module = cst.parse_module(dst_module_code) @@ -363,9 +357,6 @@ def add_needed_imports_from_module( return dst_module_code # Return the original code if there's a syntax error try: transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module) - for _import in scheduled_unused_imports: - (_module, _obj, _alias) = _import - RemoveImportsVisitor.remove_unused_import(dst_context, module=_module, obj=_obj, asname=_alias) transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module) return transformed_module.code.lstrip("\n") except Exception as e: diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index b9acffcda..3c73c5919 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -397,13 +397,6 @@ def replace_functions_and_add_imports( preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], project_root_path: Path, ) -> str: - logger.debug("start from here,...") - logger.debug(f"source_code: {source_code}") - logger.debug(f"function_names: {function_names}") - logger.debug(f"optimized_code: {optimized_code}") - logger.debug(f"module_abspath: {module_abspath}") - logger.debug(f"preexisting_objects: {preexisting_objects}") - logger.debug(f"project_root_path: {project_root_path}") return add_needed_imports_from_module( optimized_code, replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects), @@ -422,12 +415,16 @@ def replace_function_definitions_in_module( ) -> bool: source_code: str = module_abspath.read_text(encoding="utf8") new_code: str = replace_functions_and_add_imports( - source_code, function_names, optimized_code, module_abspath, preexisting_objects, project_root_path + add_global_assignments(optimized_code, source_code), + function_names, + optimized_code, + module_abspath, + preexisting_objects, + project_root_path, ) if is_zero_diff(source_code, new_code): return False - code_with_global_assignments = add_global_assignments(optimized_code, new_code) - module_abspath.write_text(code_with_global_assignments, encoding="utf8") + module_abspath.write_text(new_code, encoding="utf8") return True diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 3ef07adbf..25200cb9c 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -11,6 +11,8 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent from codeflash.optimization.optimizer import Optimizer +from codeflash.code_utils.code_replacer import replace_functions_and_add_imports +from codeflash.code_utils.code_extractor import add_global_assignments class HelperClass: @@ -2436,51 +2438,20 @@ def simple_method(self): assert "return 42" in code_content -from codeflash.code_utils.code_replacer import replace_functions_and_add_imports + def test_replace_functions_and_add_imports(): path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps" - optimized_code = '''from __future__ import annotations - -import urllib.parse -from os import getenv - -from attrs import define -from code_to_optimize.code_directories.circular_deps.constants import DEFAULT_API_URL, DEFAULT_APP_URL - -# Precompute constant netlocs for set membership test -_DEFAULT_APP_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc -_DEFAULT_API_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc -_NETLOC_SET = {_DEFAULT_APP_NETLOC, _DEFAULT_API_NETLOC} - -@define -class GalileoApiClient(): - - @staticmethod - def get_console_url() -> str: - # Return DEFAULT_APP_URL if the env var is not set or set to DEFAULT_API_URL - console_url = getenv("GALILEO_CONSOLE_URL", DEFAULT_API_URL) - if console_url == DEFAULT_API_URL: - return DEFAULT_APP_URL - return console_url - -def _set_destination(console_url: str) -> str: - """ - Parse the console_url and return the destination for the OpenTelemetry traces. - """ - destination = (console_url or GalileoApiClient.get_console_url()).replace("console.", "api.") - parsed_url = urllib.parse.urlparse(destination) - if parsed_url.netloc in _NETLOC_SET: - return f"{DEFAULT_APP_URL}api/galileo/otel/traces" - return f"{parsed_url.scheme}://{parsed_url.netloc}/otel/traces"''' file_abs_path = path_to_root / "api_client.py" + optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8") content = Path(file_abs_path).read_text(encoding="utf-8") new_code = replace_functions_and_add_imports( - source_code= content, - function_names= ["GalileoApiClient.get_console_url"], + source_code= add_global_assignments(optimized_code, content), + function_names= ["ApiClient.get_console_url"], optimized_code= optimized_code, - module_abspath= file_abs_path, - preexisting_objects= {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))}, + module_abspath= Path(file_abs_path), + preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))}, project_root_path= Path(path_to_root), ) - print(new_code) - assert 1 == 1 \ No newline at end of file + assert "import ApiClient" not in new_code, "Error: Circular dependency found" + + assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist" From 190d5930cad71ca8b284941799de7828d006b4a1 Mon Sep 17 00:00:00 2001 From: mohammed Date: Sun, 29 Jun 2025 00:44:02 +0300 Subject: [PATCH 6/6] fix: tests expected code new lines --- tests/test_code_replacement.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 363dbaee4..7272163d3 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -1693,8 +1693,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") a=2 print("Hello world") def some_fn(): @@ -1712,8 +1712,7 @@ def __init__(self, name): def __call__(self, value): return "I am still old" def new_function2(value): - return cst.ensure_type(value, str) -""" + return cst.ensure_type(value, str)""" code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() code_path.write_text(original_code, encoding="utf-8") tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") @@ -1769,8 +1768,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") print("Hello world") def some_fn(): a=np.zeros(10) @@ -1846,8 +1845,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") a=3 print("Hello world") def some_fn(): @@ -1922,8 +1921,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") a=2 print("Hello world") def some_fn(): @@ -1999,8 +1998,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") a=3 print("Hello world") def some_fn(): @@ -2082,8 +2081,8 @@ def new_function2(value): print("Hello world") """ expected_code = """import numpy as np -print("Hello world") +print("Hello world") if 2<3: a=4 else: