Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions codeflash/api/cfapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def suggest_changes(
generated_tests: str,
trace_id: str,
coverage_message: str,
replay_tests: str = "",
concolic_tests: str = "",
) -> Response:
"""Suggest changes to a pull request.

Expand All @@ -147,6 +149,8 @@ def suggest_changes(
"generatedTests": generated_tests,
"traceId": trace_id,
"coverage_message": coverage_message,
"replayTests": replay_tests,
"concolicTests": concolic_tests,
}
return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload)

Expand All @@ -161,6 +165,8 @@ def create_pr(
generated_tests: str,
trace_id: str,
coverage_message: str,
replay_tests: str = "",
concolic_tests: str = "",
) -> Response:
"""Create a pull request, targeting the specified branch. (usually 'main').

Expand All @@ -183,6 +189,8 @@ def create_pr(
"generatedTests": generated_tests,
"traceId": trace_id,
"coverage_message": coverage_message,
"replayTests": replay_tests,
"concolicTests": concolic_tests,
}
return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload)

Expand Down
4 changes: 3 additions & 1 deletion codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@ def process_review(
if concolic_test_str:
generated_tests_str += "\n#------------------------------------------------\n" + concolic_test_str

existing_tests = existing_tests_source_for(
existing_tests, replay_tests, concolic_tests = existing_tests_source_for(
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
function_to_all_tests,
test_cfg=self.test_cfg,
Expand Down Expand Up @@ -1237,6 +1237,8 @@ def process_review(
if self.experiment_id
else self.function_trace_id,
"coverage_message": coverage_message,
"replay_tests": replay_tests,
"concolic_tests": concolic_tests,
}

if not self.args.no_pr and not self.args.staging_review:
Expand Down
85 changes: 73 additions & 12 deletions codeflash/result/create_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@ def existing_tests_source_for(
test_cfg: TestConfig,
original_runtimes_all: dict[InvocationId, list[int]],
optimized_runtimes_all: dict[InvocationId, list[int]],
) -> str:
) -> tuple[str, str, str]:
test_files = function_to_tests.get(function_qualified_name_with_modules_from_root)
if not test_files:
return ""
output: str = ""
rows = []
return "", "", ""
output_existing: str = ""
output_concolic: str = ""
output_replay: str = ""
rows_existing = []
rows_concolic = []
rows_replay = []
headers = ["Test File::Test Function", "Original ⏱️", "Optimized ⏱️", "Speedup"]
tests_root = test_cfg.tests_root
original_tests_to_runtimes: dict[Path, dict[str, int]] = {}
Expand Down Expand Up @@ -99,28 +103,79 @@ def existing_tests_source_for(
* 100
)
if greater:
rows.append(
if "__replay_test_" in str(print_filename):
rows_replay.append(
[
f"`{print_filename}::{qualified_name}`",
f"{print_original_runtime}",
f"{print_optimized_runtime}",
f"{perf_gain}%⚠️",
]
)
elif "codeflash_concolic" in str(print_filename):
rows_concolic.append(
[
f"`{print_filename}::{qualified_name}`",
f"{print_original_runtime}",
f"{print_optimized_runtime}",
f"{perf_gain}%⚠️",
]
)
else:
rows_existing.append(
[
f"`{print_filename}::{qualified_name}`",
f"{print_original_runtime}",
f"{print_optimized_runtime}",
f"{perf_gain}%⚠️",
]
)
elif "__replay_test_" in str(print_filename):
rows_replay.append(
[
f"`{print_filename}::{qualified_name}`",
f"{print_original_runtime}",
f"{print_optimized_runtime}",
f"⚠️{perf_gain}%",
f"{perf_gain}%✅",
]
)
elif "codeflash_concolic" in str(print_filename):
rows_concolic.append(
[
f"`{print_filename}::{qualified_name}`",
f"{print_original_runtime}",
f"{print_optimized_runtime}",
f"{perf_gain}%✅",
]
)
else:
rows.append(
rows_existing.append(
[
f"`{print_filename}::{qualified_name}`",
f"{print_original_runtime}",
f"{print_optimized_runtime}",
f"{perf_gain}%",
f"{perf_gain}%",
]
)
output += tabulate( # type: ignore[no-untyped-call]
headers=headers, tabular_data=rows, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
output_existing += tabulate( # type: ignore[no-untyped-call]
headers=headers, tabular_data=rows_existing, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
)
output_existing += "\n"
if len(rows_existing) == 0:
output_existing = ""
output_concolic += tabulate( # type: ignore[no-untyped-call]
headers=headers, tabular_data=rows_concolic, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
)
output_concolic += "\n"
if len(rows_concolic) == 0:
output_concolic = ""
output_replay += tabulate( # type: ignore[no-untyped-call]
headers=headers, tabular_data=rows_replay, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
)
output += "\n"
return output
output_replay += "\n"
if len(rows_replay) == 0:
output_replay = ""
return output_existing, output_replay, output_concolic


def check_create_pr(
Expand All @@ -131,6 +186,8 @@ def check_create_pr(
generated_original_test_source: str,
function_trace_id: str,
coverage_message: str,
replay_tests: str,
concolic_tests: str,
git_remote: Optional[str] = None,
) -> None:
pr_number: Optional[int] = env_utils.get_pr_number()
Expand Down Expand Up @@ -171,6 +228,8 @@ def check_create_pr(
generated_tests=generated_original_test_source,
trace_id=function_trace_id,
coverage_message=coverage_message,
replay_tests=replay_tests,
concolic_tests=concolic_tests,
)
if response.ok:
logger.info(f"Suggestions were successfully made to PR #{pr_number}")
Expand Down Expand Up @@ -218,6 +277,8 @@ def check_create_pr(
generated_tests=generated_original_test_source,
trace_id=function_trace_id,
coverage_message=coverage_message,
replay_tests=replay_tests,
concolic_tests=concolic_tests,
)
if response.ok:
pr_id = response.text
Expand Down
Loading
Loading