Skip to content

Commit 651ded4

Browse files
committed
experiment
1 parent 9cd4743 commit 651ded4

File tree

6 files changed

+81
-50
lines changed

6 files changed

+81
-50
lines changed

codeflash/api/aiservice.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ def get_optimization_impact(
544544
replay_tests: str,
545545
root_dir: Path,
546546
concolic_tests: str, # noqa: ARG002
547-
) -> str:
547+
calling_fn_details: str,
548+
) -> tuple[str, str]:
548549
"""Compute the optimization impact of current Pull Request.
549550
550551
Args:
@@ -558,6 +559,7 @@ def get_optimization_impact(
558559
replay_tests: str -> replay test table
559560
root_dir: Path -> path of git directory
560561
concolic_tests: str -> concolic_tests (not used)
562+
calling_fn_details: str -> filenames and definitions of functions which call the function_to_optimize
561563
562564
Returns:
563565
-------
@@ -577,13 +579,6 @@ def get_optimization_impact(
577579
]
578580
)
579581
code_diff = f"```diff\n{diff_str}\n```"
580-
# TODO get complexity metrics and fn call heuristics -> constructing a complete static call graph can be expensive for really large repos
581-
# grep function name in codebase -> ast parser to get no of calls and no of calls in loop -> radon lib to get complexity metrics -> send as additional context to the AI service
582-
# metric 1 -> call count - how many times the function is called in the codebase
583-
# metric 2 -> loop call count - how many times the function is called in a loop in the codebase
584-
# metric 3 -> presence of decorators like @profile, @cache -> this means the owner of the repo cares about the performance of this function
585-
# metric 4 -> cyclomatic complexity (https://en.wikipedia.org/wiki/Cyclomatic_complexity)
586-
# metric 5 (for future) -> halstead complexity (https://en.wikipedia.org/wiki/Halstead_complexity_measures)
587582
logger.info("!lsp|Computing Optimization Impact…")
588583
payload = {
589584
"code_diff": code_diff,
@@ -598,25 +593,26 @@ def get_optimization_impact(
598593
"benchmark_details": explanation.benchmark_details if explanation.benchmark_details else None,
599594
"optimized_runtime": humanize_runtime(explanation.best_runtime_ns),
600595
"original_runtime": humanize_runtime(explanation.original_runtime_ns),
596+
"calling_fn_details": calling_fn_details,
601597
}
602598
console.rule()
603599
try:
604600
response = self.make_ai_service_request("/optimization_impact", payload=payload, timeout=600)
605601
except requests.exceptions.RequestException as e:
606602
logger.exception(f"Error generating optimization refinements: {e}")
607603
ph("cli-optimize-error-caught", {"error": str(e)})
608-
return ""
604+
return ("", str(e))
609605

610606
if response.status_code == 200:
611-
return cast("str", response.json()["impact"])
607+
return (cast("str", response.json()["impact"]), cast("str", response.json()["impact_explanation"]))
612608
try:
613609
error = cast("str", response.json()["error"])
614610
except Exception:
615611
error = response.text
616612
logger.error(f"Error generating impact candidates: {response.status_code} - {error}")
617613
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
618614
console.rule()
619-
return ""
615+
return ("", error)
620616

621617

622618
class LocalAiServiceClient(AiServiceClient):

codeflash/api/cfapi.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def create_pr(
173173
coverage_message: str,
174174
replay_tests: str = "",
175175
concolic_tests: str = "",
176+
optimization_impact: str = "",
176177
) -> Response:
177178
"""Create a pull request, targeting the specified branch. (usually 'main').
178179
@@ -197,6 +198,7 @@ def create_pr(
197198
"coverage_message": coverage_message,
198199
"replayTests": replay_tests,
199200
"concolicTests": concolic_tests,
201+
"optimizationImpact": optimization_impact,
200202
}
201203
return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload)
202204

