diff --git a/codeflash/discovery/pytest_new_process_discovery.py b/codeflash/discovery/pytest_new_process_discovery.py index 8eb29f16b..9f78818f4 100644 --- a/codeflash/discovery/pytest_new_process_discovery.py +++ b/codeflash/discovery/pytest_new_process_discovery.py @@ -1,6 +1,9 @@ # ruff: noqa import sys +from pathlib import Path from typing import Any +import pickle + # This script should not have any relation to the codeflash package, be careful with imports cwd = sys.argv[1] @@ -11,12 +14,29 @@ sys.path.insert(1, str(cwd)) +def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]: + test_results = [] + for test in pytest_tests: + test_class = None + if test.cls: + test_class = test.parent.name + test_results.append({"test_file": str(test.path), "test_class": test_class, "test_function": test.name}) + return test_results + + class PytestCollectionPlugin: def pytest_collection_finish(self, session) -> None: - global pytest_rootdir + global pytest_rootdir, collected_tests + collected_tests.extend(session.items) pytest_rootdir = session.config.rootdir + # Write results immediately since pytest.main() will exit after this callback, not always with a success code + tests = parse_pytest_collection_results(collected_tests) + exit_code = getattr(session.config, "exitstatus", 0) + with Path(pickle_path).open("wb") as f: + pickle.dump((exit_code, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL) + def pytest_collection_modifyitems(self, items) -> None: skip_benchmark = pytest.mark.skip(reason="Skipping benchmark tests") for item in items: @@ -24,31 +44,18 @@ def pytest_collection_modifyitems(self, items) -> None: item.add_marker(skip_benchmark) -def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]: - test_results = [] - for test in pytest_tests: - test_class = None - if test.cls: - test_class = test.parent.name - test_results.append({"test_file": str(test.path), "test_class": test_class, "test_function": test.name}) - return test_results - - if __name__ == "__main__": - from pathlib import Path - import pytest try: - exitcode = pytest.main( - [tests_root, "-p no:logging", "--collect-only", "-m", "not skip", "-p", "no:codeflash-benchmark"], + pytest.main( + [tests_root, "-p", "no:logging", "--collect-only", "-m", "not skip", "-p", "no:codeflash-benchmark"], plugins=[PytestCollectionPlugin()], ) except Exception as e: print(f"Failed to collect tests: {e!s}") - exitcode = -1 - tests = parse_pytest_collection_results(collected_tests) - import pickle - - with Path(pickle_path).open("wb") as f: - pickle.dump((exitcode, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL) + try: + with Path(pickle_path).open("wb") as f: + pickle.dump((-1, [], None), f, protocol=pickle.HIGHEST_PROTOCOL) + except Exception as pickle_error: + print(f"Failed to write failure pickle: {pickle_error!s}", file=sys.stderr)