Skip to content

Commit cada1c9

Browse files
authored
Merge pull request #4 from codeflash-ai/revert-2-function-optimizer-refactor
Revert "Refactor optimizer into FunctionOptimizer class"
2 parents 3e7f30b + f86d997 commit cada1c9

10 files changed

+1523
-1568
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 0 additions & 1176 deletions
This file was deleted.

codeflash/optimization/optimizer.py

Lines changed: 1148 additions & 30 deletions
Large diffs are not rendered by default.

tests/test_code_replacement.py

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dataclasses
44
import os
5+
from argparse import Namespace
56
from collections import defaultdict
67
from pathlib import Path
78

@@ -13,8 +14,7 @@
1314
)
1415
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1516
from codeflash.models.models import FunctionParent
16-
from codeflash.optimization.function_optimizer import FunctionOptimizer
17-
from codeflash.verification.verification_utils import TestConfig
17+
from codeflash.optimization.optimizer import Optimizer
1818

1919
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
2020

@@ -766,18 +766,24 @@ def main_method(self):
766766
return HelperClass(self.name).helper_method()
767767
"""
768768
file_path = Path(__file__).resolve()
769+
opt = Optimizer(
770+
Namespace(
771+
project_root=file_path.parent.resolve(),
772+
disable_telemetry=True,
773+
tests_root="tests",
774+
test_framework="pytest",
775+
pytest_cmd="pytest",
776+
experiment_id=None,
777+
test_project_root=file_path.parent.resolve(),
778+
)
779+
)
769780
func_top_optimize = FunctionToOptimize(
770781
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
771782
)
772-
test_config = TestConfig(
773-
tests_root=file_path.parent,
774-
tests_project_rootdir=file_path.parent,
775-
project_root_path=file_path.parent,
776-
test_framework="pytest",
777-
pytest_cmd="pytest",
778-
)
779-
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
780-
code_context = func_optimizer.get_code_optimization_context().unwrap()
783+
original_code = file_path.read_text()
784+
code_context = opt.get_code_optimization_context(
785+
function_to_optimize=func_top_optimize, project_root=file_path.parent, original_source_code=original_code
786+
).unwrap()
781787
assert code_context.code_to_optimize_with_helpers == get_code_output
782788

783789

@@ -1007,35 +1013,35 @@ def to_name(self) -> str:
10071013
class TestResults(BaseModel):
10081014
def __iter__(self) -> Iterator[FunctionTestInvocation]:
10091015
return iter(self.test_results)
1010-
1016+
10111017
def __len__(self) -> int:
10121018
return len(self.test_results)
1013-
1019+
10141020
def __getitem__(self, index: int) -> FunctionTestInvocation:
10151021
return self.test_results[index]
1016-
1022+
10171023
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
10181024
self.test_results[index] = value
1019-
1025+
10201026
def __delitem__(self, index: int) -> None:
10211027
del self.test_results[index]
1022-
1028+
10231029
def __contains__(self, value: FunctionTestInvocation) -> bool:
10241030
return value in self.test_results
1025-
1031+
10261032
def __bool__(self) -> bool:
10271033
return bool(self.test_results)
1028-
1034+
10291035
def __eq__(self, other: object) -> bool:
10301036
# Unordered comparison
10311037
if not isinstance(other, TestResults) or len(self) != len(other):
10321038
return False
1033-
1039+
10341040
# Increase recursion limit only if necessary
10351041
original_recursion_limit = sys.getrecursionlimit()
10361042
if original_recursion_limit < 5000:
10371043
sys.setrecursionlimit(5000)
1038-
1044+
10391045
for test_result in self:
10401046
other_test_result = other.get_by_id(test_result.id)
10411047
if other_test_result is None or not (
@@ -1048,10 +1054,10 @@ def __eq__(self, other: object) -> bool:
10481054
):
10491055
sys.setrecursionlimit(original_recursion_limit)
10501056
return False
1051-
1057+
10521058
sys.setrecursionlimit(original_recursion_limit)
10531059
return True
1054-
1060+
10551061
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
10561062
report = {test_type: {"passed": 0, "failed": 0} for test_type in TestType}
10571063
for test_result in self.test_results:
@@ -1099,8 +1105,8 @@ def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
10991105
)
11001106

11011107
assert (
1102-
new_code
1103-
== """from __future__ import annotations
1108+
new_code
1109+
== """from __future__ import annotations
11041110
import sys
11051111
from codeflash.verification.comparator import comparator
11061112
from enum import Enum
@@ -1239,21 +1245,21 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
12391245
"""Row-wise cosine similarity between two equal-width matrices."""
12401246
if len(X.data) == 0 or len(Y.data) == 0:
12411247
return np.array([])
1242-
1248+
12431249
X_np, Y_np = np.asarray(X.data), np.asarray(Y.data)
12441250
if X_np.shape[1] != Y_np.shape[1]:
12451251
raise ValueError(f"Number of columns in X and Y must be the same. X has shape {X_np.shape} and Y has shape {Y_np.shape}.")
12461252
X_norm = np.linalg.norm(X_np, axis=1, keepdims=True)
12471253
Y_norm = np.linalg.norm(Y_np, axis=1, keepdims=True)
1248-
1254+
12491255
norm_product = X_norm * Y_norm.T
12501256
norm_product[norm_product == 0] = np.inf # Prevent division by zero
12511257
dot_product = np.dot(X_np, Y_np.T)
12521258
similarity = dot_product / norm_product
1253-
1259+
12541260
# Any NaN or Inf values are set to 0.0
12551261
np.nan_to_num(similarity, copy=False)
1256-
1262+
12571263
return similarity
12581264
def cosine_similarity_top_k(
12591265
X: Matrix,
@@ -1264,15 +1270,15 @@ def cosine_similarity_top_k(
12641270
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
12651271
if len(X.data) == 0 or len(Y.data) == 0:
12661272
return [], []
1267-
1273+
12681274
score_array = cosine_similarity(X, Y)
1269-
1275+
12701276
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
12711277
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
1272-
1278+
12731279
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
12741280
scores = score_array.flatten()[sorted_idxs].tolist()
1275-
1281+
12761282
return ret_idxs, scores
12771283
'''
12781284
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
@@ -1305,8 +1311,8 @@ def cosine_similarity_top_k(
13051311
project_root_path=Path(__file__).parent.parent.resolve(),
13061312
)
13071313
assert (
1308-
new_code
1309-
== '''import numpy as np
1314+
new_code
1315+
== '''import numpy as np
13101316
from pydantic.dataclasses import dataclass
13111317
from typing import List, Optional, Tuple, Union
13121318
@dataclass(config=dict(arbitrary_types_allowed=True))
@@ -1337,15 +1343,15 @@ def cosine_similarity_top_k(
13371343
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
13381344
if len(X.data) == 0 or len(Y.data) == 0:
13391345
return [], []
1340-
1346+
13411347
score_array = cosine_similarity(X, Y)
1342-
1348+
13431349
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
13441350
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
1345-
1351+
13461352
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
13471353
scores = score_array.flatten()[sorted_idxs].tolist()
1348-
1354+
13491355
return ret_idxs, scores
13501356
'''
13511357
)
@@ -1364,8 +1370,8 @@ def cosine_similarity_top_k(
13641370
)
13651371

13661372
assert (
1367-
new_helper_code
1368-
== '''import numpy as np
1373+
new_helper_code
1374+
== '''import numpy as np
13691375
from pydantic.dataclasses import dataclass
13701376
from typing import List, Optional, Tuple, Union
13711377
@dataclass(config=dict(arbitrary_types_allowed=True))
@@ -1375,21 +1381,21 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
13751381
"""Row-wise cosine similarity between two equal-width matrices."""
13761382
if len(X.data) == 0 or len(Y.data) == 0:
13771383
return np.array([])
1378-
1384+
13791385
X_np, Y_np = np.asarray(X.data), np.asarray(Y.data)
13801386
if X_np.shape[1] != Y_np.shape[1]:
13811387
raise ValueError(f"Number of columns in X and Y must be the same. X has shape {X_np.shape} and Y has shape {Y_np.shape}.")
13821388
X_norm = np.linalg.norm(X_np, axis=1, keepdims=True)
13831389
Y_norm = np.linalg.norm(Y_np, axis=1, keepdims=True)
1384-
1390+
13851391
norm_product = X_norm * Y_norm.T
13861392
norm_product[norm_product == 0] = np.inf # Prevent division by zero
13871393
dot_product = np.dot(X_np, Y_np.T)
13881394
similarity = dot_product / norm_product
1389-
1395+
13901396
# Any NaN or Inf values are set to 0.0
13911397
np.nan_to_num(similarity, copy=False)
1392-
1398+
13931399
return similarity
13941400
def cosine_similarity_top_k(
13951401
X: Matrix,
@@ -1400,15 +1406,15 @@ def cosine_similarity_top_k(
14001406
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
14011407
if len(X.data) == 0 or len(Y.data) == 0:
14021408
return [], []
1403-
1409+
14041410
score_array = cosine_similarity(X, Y)
1405-
1411+
14061412
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
14071413
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
1408-
1414+
14091415
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
14101416
scores = score_array.flatten()[sorted_idxs].tolist()
1411-
1417+
14121418
return ret_idxs, scores
14131419
'''
14141420
)
@@ -1475,7 +1481,7 @@ def test_future_aliased_imports_removal() -> None:
14751481

14761482
def test_0_diff_code_replacement():
14771483
original_code = """from __future__ import annotations
1478-
1484+
14791485
import numpy as np
14801486
def functionA():
14811487
return np.array([1, 2, 3])

0 commit comments

Comments
 (0)