Skip to content

Commit 2200b21

Browse files
committed
bring over changes from other branch
1 parent 2d1696a commit 2200b21

File tree

2 files changed

+114
-17
lines changed

2 files changed

+114
-17
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 91 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,6 +1472,19 @@ def establish_original_code_baseline(
14721472

14731473
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
14741474

1475+
if self.function_to_optimize.is_async:
1476+
from codeflash.code_utils.instrument_existing_tests import (
1477+
instrument_source_module_with_async_decorators,
1478+
)
1479+
1480+
success, instrumented_source = instrument_source_module_with_async_decorators(
1481+
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
1482+
)
1483+
if success and instrumented_source:
1484+
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
1485+
f.write(instrumented_source)
1486+
logger.debug(f"Applied async instrumentation to {self.function_to_optimize.file_path}")
1487+
14751488
# Instrument codeflash capture
14761489
with progress_bar("Running tests to establish original code behavior..."):
14771490
try:
@@ -1511,15 +1524,38 @@ def establish_original_code_baseline(
15111524
)
15121525
console.rule()
15131526
with progress_bar("Running performance benchmarks..."):
1514-
benchmarking_results, _ = self.run_and_parse_tests(
1515-
testing_type=TestingMode.PERFORMANCE,
1516-
test_env=test_env,
1517-
test_files=self.test_files,
1518-
optimization_iteration=0,
1519-
testing_time=total_looping_time,
1520-
enable_coverage=False,
1521-
code_context=code_context,
1522-
)
1527+
if self.function_to_optimize.is_async:
1528+
from codeflash.code_utils.instrument_existing_tests import (
1529+
instrument_source_module_with_async_decorators,
1530+
)
1531+
1532+
success, instrumented_source = instrument_source_module_with_async_decorators(
1533+
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
1534+
)
1535+
if success and instrumented_source:
1536+
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
1537+
f.write(instrumented_source)
1538+
logger.debug(
1539+
f"Applied async performance instrumentation to {self.function_to_optimize.file_path}"
1540+
)
1541+
1542+
try:
1543+
benchmarking_results, _ = self.run_and_parse_tests(
1544+
testing_type=TestingMode.PERFORMANCE,
1545+
test_env=test_env,
1546+
test_files=self.test_files,
1547+
optimization_iteration=0,
1548+
testing_time=total_looping_time,
1549+
enable_coverage=False,
1550+
code_context=code_context,
1551+
)
1552+
finally:
1553+
if self.function_to_optimize.is_async:
1554+
self.write_code_and_helpers(
1555+
self.function_to_optimize_source_code,
1556+
original_helper_code,
1557+
self.function_to_optimize.file_path,
1558+
)
15231559
else:
15241560
benchmarking_results = TestResults()
15251561
start_time: float = time.time()
@@ -1614,6 +1650,21 @@ def run_optimized_candidate(
16141650
candidate_helper_code = {}
16151651
for module_abspath in original_helper_code:
16161652
candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8")
1653+
if self.function_to_optimize.is_async:
1654+
from codeflash.code_utils.instrument_existing_tests import (
1655+
instrument_source_module_with_async_decorators,
1656+
)
1657+
1658+
success, instrumented_source = instrument_source_module_with_async_decorators(
1659+
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
1660+
)
1661+
if success and instrumented_source:
1662+
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
1663+
f.write(instrumented_source)
1664+
logger.debug(
1665+
f"Applied async behavioral instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
1666+
)
1667+
16171668
try:
16181669
instrument_codeflash_capture(
16191670
self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
@@ -1651,14 +1702,37 @@ def run_optimized_candidate(
16511702
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
16521703

16531704
if test_framework == "pytest":
1654-
candidate_benchmarking_results, _ = self.run_and_parse_tests(
1655-
testing_type=TestingMode.PERFORMANCE,
1656-
test_env=test_env,
1657-
test_files=self.test_files,
1658-
optimization_iteration=optimization_candidate_index,
1659-
testing_time=total_looping_time,
1660-
enable_coverage=False,
1661-
)
1705+
# For async functions, instrument at definition site for performance benchmarking
1706+
if self.function_to_optimize.is_async:
1707+
from codeflash.code_utils.instrument_existing_tests import (
1708+
instrument_source_module_with_async_decorators,
1709+
)
1710+
1711+
success, instrumented_source = instrument_source_module_with_async_decorators(
1712+
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
1713+
)
1714+
if success and instrumented_source:
1715+
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
1716+
f.write(instrumented_source)
1717+
logger.debug(
1718+
f"Applied async performance instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
1719+
)
1720+
1721+
try:
1722+
candidate_benchmarking_results, _ = self.run_and_parse_tests(
1723+
testing_type=TestingMode.PERFORMANCE,
1724+
test_env=test_env,
1725+
test_files=self.test_files,
1726+
optimization_iteration=optimization_candidate_index,
1727+
testing_time=total_looping_time,
1728+
enable_coverage=False,
1729+
)
1730+
finally:
1731+
# Restore original source if we instrumented it
1732+
if self.function_to_optimize.is_async:
1733+
self.write_code_and_helpers(
1734+
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
1735+
)
16621736
loop_count = (
16631737
max(all_loop_indices)
16641738
if (

codeflash/verification/pytest_plugin.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,3 +450,26 @@ def make_progress_id(i: int, n: int = count) -> str:
450450
metafunc.parametrize(
451451
"__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope
452452
)
453+
454+
@pytest.hookimpl(tryfirst=True)
455+
def pytest_runtest_setup(self, item: pytest.Item) -> None:
456+
"""Set test context environment variables before each test."""
457+
test_module_name = item.module.__name__ if item.module else "unknown_module"
458+
459+
test_class_name = None
460+
if item.cls:
461+
test_class_name = item.cls.__name__
462+
463+
test_function_name = item.name
464+
if "[" in test_function_name:
465+
test_function_name = test_function_name.split("[", 1)[0]
466+
467+
os.environ["CODEFLASH_TEST_MODULE"] = test_module_name
468+
os.environ["CODEFLASH_TEST_CLASS"] = test_class_name or ""
469+
os.environ["CODEFLASH_TEST_FUNCTION"] = test_function_name
470+
471+
@pytest.hookimpl(trylast=True)
472+
def pytest_runtest_teardown(self, item: pytest.Item) -> None: # noqa: ARG002
473+
"""Clean up test context environment variables after each test."""
474+
for var in ["CODEFLASH_TEST_MODULE", "CODEFLASH_TEST_CLASS", "CODEFLASH_TEST_FUNCTION"]:
475+
os.environ.pop(var, None)

0 commit comments

Comments
 (0)