diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index a4172afbe..20cf47b5f 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -12,7 +12,8 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name -from codeflash.models.models import OptimizedCandidate +from codeflash.models.ExperimentMetadata import ExperimentMetadata +from codeflash.models.models import AIServiceRefinerRequest, OptimizedCandidate from codeflash.telemetry.posthog_cf import ph from codeflash.version import __version__ as codeflash_version @@ -21,6 +22,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.ExperimentMetadata import ExperimentMetadata + from codeflash.models.models import AIServiceRefinerRequest class AiServiceClient: @@ -36,7 +38,11 @@ def get_aiservice_base_url(self) -> str: return "https://app.codeflash.ai" def make_ai_service_request( - self, endpoint: str, method: str = "POST", payload: dict[str, Any] | None = None, timeout: float | None = None + self, + endpoint: str, + method: str = "POST", + payload: dict[str, Any] | list[dict[str, Any]] | None = None, + timeout: float | None = None, ) -> requests.Response: """Make an API request to the given endpoint on the AI service. @@ -98,11 +104,7 @@ def optimize_python_code( # noqa: D417 """ start_time = time.perf_counter() - try: - git_repo_owner, git_repo_name = get_repo_owner_and_name() - except Exception as e: - logger.warning(f"Could not determine repo owner and name: {e}") - git_repo_owner, git_repo_name = None, None + git_repo_owner, git_repo_name = safe_get_repo_owner_and_name() payload = { "source_code": source_code, @@ -219,6 +221,63 @@ def optimize_python_code_line_profiler( # noqa: D417 console.rule() return [] + def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Args: + request: A list of optimization candidate details for refinement + + Returns: + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + payload = [ + { + "optimization_id": opt.optimization_id, + "original_source_code": opt.original_source_code, + "read_only_dependency_code": opt.read_only_dependency_code, + "original_line_profiler_results": opt.original_line_profiler_results, + "original_code_runtime": opt.original_code_runtime, + "optimized_source_code": opt.optimized_source_code, + "optimized_explanation": opt.optimized_explanation, + "optimized_line_profiler_results": opt.optimized_line_profiler_results, + "optimized_code_runtime": opt.optimized_code_runtime, + "speedup": opt.speedup, + "trace_id": opt.trace_id, + } + for opt in request + ] + logger.info(f"Refining {len(request)} optimizations…") + console.rule() + try: + response = self.make_ai_service_request("/refinement", payload=payload, timeout=600) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating optimization refinements: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return [] + + if response.status_code == 200: + refined_optimizations = response.json()["refinements"] + logger.info(f"Generated {len(refined_optimizations)} candidate refinements.") + console.rule() + return [ + OptimizedCandidate( + source_code=opt["source_code"], + explanation=opt["explanation"], + optimization_id=opt["optimization_id"][:-4] + "refi", + ) + for opt in refined_optimizations + ] + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return [] + def log_results( # noqa: D417 self, function_trace_id: str, @@ -226,6 +285,8 @@ def log_results( # noqa: D417 original_runtime: float | None, optimized_runtime: dict[str, float | None] | None, is_correct: dict[str, bool] | None, + optimized_line_profiler_results: dict[str, str] | None, + metadata: dict[str, Any] | None, ) -> None: """Log features to the database. @@ -236,6 +297,8 @@ def log_results( # noqa: D417 - original_runtime (Optional[Dict[str, float]]): The original runtime. - optimized_runtime (Optional[Dict[str, float]]): The optimized runtime. - is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct. + - optimized_line_profiler_results: line_profiler results for every candidate mapped to their optimization_id + - metadata: contains the best optimization id """ payload = { @@ -245,6 +308,8 @@ def log_results( # noqa: D417 "optimized_runtime": optimized_runtime, "is_correct": is_correct, "codeflash_version": codeflash_version, + "optimized_line_profiler_results": optimized_line_profiler_results, + "metadata": metadata, } try: self.make_ai_service_request("/log_features", payload=payload, timeout=5) @@ -331,3 +396,12 @@ class LocalAiServiceClient(AiServiceClient): def get_aiservice_base_url(self) -> str: """Get the base URL for the local AI service.""" return "http://localhost:8000" + + +def safe_get_repo_owner_and_name() -> tuple[str | None, str | None]: + try: + git_repo_owner, git_repo_name = get_repo_owner_and_name() + except Exception as e: + logger.warning(f"Could not determine repo owner and name: {e}") + git_repo_owner, git_repo_name = None, None + return git_repo_owner, git_repo_name diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 59c75a00b..dfd79a76b 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import difflib import os import re import shutil @@ -19,6 +20,50 @@ ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE) +def diff_length(a: str, b: str) -> int: + """Compute the length (in characters) of the unified diff between two strings. + + Args: + a (str): Original string. + b (str): Modified string. + + Returns: + int: Total number of characters in the diff. + + """ + # Split input strings into lines for line-by-line diff + a_lines = a.splitlines(keepends=True) + b_lines = b.splitlines(keepends=True) + + # Compute unified diff + diff_lines = list(difflib.unified_diff(a_lines, b_lines, lineterm="")) + + # Join all lines with newline to calculate total diff length + diff_text = "\n".join(diff_lines) + + return len(diff_text) + + +def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]: + """Create a dictionary from a list of ints, mapping the original index to its rank. + + This version uses a more compact, "Pythonic" implementation. + + Args: + int_array: A list of integers. + + Returns: + A dictionary where keys are original indices and values are the + rank of the element in ascending order. + + """ + # Sort the indices of the array based on their corresponding values + sorted_indices = sorted(range(len(int_array)), key=lambda i: int_array[i]) + + # Create a dictionary mapping the original index to its rank (its position in the sorted list) + return {original_index: rank for rank, original_index in enumerate(sorted_indices)} + + @contextmanager def custom_addopts() -> None: pyproject_file = find_pyproject_toml() diff --git a/codeflash/models/models.py b/codeflash/models/models.py index e96d12423..369fd51bd 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -27,6 +27,22 @@ from codeflash.code_utils.env_utils import is_end_to_end from codeflash.verification.comparator import comparator + +@dataclass(frozen=True) +class AIServiceRefinerRequest: + optimization_id: str + original_source_code: str + read_only_dependency_code: str + original_code_runtime: str + optimized_source_code: str + optimized_explanation: str + optimized_code_runtime: str + speedup: str + trace_id: str + original_line_profiler_results: str + optimized_line_profiler_results: str + + # If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully # qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name # of the module is foo.eggs. @@ -76,11 +92,13 @@ def __hash__(self) -> int: class BestOptimization(BaseModel): candidate: OptimizedCandidate helper_functions: list[FunctionSource] + code_context: CodeOptimizationContext runtime: int replay_performance_gain: Optional[dict[BenchmarkKey, float]] = None winning_behavior_test_results: TestResults winning_benchmarking_test_results: TestResults winning_replay_benchmarking_test_results: Optional[TestResults] = None + line_profiler_test_results: dict @dataclass(frozen=True) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index ef7b215bc..1844819cc 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -18,7 +18,7 @@ from rich.syntax import Syntax from rich.tree import Tree -from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient +from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient from codeflash.api.cfapi import add_code_context_hash, mark_optimization_success from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, progress_bar @@ -31,6 +31,8 @@ from codeflash.code_utils.code_utils import ( ImportErrorPattern, cleanup_paths, + create_rank_dictionary_compact, + diff_length, file_name_from_test_module_name, get_run_tmp_file, has_any_async_functions, @@ -146,6 +148,9 @@ def __init__( self.generate_and_instrument_tests_results: ( tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet] | None ) = None + self.valid_optimizations: list[BestOptimization] = ( + list() # TODO: Figure out the dataclass type for this # noqa: C408 + ) def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -357,11 +362,12 @@ def determine_best_candidate( exp_type: str, ) -> BestOptimization | None: best_optimization: BestOptimization | None = None - best_runtime_until_now = original_code_baseline.runtime + _best_runtime_until_now = original_code_baseline.runtime speedup_ratios: dict[str, float | None] = {} optimized_runtimes: dict[str, float | None] = {} is_correct = {} + optimized_line_profiler_results: dict[str, str] = {} logger.info( f"Determining best optimization candidate (out of {len(candidates)}) for " @@ -369,6 +375,7 @@ def determine_best_candidate( ) console.rule() candidates = deque(candidates) + refinement_done = False # Start a new thread for AI service request, start loop in main thread # check if aiservice request is complete, when it is complete, append result to the candidates list with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: @@ -390,8 +397,11 @@ def determine_best_candidate( candidate_index = 0 original_len = len(candidates) while candidates: - done = True if future_line_profile_results is None else future_line_profile_results.done() - if done and (future_line_profile_results is not None): + candidate_index += 1 + line_profiler_done = ( + True if future_line_profile_results is None else future_line_profile_results.done() + ) + if line_profiler_done and (future_line_profile_results is not None): line_profile_results = future_line_profile_results.result() candidates.extend(line_profile_results) original_len += len(line_profile_results) @@ -400,7 +410,6 @@ def determine_best_candidate( ) future_line_profile_results = None candidate = candidates.popleft() - candidate_index += 1 get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) logger.info(f"Optimization candidate {candidate_index}/{original_len}:") @@ -451,9 +460,11 @@ def determine_best_candidate( tree = Tree(f"Candidate #{candidate_index} - Runtime Information") benchmark_tree = None if speedup_critic( - candidate_result, original_code_baseline.runtime, best_runtime_until_now + candidate_result, original_code_baseline.runtime, best_runtime_until_now=None ) and quantity_of_tests_critic(candidate_result): - tree.add("This candidate is faster than the previous best candidate. 🚀") + tree.add( + "This candidate is faster than the original code. 🚀" + ) # TODO: Change this description tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}") tree.add( f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} " @@ -462,6 +473,14 @@ def determine_best_candidate( ) tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") + line_profile_test_results = self.line_profiler_step( + code_context=code_context, + original_helper_code=original_helper_code, + candidate_index=candidate_index, + ) + optimized_line_profiler_results[candidate.optimization_id] = line_profile_test_results[ + "str_out" + ] replay_perf_gain = {} if self.args.benchmark: test_results_by_benchmark = ( @@ -487,13 +506,15 @@ def determine_best_candidate( best_optimization = BestOptimization( candidate=candidate, helper_functions=code_context.helper_functions, + code_context=code_context, runtime=best_test_runtime, + line_profiler_test_results=line_profile_test_results, winning_behavior_test_results=candidate_result.behavior_test_results, replay_performance_gain=replay_perf_gain if self.args.benchmark else None, winning_benchmarking_test_results=candidate_result.benchmarking_test_results, winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results, ) - best_runtime_until_now = best_test_runtime + self.valid_optimizations.append(best_optimization) else: tree.add( f"Summed runtime: {humanize_runtime(best_test_runtime)} " @@ -510,8 +531,9 @@ def determine_best_candidate( self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) + if (not len(candidates)) and ( - not done + not line_profiler_done ): # all original candidates processed but lp results haven't been processed concurrent.futures.wait([future_line_profile_results]) line_profile_results = future_line_profile_results.result() @@ -521,6 +543,26 @@ def determine_best_candidate( f"Added results from line profiler to candidates, total candidates now: {original_len}" ) future_line_profile_results = None + + if len(candidates) == 0 and len(self.valid_optimizations) > 0 and not refinement_done: + # TODO: Instead of doing it all at once at the end, do it one by one as the optimizations + # are found. This way we can hide the time waiting for the LLM results. + trace_id = self.function_trace_id + if trace_id.endswith(("EXP0", "EXP1")): + trace_id = trace_id[:-4] + exp_type + # refinement_response is a dataclass with optimization_id, code and explanation + refinement_response = self.refine_optimizations( + valid_optimizations=self.valid_optimizations, + original_code_baseline=original_code_baseline, + code_context=code_context, + trace_id=trace_id, + ai_service_client=ai_service_client, + executor=executor, + ) + candidates.extend(refinement_response) + print("Added candidates from refinement") + original_len += len(refinement_response) + refinement_done = True except KeyboardInterrupt as e: self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path @@ -528,15 +570,62 @@ def determine_best_candidate( logger.exception(f"Optimization interrupted: {e}") raise + if not len(self.valid_optimizations): + return None + # need to figure out the best candidate here before we return best_optimization + diff_lens_list = [] # character level diff + runtimes_list = [] + for valid_opt in self.valid_optimizations: + diff_lens_list.append( + diff_length(valid_opt.candidate.source_code, code_context.read_writable_code) + ) # char level diff + runtimes_list.append(valid_opt.runtime) + diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list) + runtimes_ranking = create_rank_dictionary_compact(runtimes_list) + # TODO: better way to resolve conflicts with same min ranking + overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118 + min_key = min(overall_ranking, key=overall_ranking.get) + best_optimization = self.valid_optimizations[min_key] ai_service_client.log_results( function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, speedup_ratio=speedup_ratios, original_runtime=original_code_baseline.runtime, optimized_runtime=optimized_runtimes, is_correct=is_correct, + optimized_line_profiler_results=optimized_line_profiler_results, + metadata={"best_optimization_id": best_optimization.candidate.optimization_id}, ) return best_optimization + def refine_optimizations( + self, + valid_optimizations: list[BestOptimization], + original_code_baseline: OriginalCodeBaseline, + code_context: CodeOptimizationContext, + trace_id: str, + ai_service_client: AiServiceClient, + executor: concurrent.futures.ThreadPoolExecutor, + ) -> list[OptimizedCandidate]: + request = [ + AIServiceRefinerRequest( + optimization_id=opt.candidate.optimization_id, + original_source_code=code_context.read_writable_code, + read_only_dependency_code=code_context.read_only_context_code, + original_code_runtime=humanize_runtime(original_code_baseline.runtime), + optimized_source_code=opt.candidate.source_code, + optimized_explanation=opt.candidate.explanation, + optimized_code_runtime=humanize_runtime(opt.runtime), + speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=opt.runtime) * 100)}%", + trace_id=trace_id, + original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], + optimized_line_profiler_results=opt.line_profiler_test_results["str_out"], + ) + for opt in valid_optimizations + ] # TODO: multiple workers for this? + future_refinement_results = executor.submit(ai_service_client.optimize_python_code_refinement, request=request) + concurrent.futures.wait([future_refinement_results]) + return future_refinement_results.result() + def log_successful_optimization( self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str ) -> None: @@ -1074,16 +1163,8 @@ def establish_original_code_baseline( assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 success = True - test_env = os.environ.copy() - test_env["CODEFLASH_TEST_ITERATION"] = "0" - test_env["CODEFLASH_TRACER_DISABLE"] = "1" - test_env["CODEFLASH_LOOP_INDEX"] = "0" - if "PYTHONPATH" not in test_env: - test_env["PYTHONPATH"] = str(self.args.project_root) - else: - test_env["PYTHONPATH"] += os.pathsep + str(self.args.project_root) + test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1) - coverage_results = None # Instrument codeflash capture try: instrument_codeflash_capture( @@ -1112,28 +1193,10 @@ def establish_original_code_baseline( if not coverage_critic(coverage_results, self.args.test_framework): return Failure("The threshold for test coverage was not met.") if test_framework == "pytest": - try: - line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context) - line_profile_results, _ = self.run_and_parse_tests( - testing_type=TestingMode.LINE_PROFILE, - test_env=test_env, - test_files=self.test_files, - optimization_iteration=0, - testing_time=TOTAL_LOOPING_TIME, - enable_coverage=False, - code_context=code_context, - line_profiler_output_file=line_profiler_output_file, - ) - finally: - # Remove codeflash capture - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path - ) - if line_profile_results["str_out"] == "": - logger.warning( - f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}" - ) - console.rule() + line_profile_results = self.line_profiler_step( + code_context=code_context, original_helper_code=original_helper_code, candidate_index=0 + ) + console.rule() benchmarking_results, _ = self.run_and_parse_tests( testing_type=TestingMode.PERFORMANCE, test_env=test_env, @@ -1229,14 +1292,11 @@ def run_optimized_candidate( assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 with progress_bar("Testing optimization candidate"): - test_env = os.environ.copy() - test_env["CODEFLASH_LOOP_INDEX"] = "0" - test_env["CODEFLASH_TEST_ITERATION"] = str(optimization_candidate_index) - test_env["CODEFLASH_TRACER_DISABLE"] = "1" - if "PYTHONPATH" not in test_env: - test_env["PYTHONPATH"] = str(self.project_root) - else: - test_env["PYTHONPATH"] += os.pathsep + str(self.project_root) + test_env = self.get_test_env( + codeflash_loop_index=0, + codeflash_test_iteration=optimization_candidate_index, + codeflash_tracer_disable=1, + ) get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True) # Instrument codeflash capture @@ -1470,3 +1530,45 @@ def cleanup_generated_files(self) -> None: paths_to_cleanup.append(test_file.benchmarking_file_path) cleanup_paths(paths_to_cleanup) + + def get_test_env( + self, codeflash_loop_index: int, codeflash_test_iteration: int, codeflash_tracer_disable: int = 1 + ) -> dict: + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = str(codeflash_test_iteration) + test_env["CODEFLASH_TRACER_DISABLE"] = str(codeflash_tracer_disable) + test_env["CODEFLASH_LOOP_INDEX"] = str(codeflash_loop_index) + if "PYTHONPATH" not in test_env: + test_env["PYTHONPATH"] = str(self.args.project_root) + else: + test_env["PYTHONPATH"] += os.pathsep + str(self.args.project_root) + return test_env + + def line_profiler_step( + self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int + ) -> dict: + try: + test_env = self.get_test_env( + codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1 + ) + line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context) + line_profile_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.LINE_PROFILE, + test_env=test_env, + test_files=self.test_files, + optimization_iteration=0, + testing_time=TOTAL_LOOPING_TIME, + enable_coverage=False, + code_context=code_context, + line_profiler_output_file=line_profiler_output_file, + ) + finally: + # Remove codeflash capture + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + if line_profile_results["str_out"] == "": + logger.warning( + f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}" + ) + return line_profile_results diff --git a/codeflash/result/critic.py b/codeflash/result/critic.py index 323482323..fa4a68b82 100644 --- a/codeflash/result/critic.py +++ b/codeflash/result/critic.py @@ -28,7 +28,7 @@ def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> def speedup_critic( candidate_result: OptimizedCandidateResult, original_code_runtime: int, - best_runtime_until_now: int, + best_runtime_until_now: int | None, disable_gh_action_noise: Optional[bool] = None, ) -> bool: """Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user. @@ -47,6 +47,9 @@ def speedup_critic( perf_gain = performance_gain( original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime ) + if best_runtime_until_now is None: + # collect all optimizations with this + return bool(perf_gain > noise_floor) return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now) diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index f0d1c01a5..85e347641 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -152,7 +152,7 @@ def run_line_profile_tests( test_framework: str, *, pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME, - verbose: bool = False, # noqa: ARG001 + verbose: bool = False, pytest_timeout: int | None = None, pytest_min_loops: int = 5, # noqa: ARG001 pytest_max_loops: int = 100_000, # noqa: ARG001 @@ -200,6 +200,30 @@ def run_line_profile_tests( env=pytest_test_env, timeout=600, # TODO: Make this dynamic ) + elif test_framework == "unittest": + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["LINE_PROFILE"] = "1" + test_files: list[str] = [] + for file in test_paths.test_files: + if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file: + test_files.extend( + [ + str(file.benchmarking_file_path) + + "::" + + (test.test_class + "::" if test.test_class else "") + + (test.test_function.split("[", 1)[0] if "[" in test.test_function else test.test_function) + for test in file.tests_in_file + ] + ) + else: + test_files.append(str(file.benchmarking_file_path)) + test_files = list(set(test_files)) # remove multiple calls in the same test function + line_profiler_output_file, results = run_unittest_tests( + verbose=verbose, test_file_paths=[Path(file) for file in test_files], test_env=test_env, cwd=cwd + ) + logger.debug( + f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ""}""" + ) else: msg = f"Unsupported test framework: {test_framework}" raise ValueError(msg)