Skip to content

Commit f900525

Browse files
authored
Merge branch 'main' into saga4/fix_optimizable_functions
2 parents 13885aa + 2b5fa6e commit f900525

File tree

9 files changed

+137
-15
lines changed

9 files changed

+137
-15
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from os import getenv
2+
3+
from attrs import define, evolve
4+
5+
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
6+
7+
8+
@define
9+
class ApiClient():
10+
api_key_header_name: str = "API-Key"
11+
client_type_header_name: str = "client-type"
12+
client_type_header_value: str = "sdk-python"
13+
14+
@staticmethod
15+
def get_console_url() -> str:
16+
console_url = getenv("CONSOLE_URL", DEFAULT_API_URL)
17+
if DEFAULT_API_URL == console_url:
18+
return DEFAULT_APP_URL
19+
20+
return console_url
21+
22+
def with_api_key(self, api_key: str) -> "ApiClient": # ---> here is the problem with circular dependency, this makes libcst thinks that ApiClient needs an import despite it's already in the same file.
23+
"""Get a new client matching this one with a new API key"""
24+
return evolve(self, api_key=api_key)
25+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
DEFAULT_API_URL = "https://api.galileo.ai/"
2+
DEFAULT_APP_URL = "https://app.galileo.ai/"
3+
4+
5+
# function_names: GalileoApiClient.get_console_url
6+
# module_abs_path : /home/mohammed/Work/galileo-python/src/galileo/api_client.py
7+
# preexisting_objects: {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))}
8+
# project_root_path: /home/mohammed/Work/galileo-python/src
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import annotations
2+
3+
import urllib.parse
4+
from os import getenv
5+
6+
from attrs import define
7+
from api_client import ApiClient
8+
from constants import DEFAULT_API_URL, DEFAULT_APP_URL
9+
10+
11+
@define
12+
class ApiClient():
13+
14+
@staticmethod
15+
def get_console_url() -> str:
16+
# Cache env lookup for speed
17+
console_url = getenv("CONSOLE_URL")
18+
if not console_url or console_url == DEFAULT_API_URL:
19+
return DEFAULT_APP_URL
20+
return console_url
21+
22+
# Pre-parse netlocs that are checked frequently to avoid parsing repeatedly
23+
_DEFAULT_APP_URL_NETLOC = urllib.parse.urlparse(DEFAULT_APP_URL).netloc
24+
_DEFAULT_API_URL_NETLOC = urllib.parse.urlparse(DEFAULT_API_URL).netloc
25+
26+
def get_dest_url(url: str) -> str:
27+
destination = url if url else ApiClient.get_console_url()
28+
# Replace only if 'console.' is at the beginning to avoid partial matches
29+
if destination.startswith("console."):
30+
destination = "api." + destination[len("console."):]
31+
else:
32+
destination = destination.replace("console.", "api.", 1)
33+
34+
parsed_url = urllib.parse.urlparse(destination)
35+
if parsed_url.netloc == _DEFAULT_APP_URL_NETLOC or parsed_url.netloc == _DEFAULT_API_URL_NETLOC:
36+
return f"{DEFAULT_APP_URL}api/traces"
37+
return f"{parsed_url.scheme}://{parsed_url.netloc}/traces"
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[tool.codeflash]
2+
# All paths are relative to this pyproject.toml's directory.
3+
module-root = "."
4+
tests-root = "tests"
5+
test-framework = "pytest"
6+
ignore-paths = []
7+
formatter-cmds = ["black $file"]