@@ -212,6 +214,7 @@ def create_staging(
212214
replay_tests: str,
213215
concolic_tests: str,
214216
root_dir: Path,
217+
optimization_impact: str = "",
215218
) -> Response:
216219
"""Create a staging pull request, targeting the specified branch. (usually 'staging').
217220
@@ -252,6 +255,7 @@ def create_staging(
252255
"coverage_message": coverage_message,
253256
"replayTests": replay_tests,
254257
"concolicTests": concolic_tests,
258+
"optimizationImpact": optimization_impact,
255259
}
256260

257261
return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload)

codeflash/code_utils/code_extractor.py

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# ruff: noqa: ARG002
21
from __future__ import annotations
32

43
import ast
@@ -10,15 +9,13 @@
109

1110
import jedi
1211
import libcst as cst
13-
import radon.visitors
1412
from libcst.codemod import CodemodContext
1513
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
1614
from libcst.helpers import calculate_module_and_package
17-
from radon.complexity import cc_visit
1815

1916
from codeflash.cli_cmds.console import logger
2017
from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_IMPACT, TIME_LIMIT_FOR_OPT_IMPACT
21-
from codeflash.models.models import CodePosition, FunctionParent, ImpactMetrics
18+
from codeflash.models.models import CodePosition, FunctionParent
2219

2320
if TYPE_CHECKING:
2421
from libcst.helpers import ModuleNameAndPackage
@@ -38,28 +35,28 @@ def __init__(self) -> None:
3835
self.scope_depth = 0
3936
self.if_else_depth = 0
4037

41-
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
38+
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: # noqa: ARG002
4239
self.scope_depth += 1
4340
return True
4441

45-
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
42+
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
4643
self.scope_depth -= 1
4744

4845
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
4946
self.scope_depth += 1
5047
return True
5148

52-
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
49+
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
5350
self.scope_depth -= 1
5451

55-
def visit_If(self, node: cst.If) -> Optional[bool]:
52+
def visit_If(self, node: cst.If) -> Optional[bool]: # noqa: ARG002
5653
self.if_else_depth += 1
5754
return True
5855

59-
def leave_If(self, original_node: cst.If) -> None:
56+
def leave_If(self, original_node: cst.If) -> None: # noqa: ARG002
6057
self.if_else_depth -= 1
6158

62-
def visit_Else(self, node: cst.Else) -> Optional[bool]:
59+
def visit_Else(self, node: cst.Else) -> Optional[bool]: # noqa: ARG002
6360
# Else blocks are already counted as part of the if statement
6461
return True
6562

@@ -86,21 +83,21 @@ def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order:
8683
self.scope_depth = 0
8784
self.if_else_depth = 0
8885

89-
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
86+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
9087
self.scope_depth += 1
9188

9289
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
9390
self.scope_depth -= 1
9491
return updated_node
9592

96-
def visit_ClassDef(self, node: cst.ClassDef) -> None:
93+
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
9794
self.scope_depth += 1
9895

99-
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
96+
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
10097
self.scope_depth -= 1
10198
return updated_node
10299

103-
def visit_If(self, node: cst.If) -> None:
100+
def visit_If(self, node: cst.If) -> None: # noqa: ARG002
104101
self.if_else_depth += 1
105102

106103
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
@@ -1156,8 +1153,7 @@ def get_fn_references_jedi(
11561153

11571154
def get_opt_impact_metrics(
11581155
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path
1159-
) -> ImpactMetrics:
1160-
metrics = ImpactMetrics()
1156+
) -> str:
11611157
try:
11621158
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)
11631159
if len(qualified_name_split) == 1:
@@ -1167,26 +1163,8 @@ def get_opt_impact_metrics(
11671163
matches = get_fn_references_jedi(
11681164
source_code, file_path, project_root, target_function, target_class
11691165
) # jedi is not perfect, it doesn't capture aliased references
1170-
cyclomatic_complexity_results = cc_visit(source_code)
1171-
match_found = False
1172-
for result in cyclomatic_complexity_results:
1173-
if match_found:
1174-
break
1175-
if isinstance(result, radon.visitors.Function) and not target_class:
1176-
if result.name == target_function:
1177-
metrics.cyclomatic_complexity = result.complexity
1178-
metrics.cyclomatic_complexity_rating = result.letter
1179-
match_found = True
1180-
elif isinstance(result, radon.visitors.Class) and target_class: # noqa: SIM102
1181-
if result.name == target_class:
1182-
for method in result.methods:
1183-
if match_found:
1184-
break
1185-
if method.name == target_function:
1186-
metrics.cyclomatic_complexity = method.complexity
1187-
metrics.cyclomatic_complexity_rating = method.letter
1188-
match_found = True
1189-
metrics.calling_fns = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root)
1166+
calling_fns_details = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root)
11901167
except Exception as e:
1168+
calling_fns_details = ""
11911169
logger.debug(f"Investigate {e}")
1192-
return metrics
1170+
return calling_fns_details

codeflash/optimization/function_optimizer.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from codeflash.benchmarking.utils import process_benchmark_data
2525
from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar
2626
from codeflash.code_utils import env_utils
27+
from codeflash.code_utils.code_extractor import get_opt_impact_metrics
2728
from codeflash.code_utils.code_replacer import (
2829
add_custom_marker_to_all_tests,
2930
modify_autouse_fixture,
@@ -1461,12 +1462,36 @@ def process_review(
14611462

14621463
if raise_pr or staging_review:
14631464
data["root_dir"] = git_root_dir()
1465+
calling_fn_details = get_opt_impact_metrics(
1466+
self.function_to_optimize_source_code,
1467+
self.function_to_optimize.file_path,
1468+
self.function_to_optimize.qualified_name,
1469+
self.project_root,
1470+
self.test_cfg.tests_root,
1471+
)
14641472
opt_impact_response = ""
14651473
try:
1466-
opt_impact_response = self.aiservice_client.get_optimization_impact(**data)
1474+
opt_impact_response = self.aiservice_client.get_optimization_impact(
1475+
**data, calling_fn_details=calling_fn_details
1476+
)
14671477
except Exception as e:
14681478
logger.debug(f"optimization impact response failed, investigate {e}")
1469-
data["optimization_impact"] = opt_impact_response
1479+
data["optimization_impact"] = opt_impact_response[0]
1480+
new_explanation_with_opt_explanation = Explanation(
1481+
raw_explanation_message=f"Impact: {opt_impact_response[0]}\n Impact_explanation: {opt_impact_response[1]} END OF IMPACT EXPLANATION\n"
1482+
+ new_explanation.raw_explanation_message,
1483+
winning_behavior_test_results=explanation.winning_behavior_test_results,
1484+
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
1485+
original_runtime_ns=explanation.original_runtime_ns,
1486+
best_runtime_ns=explanation.best_runtime_ns,
1487+
function_name=explanation.function_name,
1488+
file_path=explanation.file_path,
1489+
benchmark_details=explanation.benchmark_details,
1490+
original_async_throughput=explanation.original_async_throughput,
1491+
best_async_throughput=explanation.best_async_throughput,
1492+
)
1493+
best_optimization.explanation_v2 = new_explanation_with_opt_explanation.explanation_message()
1494+
data["explanation"] = new_explanation_with_opt_explanation
14701495
if raise_pr and not staging_review:
14711496
data["git_remote"] = self.args.git_remote
14721497
check_create_pr(**data)

codeflash/result/create_pr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def check_create_pr(
277277
coverage_message=coverage_message,
278278
replay_tests=replay_tests,
279279
concolic_tests=concolic_tests,
280+
optimization_impact=optimization_impact,
280281
)
281282
if response.ok:
282283
pr_id = response.text

uv.lock

Lines changed: 27 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)