Skip to content

Commit 3e7f30b

Browse files
authored
Merge pull request #2 from codeflash-ai/function-optimizer-refactor
Refactor optimizer into FunctionOptimizer class
2 parents 06137a8 + 8e8258a commit 3e7f30b

10 files changed

+1568
-1523
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 1176 additions & 0 deletions
Large diffs are not rendered by default.

codeflash/optimization/optimizer.py

Lines changed: 30 additions & 1148 deletions
Large diffs are not rendered by default.

tests/test_code_replacement.py

Lines changed: 49 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import dataclasses
44
import os
5-
from argparse import Namespace
65
from collections import defaultdict
76
from pathlib import Path
87

@@ -14,7 +13,8 @@
1413
)
1514
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1615
from codeflash.models.models import FunctionParent
17-
from codeflash.optimization.optimizer import Optimizer
16+
from codeflash.optimization.function_optimizer import FunctionOptimizer
17+
from codeflash.verification.verification_utils import TestConfig
1818

1919
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
2020

@@ -766,24 +766,18 @@ def main_method(self):
766766
return HelperClass(self.name).helper_method()
767767
"""
768768
file_path = Path(__file__).resolve()
769-
opt = Optimizer(
770-
Namespace(
771-
project_root=file_path.parent.resolve(),
772-
disable_telemetry=True,
773-
tests_root="tests",
774-
test_framework="pytest",
775-
pytest_cmd="pytest",
776-
experiment_id=None,
777-
test_project_root=file_path.parent.resolve(),
778-
)
779-
)
780769
func_top_optimize = FunctionToOptimize(
781770
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
782771
)
783-
original_code = file_path.read_text()
784-
code_context = opt.get_code_optimization_context(
785-
function_to_optimize=func_top_optimize, project_root=file_path.parent, original_source_code=original_code
786-
).unwrap()
772+
test_config = TestConfig(
773+
tests_root=file_path.parent,
774+
tests_project_rootdir=file_path.parent,
775+
project_root_path=file_path.parent,
776+
test_framework="pytest",
777+
pytest_cmd="pytest",
778+
)
779+
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
780+
code_context = func_optimizer.get_code_optimization_context().unwrap()
787781
assert code_context.code_to_optimize_with_helpers == get_code_output
788782

789783

@@ -1013,35 +1007,35 @@ def to_name(self) -> str:
10131007
class TestResults(BaseModel):
10141008
def __iter__(self) -> Iterator[FunctionTestInvocation]:
10151009
return iter(self.test_results)
1016-
1010+
10171011
def __len__(self) -> int:
10181012
return len(self.test_results)
1019-
1013+
10201014
def __getitem__(self, index: int) -> FunctionTestInvocation:
10211015
return self.test_results[index]
1022-
1016+
10231017
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
10241018
self.test_results[index] = value
1025-
1019+
10261020
def __delitem__(self, index: int) -> None:
10271021
del self.test_results[index]
1028-
1022+
10291023
def __contains__(self, value: FunctionTestInvocation) -> bool:
10301024
return value in self.test_results
1031-
1025+
10321026
def __bool__(self) -> bool:
10331027
return bool(self.test_results)
1034-
1028+
10351029
def __eq__(self, other: object) -> bool:
10361030
# Unordered comparison
10371031
if not isinstance(other, TestResults) or len(self) != len(other):
10381032
return False
1039-
1033+
10401034
# Increase recursion limit only if necessary
10411035
original_recursion_limit = sys.getrecursionlimit()
10421036
if original_recursion_limit < 5000:
10431037
sys.setrecursionlimit(5000)
1044-
1038+
10451039
for test_result in self:
10461040
other_test_result = other.get_by_id(test_result.id)
10471041
if other_test_result is None or not (
@@ -1054,10 +1048,10 @@ def __eq__(self, other: object) -> bool:
10541048
):
10551049
sys.setrecursionlimit(original_recursion_limit)
10561050
return False
1057-
1051+
10581052
sys.setrecursionlimit(original_recursion_limit)
10591053
return True
1060-
1054+
10611055
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
10621056
report = {test_type: {"passed": 0, "failed": 0} for test_type in TestType}
10631057
for test_result in self.test_results:
@@ -1105,8 +1099,8 @@ def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
11051099
)
11061100

11071101
assert (
1108-
new_code
1109-
== """from __future__ import annotations
1102+
new_code
1103+
== """from __future__ import annotations
11101104
import sys
11111105
from codeflash.verification.comparator import comparator
11121106
from enum import Enum
@@ -1245,21 +1239,21 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
12451239
"""Row-wise cosine similarity between two equal-width matrices."""
12461240
if len(X.data) == 0 or len(Y.data) == 0:
12471241
return np.array([])
1248-
1242+
12491243
X_np, Y_np = np.asarray(X.data), np.asarray(Y.data)
12501244
if X_np.shape[1] != Y_np.shape[1]:
12511245
raise ValueError(f"Number of columns in X and Y must be the same. X has shape {X_np.shape} and Y has shape {Y_np.shape}.")
12521246
X_norm = np.linalg.norm(X_np, axis=1, keepdims=True)
12531247
Y_norm = np.linalg.norm(Y_np, axis=1, keepdims=True)
1254-
1248+
12551249
norm_product = X_norm * Y_norm.T
12561250
norm_product[norm_product == 0] = np.inf # Prevent division by zero
12571251
dot_product = np.dot(X_np, Y_np.T)
12581252
similarity = dot_product / norm_product
1259-
1253+
12601254
# Any NaN or Inf values are set to 0.0
12611255
np.nan_to_num(similarity, copy=False)
1262-
1256+
12631257
return similarity
12641258
def cosine_similarity_top_k(
12651259
X: Matrix,
@@ -1270,15 +1264,15 @@ def cosine_similarity_top_k(
12701264
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
12711265
if len(X.data) == 0 or len(Y.data) == 0:
12721266
return [], []
1273-
1267+
12741268
score_array = cosine_similarity(X, Y)
1275-
1269+
12761270
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
12771271
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
1278-
1272+
12791273
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
12801274
scores = score_array.flatten()[sorted_idxs].tolist()
1281-
1275+
12821276
return ret_idxs, scores
12831277
'''
12841278
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
@@ -1311,8 +1305,8 @@ def cosine_similarity_top_k(
13111305
project_root_path=Path(__file__).parent.parent.resolve(),
13121306
)
13131307
assert (
1314-
new_code
1315-
== '''import numpy as np
1308+
new_code
1309+
== '''import numpy as np
13161310
from pydantic.dataclasses import dataclass
13171311
from typing import List, Optional, Tuple, Union
13181312
@dataclass(config=dict(arbitrary_types_allowed=True))
@@ -1343,15 +1337,15 @@ def cosine_similarity_top_k(
13431337
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
13441338
if len(X.data) == 0 or len(Y.data) == 0:
13451339
return [], []
1346-
1340+
13471341
score_array = cosine_similarity(X, Y)
1348-
1342+
13491343
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
13501344
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
1351-
1345+
13521346
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
13531347
scores = score_array.flatten()[sorted_idxs].tolist()
1354-
1348+
13551349
return ret_idxs, scores
13561350
'''
13571351
)
@@ -1370,8 +1364,8 @@ def cosine_similarity_top_k(
13701364
)
13711365

13721366
assert (
1373-
new_helper_code
1374-
== '''import numpy as np
1367+
new_helper_code
1368+
== '''import numpy as np
13751369
from pydantic.dataclasses import dataclass
13761370
from typing import List, Optional, Tuple, Union
13771371
@dataclass(config=dict(arbitrary_types_allowed=True))
@@ -1381,21 +1375,21 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
13811375
"""Row-wise cosine similarity between two equal-width matrices."""
13821376
if len(X.data) == 0 or len(Y.data) == 0:
13831377
return np.array([])
1384-
1378+
13851379
X_np, Y_np = np.asarray(X.data), np.asarray(Y.data)
13861380
if X_np.shape[1] != Y_np.shape[1]:
13871381
raise ValueError(f"Number of columns in X and Y must be the same. X has shape {X_np.shape} and Y has shape {Y_np.shape}.")
13881382
X_norm = np.linalg.norm(X_np, axis=1, keepdims=True)
13891383
Y_norm = np.linalg.norm(Y_np, axis=1, keepdims=True)
1390-
1384+
13911385
norm_product = X_norm * Y_norm.T
13921386
norm_product[norm_product == 0] = np.inf # Prevent division by zero
13931387
dot_product = np.dot(X_np, Y_np.T)
13941388
similarity = dot_product / norm_product
1395-
1389+
13961390
# Any NaN or Inf values are set to 0.0
13971391
np.nan_to_num(similarity, copy=False)
1398-
1392+
13991393
return similarity
14001394
def cosine_similarity_top_k(
14011395
X: Matrix,
@@ -1406,15 +1400,15 @@ def cosine_similarity_top_k(
14061400
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
14071401
if len(X.data) == 0 or len(Y.data) == 0:
14081402
return [], []
1409-
1403+
14101404
score_array = cosine_similarity(X, Y)
1411-
1405+
14121406
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
14131407
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
1414-
1408+
14151409
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
14161410
scores = score_array.flatten()[sorted_idxs].tolist()
1417-
1411+
14181412
return ret_idxs, scores
14191413
'''
14201414
)
@@ -1481,7 +1475,7 @@ def test_future_aliased_imports_removal() -> None:
14811475

14821476
def test_0_diff_code_replacement():
14831477
original_code = """from __future__ import annotations
1484-
1478+
14851479
import numpy as np
14861480
def functionA():
14871481
return np.array([1, 2, 3])

0 commit comments

Comments
 (0)