codeflash/code_utils/code_extractor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,9 @@ def add_needed_imports_from_module(
331331
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
332332
for mod, obj_seq in gatherer.object_mapping.items():
333333
for obj in obj_seq:
334-
if f"{mod}.{obj}" in helper_functions_fqn:
334+
if (
335+
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
336+
):
335337
continue # Skip adding imports for helper functions already in the context
336338
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
337339
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,16 @@ def replace_function_definitions_in_module(
415415
) -> bool:
416416
source_code: str = module_abspath.read_text(encoding="utf8")
417417
new_code: str = replace_functions_and_add_imports(
418-
source_code, function_names, optimized_code, module_abspath, preexisting_objects, project_root_path
418+
add_global_assignments(optimized_code, source_code),
419+
function_names,
420+
optimized_code,
421+
module_abspath,
422+
preexisting_objects,
423+
project_root_path,
419424
)
420425
if is_zero_diff(source_code, new_code):
421426
return False
422-
code_with_global_assignments = add_global_assignments(optimized_code, new_code)
423-
module_abspath.write_text(code_with_global_assignments, encoding="utf8")
427+
module_abspath.write_text(new_code, encoding="utf8")
424428
return True
425429

426430

codeflash/optimization/optimizer.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
1313
from codeflash.cli_cmds.console import console, logger, progress_bar
1414
from codeflash.code_utils import env_utils
15+
from codeflash.code_utils.code_utils import cleanup_paths
1516
from codeflash.code_utils.env_utils import get_pr_number
1617
from codeflash.either import is_successful
1718
from codeflash.models.models import ValidCode
@@ -248,10 +249,10 @@ def run(self) -> None:
248249
return
249250
if not env_utils.check_formatter_installed(self.args.formatter_cmds):
250251
return
251-
252252
if self.args.no_draft and is_pr_draft():
253253
logger.warning("PR is in draft mode, skipping optimization")
254254
return
255+
cleanup_paths(Optimizer.find_leftover_instrumented_test_files(self.test_cfg.tests_root))
255256

256257
function_optimizer = None
257258
file_to_funcs_to_optimize, num_optimizable_functions = self.get_optimizable_functions()
@@ -326,9 +327,27 @@ def run(self) -> None:
326327

327328
self.cleanup_temporary_paths()
328329

329-
def cleanup_temporary_paths(self) -> None:
330-
from codeflash.code_utils.code_utils import cleanup_paths
330+
@staticmethod
331+
def find_leftover_instrumented_test_files(test_root: Path) -> list[Path]:
332+
"""Search for all paths within the test_root that match the following patterns.
331333
334+
- 'test.*__perf_test_{0,1}.py'
335+
- 'test_.*__unit_test_{0,1}.py'
336+
- 'test_.*__perfinstrumented.py'
337+
- 'test_.*__perfonlyinstrumented.py'
338+
Returns a list of matching file paths.
339+
"""
340+
import re
341+
342+
pattern = re.compile(
343+
r"(?:test.*__perf_test_\d?\.py|test_.*__unit_test_\d?\.py|test_.*__perfinstrumented\.py|test_.*__perfonlyinstrumented\.py)$"
344+
)
345+
346+
return [
347+
file_path for file_path in test_root.rglob("*") if file_path.is_file() and pattern.match(file_path.name)
348+
]
349+
350+
def cleanup_temporary_paths(self) -> None:
332351
if self.current_function_optimizer:
333352
self.current_function_optimizer.cleanup_generated_files()
334353

tests/test_code_context_extractor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1212
from codeflash.models.models import FunctionParent
1313
from codeflash.optimization.optimizer import Optimizer
14+
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
15+
from codeflash.code_utils.code_extractor import add_global_assignments
1416

1517

1618
class HelperClass:
@@ -2434,3 +2436,22 @@ def simple_method(self):
24342436
assert "class SimpleClass:" in code_content
24352437
assert "def simple_method(self):" in code_content
24362438
assert "return 42" in code_content
2439+
2440+
2441+
2442+
def test_replace_functions_and_add_imports():
2443+
path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps"
2444+
file_abs_path = path_to_root / "api_client.py"
2445+
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
2446+
content = Path(file_abs_path).read_text(encoding="utf-8")
2447+
new_code = replace_functions_and_add_imports(
2448+
source_code= add_global_assignments(optimized_code, content),
2449+
function_names= ["ApiClient.get_console_url"],
2450+
optimized_code= optimized_code,
2451+
module_abspath= Path(file_abs_path),
2452+
preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))},
2453+
project_root_path= Path(path_to_root),
2454+
)
2455+
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
2456+
2457+
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"

tests/test_code_replacement.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,8 +1693,8 @@ def new_function2(value):
16931693
print("Hello world")
16941694
"""
16951695
expected_code = """import numpy as np
1696-
print("Hello world")
16971696
1697+
print("Hello world")
16981698
a=2
16991699
print("Hello world")
17001700
def some_fn():
@@ -1712,8 +1712,7 @@ def __init__(self, name):
17121712
def __call__(self, value):
17131713
return "I am still old"
17141714
def new_function2(value):
1715-
return cst.ensure_type(value, str)
1716-
"""
1715+
return cst.ensure_type(value, str)"""
17171716
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
17181717
code_path.write_text(original_code, encoding="utf-8")
17191718
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
@@ -1769,8 +1768,8 @@ def new_function2(value):
17691768
print("Hello world")
17701769
"""
17711770
expected_code = """import numpy as np
1772-
print("Hello world")
17731771
1772+
print("Hello world")
17741773
print("Hello world")
17751774
def some_fn():
17761775
a=np.zeros(10)
@@ -1846,8 +1845,8 @@ def new_function2(value):
18461845
print("Hello world")
18471846
"""
18481847
expected_code = """import numpy as np
1849-
print("Hello world")
18501848
1849+
print("Hello world")
18511850
a=3
18521851
print("Hello world")
18531852
def some_fn():
@@ -1922,8 +1921,8 @@ def new_function2(value):
19221921
print("Hello world")
19231922
"""
19241923
expected_code = """import numpy as np
1925-
print("Hello world")
19261924
1925+
print("Hello world")
19271926
a=2
19281927
print("Hello world")
19291928
def some_fn():
@@ -1999,8 +1998,8 @@ def new_function2(value):
19991998
print("Hello world")
20001999
"""
20012000
expected_code = """import numpy as np
2002-
print("Hello world")
20032001
2002+
print("Hello world")
20042003
a=3
20052004
print("Hello world")
20062005
def some_fn():
@@ -2082,8 +2081,8 @@ def new_function2(value):
20822081
print("Hello world")
20832082
"""
20842083
expected_code = """import numpy as np
2085-
print("Hello world")
20862084
2085+
print("Hello world")
20872086
if 2<3:
20882087
a=4
20892088
else:

0 commit comments

Comments
 (0)