Skip to content

Commit 12c35fa

Browse files
committed
fix for unittest framework
1 parent c596e12 commit 12c35fa

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

codeflash/api/aiservice.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@ def log_results( # noqa: D417
306306
"optimized_runtime": optimized_runtime,
307307
"is_correct": is_correct,
308308
"codeflash_version": codeflash_version,
309+
<<<<<<< Updated upstream
310+
=======
311+
"best_optimization_id": best_optimization_id,
312+
>>>>>>> Stashed changes
309313
"optimized_line_profiler_results": optimized_line_profiler_results,
310314
}
311315
try:

codeflash/optimization/function_optimizer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,11 @@ def determine_best_candidate(
361361
file_path_to_helper_classes: dict[Path, set[str]],
362362
exp_type: str,
363363
) -> BestOptimization | None:
364+
<<<<<<< Updated upstream
365+
=======
366+
# TODO remove
367+
368+
>>>>>>> Stashed changes
364369
best_optimization: BestOptimization | None = None
365370
_best_runtime_until_now = original_code_baseline.runtime
366371

@@ -598,6 +603,10 @@ def determine_best_candidate(
598603
original_runtime=original_code_baseline.runtime,
599604
optimized_runtime=optimized_runtimes,
600605
is_correct=is_correct,
606+
<<<<<<< Updated upstream
607+
=======
608+
best_optimization_id=best_optimization.candidate.optimization_id,
609+
>>>>>>> Stashed changes
601610
optimized_line_profiler_results=optimized_line_profiler_results,
602611
)
603612
return best_optimization

codeflash/verification/test_runner.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def run_line_profile_tests(
152152
test_framework: str,
153153
*,
154154
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME,
155-
verbose: bool = False, # noqa: ARG001
155+
verbose: bool = False,
156156
pytest_timeout: int | None = None,
157157
pytest_min_loops: int = 5, # noqa: ARG001
158158
pytest_max_loops: int = 100_000, # noqa: ARG001
@@ -200,6 +200,30 @@ def run_line_profile_tests(
200200
env=pytest_test_env,
201201
timeout=600, # TODO: Make this dynamic
202202
)
203+
elif test_framework == "unittest":
204+
test_env["CODEFLASH_LOOP_INDEX"] = "1"
205+
test_env["LINE_PROFILE"] = "1"
206+
test_files: list[str] = []
207+
for file in test_paths.test_files:
208+
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
209+
test_files.extend(
210+
[
211+
str(file.benchmarking_file_path)
212+
+ "::"
213+
+ (test.test_class + "::" if test.test_class else "")
214+
+ (test.test_function.split("[", 1)[0] if "[" in test.test_function else test.test_function)
215+
for test in file.tests_in_file
216+
]
217+
)
218+
else:
219+
test_files.append(str(file.benchmarking_file_path))
220+
test_files = list(set(test_files)) # remove multiple calls in the same test function
221+
line_profiler_output_file, results = run_unittest_tests(
222+
verbose=verbose, test_file_paths=[Path(file) for file in test_files], test_env=test_env, cwd=cwd
223+
)
224+
logger.debug(
225+
f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ""}"""
226+
)
203227
else:
204228
msg = f"Unsupported test framework: {test_framework}"
205229
raise ValueError(msg)

0 commit comments

Comments
 (0)