Skip to content

Commit 52dbe88

Browse files
committed
bugfix
1 parent 3e106cb commit 52dbe88

File tree

6 files changed

+53
-6
lines changed

6 files changed

+53
-6
lines changed

codeflash/api/aiservice.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def generate_regression_tests( # noqa: D417
405405
test_timeout: int,
406406
trace_id: str,
407407
test_index: int,
408-
) -> tuple[str, str, str] | None:
408+
) -> tuple[str, str, str, str | None, str | None] | None:
409409
"""Generate regression tests for the given function by making a request to the Django endpoint.
410410
411411
Parameters
@@ -456,6 +456,8 @@ def generate_regression_tests( # noqa: D417
456456
response_json["generated_tests"],
457457
response_json["instrumented_behavior_tests"],
458458
response_json["instrumented_perf_tests"],
459+
response_json.get("instrumented_behavior_source"),
460+
response_json.get("instrumented_perf_source"),
459461
)
460462
try:
461463
error = response.json()["error"]

codeflash/code_utils/codeflash_wrap_decorator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,3 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
160160
return return_value
161161

162162
return async_wrapper
163-

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def find_and_update_line_node(
7777
call_node = node
7878
if isinstance(node.func, ast.Name):
7979
function_name = node.func.id
80+
81+
if self.function_object.is_async:
82+
return [test_node]
83+
8084
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
8185
node.args = [
8286
ast.Name(id=function_name, ctx=ast.Load()),
@@ -98,6 +102,9 @@ def find_and_update_line_node(
98102
if isinstance(node.func, ast.Attribute):
99103
function_to_test = node.func.attr
100104
if function_to_test == self.function_object.function_name:
105+
if self.function_object.is_async:
106+
return [test_node]
107+
101108
function_name = ast.unparse(node.func)
102109
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
103110
node.args = [
@@ -382,7 +389,11 @@ def inject_profiling_into_existing_test(
382389
)
383390
if test_framework == "unittest":
384391
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
385-
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
392+
additional_functions = []
393+
if not func.is_async:
394+
additional_functions.append(create_wrapper_function(mode))
395+
396+
tree.body = [*new_imports, *additional_functions, *tree.body]
386397
return True, isort.code(ast.unparse(tree), float_to_top=True)
387398

388399

codeflash/models/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,8 @@ class GeneratedTests(BaseModel):
282282
instrumented_perf_test_source: str
283283
behavior_file_path: Path
284284
perf_file_path: Path
285+
behavior_source_path: Optional[Path] = None
286+
perf_source_path: Optional[Path] = None
285287

286288

287289
class GeneratedTestsList(BaseModel):
@@ -295,6 +297,8 @@ class TestFile(BaseModel):
295297
original_source: Optional[str] = None
296298
test_type: TestType
297299
tests_in_file: Optional[list[TestsInFile]] = None
300+
behavior_source_path: Optional[Path] = None
301+
perf_source_path: Optional[Path] = None
298302

299303

300304
class TestFiles(BaseModel):

codeflash/optimization/function_optimizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ def generate_and_instrument_tests(
323323
original_source=generated_test.generated_original_test_source,
324324
test_type=TestType.GENERATED_REGRESSION,
325325
tests_in_file=None, # This is currently unused. We can discover the tests in the file if needed.
326+
behavior_source_path=generated_test.behavior_source_path,
327+
perf_source_path=generated_test.perf_source_path,
326328
)
327329
)
328330
logger.info(f"Generated test {i + 1}/{count_tests}:")
@@ -1022,6 +1024,8 @@ def generate_tests_and_optimizations(
10221024
instrumented_perf_test_source,
10231025
test_behavior_path,
10241026
test_perf_path,
1027+
behavior_source_path,
1028+
perf_source_path,
10251029
) = res
10261030
tests.append(
10271031
GeneratedTests(
@@ -1030,6 +1034,8 @@ def generate_tests_and_optimizations(
10301034
instrumented_perf_test_source=instrumented_perf_test_source,
10311035
behavior_file_path=test_behavior_path,
10321036
perf_file_path=test_perf_path,
1037+
behavior_source_path=behavior_source_path,
1038+
perf_source_path=perf_source_path,
10331039
)
10341040
)
10351041
if not tests:

codeflash/verification/verifier.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def generate_tests(
2727
test_index: int,
2828
test_path: Path,
2929
test_perf_path: Path,
30-
) -> tuple[str, str, Path] | None:
30+
) -> tuple[str, str, str, Path, Path, Path | None, Path | None] | None:
3131
# TODO: Sometimes this recreates the original Class definition. This overrides and messes up the original
3232
# class import. Remove the recreation of the class definition
3333
start_time = time.perf_counter()
@@ -43,8 +43,8 @@ def generate_tests(
4343
trace_id=function_trace_id,
4444
test_index=test_index,
4545
)
46-
if response and isinstance(response, tuple) and len(response) == 3:
47-
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = response
46+
if response and isinstance(response, tuple) and len(response) == 5:
47+
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source, instrumented_behavior_source, instrumented_perf_source = response
4848
temp_run_dir = get_run_tmp_file(Path()).as_posix()
4949

5050
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(
@@ -53,6 +53,29 @@ def generate_tests(
5353
instrumented_perf_test_source = instrumented_perf_test_source.replace(
5454
"{codeflash_run_tmp_dir_client_side}", temp_run_dir
5555
)
56+
57+
behavior_source_path = None
58+
perf_source_path = None
59+
60+
if instrumented_behavior_source and function_to_optimize.is_async:
61+
behavior_source_path = get_run_tmp_file(Path(f"instrumented_behavior_source_{test_index}_{function_to_optimize.function_name}.py"))
62+
try:
63+
with behavior_source_path.open("w", encoding="utf8") as f:
64+
f.write(instrumented_behavior_source)
65+
logger.debug(f"Written instrumented behavior source to {behavior_source_path}")
66+
except Exception as e:
67+
logger.warning(f"Failed to write instrumented behavior source file: {e}")
68+
behavior_source_path = None
69+
70+
if instrumented_perf_source and function_to_optimize.is_async:
71+
perf_source_path = get_run_tmp_file(Path(f"instrumented_perf_source_{test_index}_{function_to_optimize.function_name}.py"))
72+
try:
73+
with perf_source_path.open("w", encoding="utf8") as f:
74+
f.write(instrumented_perf_source)
75+
logger.debug(f"Written instrumented performance source to {perf_source_path}")
76+
except Exception as e:
77+
logger.warning(f"Failed to write instrumented performance source file: {e}")
78+
perf_source_path = None
5679
else:
5780
logger.warning(f"Failed to generate and instrument tests for {function_to_optimize.function_name}")
5881
return None
@@ -64,6 +87,8 @@ def generate_tests(
6487
instrumented_perf_test_source,
6588
test_path,
6689
test_perf_path,
90+
behavior_source_path,
91+
perf_source_path,
6792
)
6893

6994

0 commit comments

Comments
 (0)