|
8 | 8 | filter_test_files_by_imports, |
9 | 9 | ) |
10 | 10 | from codeflash.discovery.functions_to_optimize import FunctionToOptimize |
11 | | -from codeflash.models.models import TestsInFile, TestType |
| 11 | +from codeflash.models.models import TestsInFile, TestType, FunctionParent |
12 | 12 | from codeflash.verification.verification_utils import TestConfig |
13 | 13 |
|
14 | 14 |
|
@@ -714,6 +714,210 @@ def test_add_with_parameters(self): |
714 | 714 | assert calculator_test.tests_in_file.test_file.resolve() == test_file_path.resolve() |
715 | 715 | assert calculator_test.tests_in_file.test_function == "test_add_with_parameters" |
716 | 716 |
|
| 717 | +def test_unittest_discovery_with_pytest_fixture(): |
| 718 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 719 | + path_obj_tmpdirname = Path(tmpdirname) |
| 720 | + |
| 721 | + # Create a simple code file |
| 722 | + code_file_path = path_obj_tmpdirname / "topological_sort.py" |
| 723 | + code_file_content = """ |
| 724 | +import uuid |
| 725 | +from collections import defaultdict |
| 726 | +
|
| 727 | +
|
| 728 | +class Graph: |
| 729 | + def __init__(self, vertices: int): |
| 730 | + self.vertices=vertices |
| 731 | +
|
| 732 | + def dummy_fn(self): |
| 733 | + return 1 |
| 734 | +
|
| 735 | + def topologicalSort(self): |
| 736 | + return self.vertices |
| 737 | +
|
| 738 | +""" |
| 739 | + code_file_path.write_text(code_file_content) |
| 740 | + |
| 741 | + # Create a unittest test file with parameterized tests |
| 742 | + test_file_path = path_obj_tmpdirname / "test_topological_sort.py" |
| 743 | + test_file_content = """ |
| 744 | +from topological_sort import Graph |
| 745 | +import pytest |
| 746 | +
|
| 747 | +@pytest.fixture |
| 748 | +def g(): |
| 749 | + return Graph(6) |
| 750 | +
|
| 751 | +def test_topological_sort(g): |
| 752 | + assert g.dummy_fn() == 1 |
| 753 | + assert g.topologicalSort() == 6 |
| 754 | +""" |
| 755 | + test_file_path.write_text(test_file_content) |
| 756 | + |
| 757 | + # Configure test discovery |
| 758 | + test_config = TestConfig( |
| 759 | + tests_root=path_obj_tmpdirname, |
| 760 | + project_root_path=path_obj_tmpdirname, |
| 761 | + test_framework="pytest", # Using pytest framework to discover unittest tests |
| 762 | + tests_project_rootdir=path_obj_tmpdirname.parent, |
| 763 | + ) |
| 764 | + fto = FunctionToOptimize(function_name="topologicalSort", file_path=code_file_path, parents=[FunctionParent(name="Graph", type="ClassDef")]) |
| 765 | + # Discover tests |
| 766 | + discovered_tests, _, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file_path: [fto]}) |
| 767 | + |
| 768 | + # Verify the unittest was discovered |
| 769 | + assert len(discovered_tests) == 2 |
| 770 | + assert "topological_sort.Graph.topologicalSort" in discovered_tests |
| 771 | + assert len(discovered_tests["topological_sort.Graph.topologicalSort"]) == 1 |
| 772 | + tpsort_test = next(iter(discovered_tests["topological_sort.Graph.topologicalSort"])) |
| 773 | + assert tpsort_test.tests_in_file.test_file.resolve() == test_file_path.resolve() |
| 774 | + assert tpsort_test.tests_in_file.test_function == "test_topological_sort" |
| 775 | + |
| 776 | +def test_unittest_discovery_with_pytest_class_fixture(): |
| 777 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 778 | + path_obj_tmpdirname = Path(tmpdirname) |
| 779 | + |
| 780 | + # Create a simple code file |
| 781 | + code_file_path = path_obj_tmpdirname / "router_file.py" |
| 782 | + code_file_content = """ |
| 783 | +from __future__ import annotations |
| 784 | +
|
| 785 | +import hashlib |
| 786 | +import json |
| 787 | +
|
| 788 | +class Router: |
| 789 | + model_names: list |
| 790 | + cache_responses = False |
| 791 | + tenacity = None |
| 792 | +
|
| 793 | + def __init__( # noqa: PLR0915 |
| 794 | + self, |
| 795 | + model_list = None, |
| 796 | + ) -> None: |
| 797 | + self.model_list = model_list |
| 798 | + self.model_id_to_deployment_index_map = {} |
| 799 | + self.model_name_to_deployment_indices = {} |
| 800 | + def _generate_model_id(self, model_group, litellm_params): |
| 801 | + # Optimized: Use list and join instead of string concatenation in loop |
| 802 | + # This avoids creating many temporary string objects (O(n) vs O(n²) complexity) |
| 803 | + parts = [model_group] |
| 804 | + for k, v in litellm_params.items(): |
| 805 | + if isinstance(k, str): |
| 806 | + parts.append(k) |
| 807 | + elif isinstance(k, dict): |
| 808 | + parts.append(json.dumps(k)) |
| 809 | + else: |
| 810 | + parts.append(str(k)) |
| 811 | +
|
| 812 | + if isinstance(v, str): |
| 813 | + parts.append(v) |
| 814 | + elif isinstance(v, dict): |
| 815 | + parts.append(json.dumps(v)) |
| 816 | + else: |
| 817 | + parts.append(str(v)) |
| 818 | +
|
| 819 | + concat_str = "".join(parts) |
| 820 | + hash_object = hashlib.sha256(concat_str.encode()) |
| 821 | +
|
| 822 | + return hash_object.hexdigest() |
| 823 | + def _add_model_to_list_and_index_map( |
| 824 | + self, model, model_id = None |
| 825 | + ) -> None: |
| 826 | + idx = len(self.model_list) |
| 827 | + self.model_list.append(model) |
| 828 | +
|
| 829 | + # Update model_id index for O(1) lookup |
| 830 | + if model_id is not None: |
| 831 | + self.model_id_to_deployment_index_map[model_id] = idx |
| 832 | + elif model.get("model_info", {}).get("id") is not None: |
| 833 | + self.model_id_to_deployment_index_map[model["model_info"]["id"]] = idx |
| 834 | +
|
| 835 | + # Update model_name index for O(1) lookup |
| 836 | + model_name = model.get("model_name") |
| 837 | + if model_name: |
| 838 | + if model_name not in self.model_name_to_deployment_indices: |
| 839 | + self.model_name_to_deployment_indices[model_name] = [] |
| 840 | + self.model_name_to_deployment_indices[model_name].append(idx) |
| 841 | +
|
| 842 | + def _build_model_id_to_deployment_index_map(self, model_list): |
| 843 | + # First populate the model_list |
| 844 | + self.model_list = [] |
| 845 | + for _, model in enumerate(model_list): |
| 846 | + # Extract model_info from the model dict |
| 847 | + model_info = model.get("model_info", {}) |
| 848 | + model_id = model_info.get("id") |
| 849 | +
|
| 850 | + # If no ID exists, generate one using the same logic as set_model_list |
| 851 | + if model_id is None: |
| 852 | + model_name = model.get("model_name", "") |
| 853 | + litellm_params = model.get("litellm_params", {}) |
| 854 | + model_id = self._generate_model_id(model_name, litellm_params) |
| 855 | + # Update the model_info in the original list |
| 856 | + if "model_info" not in model: |
| 857 | + model["model_info"] = {} |
| 858 | + model["model_info"]["id"] = model_id |
| 859 | +
|
| 860 | + self._add_model_to_list_and_index_map(model=model, model_id=model_id) |
| 861 | +
|
| 862 | +""" |
| 863 | + code_file_path.write_text(code_file_content) |
| 864 | + |
| 865 | + # Create a unittest test file with parameterized tests |
| 866 | + test_file_path = path_obj_tmpdirname / "test_router_file.py" |
| 867 | + test_file_content = """ |
| 868 | +import pytest |
| 869 | +
|
| 870 | +from router_file import Router |
| 871 | +
|
| 872 | +
|
| 873 | +class TestRouterIndexManagement: |
| 874 | + @pytest.fixture |
| 875 | + def router(self): |
| 876 | + return Router(model_list=[]) |
| 877 | + def test_build_model_id_to_deployment_index_map(self, router): |
| 878 | + model_list = [ |
| 879 | + { |
| 880 | + "model_name": "gpt-3.5-turbo", |
| 881 | + "litellm_params": {"model": "gpt-3.5-turbo"}, |
| 882 | + "model_info": {"id": "model-1"}, |
| 883 | + }, |
| 884 | + { |
| 885 | + "model_name": "gpt-4", |
| 886 | + "litellm_params": {"model": "gpt-4"}, |
| 887 | + "model_info": {"id": "model-2"}, |
| 888 | + }, |
| 889 | + ] |
| 890 | +
|
| 891 | + # Test: Build index from model list |
| 892 | + router._build_model_id_to_deployment_index_map(model_list) |
| 893 | +
|
| 894 | + # Verify: model_list is populated |
| 895 | + assert len(router.model_list) == 2 |
| 896 | + # Verify: model_id_to_deployment_index_map is correctly built |
| 897 | + assert router.model_id_to_deployment_index_map["model-1"] == 0 |
| 898 | + assert router.model_id_to_deployment_index_map["model-2"] == 1 |
| 899 | +""" |
| 900 | + test_file_path.write_text(test_file_content) |
| 901 | + |
| 902 | + # Configure test discovery |
| 903 | + test_config = TestConfig( |
| 904 | + tests_root=path_obj_tmpdirname, |
| 905 | + project_root_path=path_obj_tmpdirname, |
| 906 | + test_framework="pytest", # Using pytest framework to discover unittest tests |
| 907 | + tests_project_rootdir=path_obj_tmpdirname.parent, |
| 908 | + ) |
| 909 | + fto = FunctionToOptimize(function_name="_build_model_id_to_deployment_index_map", file_path=code_file_path, parents=[FunctionParent(name="Router", type="ClassDef")]) |
| 910 | + # Discover tests |
| 911 | + discovered_tests, _, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file_path: [fto]}) |
| 912 | + |
| 913 | + # Verify the unittest was discovered |
| 914 | + assert len(discovered_tests) == 1 |
| 915 | + assert "router_file.Router._build_model_id_to_deployment_index_map" in discovered_tests |
| 916 | + assert len(discovered_tests["router_file.Router._build_model_id_to_deployment_index_map"]) == 1 |
| 917 | + router_test = next(iter(discovered_tests["router_file.Router._build_model_id_to_deployment_index_map"])) |
| 918 | + assert router_test.tests_in_file.test_file.resolve() == test_file_path.resolve() |
| 919 | + assert router_test.tests_in_file.test_function == "test_build_model_id_to_deployment_index_map" |
| 920 | + |
717 | 921 |
|
718 | 922 | def test_unittest_discovery_with_pytest_parameterized(): |
719 | 923 | with tempfile.TemporaryDirectory() as tmpdirname: |
@@ -1335,6 +1539,77 @@ def test_topological_sort(): |
1335 | 1539 |
|
1336 | 1540 | assert should_process is True |
1337 | 1541 |
|
| 1542 | +def test_analyze_imports_fixture(): |
| 1543 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 1544 | + test_file = Path(tmpdirname) / "test_example.py" |
| 1545 | + test_content = """ |
| 1546 | +from code_to_optimize.topological_sort import Graph |
| 1547 | +import pytest |
| 1548 | +
|
| 1549 | +@pytest.fixture |
| 1550 | +def g(): |
| 1551 | + return Graph(6) |
| 1552 | +
|
| 1553 | +def test_topological_sort(g): |
| 1554 | + g.addEdge(5, 2) |
| 1555 | + g.addEdge(5, 0) |
| 1556 | + g.addEdge(4, 0) |
| 1557 | + g.addEdge(4, 1) |
| 1558 | + g.addEdge(2, 3) |
| 1559 | + g.addEdge(3, 1) |
| 1560 | +
|
| 1561 | + assert g.topologicalSort()[0] == [5, 4, 2, 3, 1, 0] |
| 1562 | +""" |
| 1563 | + test_file.write_text(test_content) |
| 1564 | + |
| 1565 | + target_functions = {"Graph.topologicalSort"} |
| 1566 | + should_process = analyze_imports_in_test_file(test_file, target_functions) |
| 1567 | + |
| 1568 | + assert should_process is True |
| 1569 | + |
| 1570 | +def test_analyze_imports_class_fixture(): |
| 1571 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 1572 | + test_file = Path(tmpdirname) / "test_example.py" |
| 1573 | + test_content = """ |
| 1574 | +import pytest |
| 1575 | +
|
| 1576 | +from router_file import Router |
| 1577 | +
|
| 1578 | +
|
| 1579 | +class TestRouterIndexManagement: |
| 1580 | + @pytest.fixture |
| 1581 | + def router(self): |
| 1582 | + return Router(model_list=[]) |
| 1583 | + def test_build_model_id_to_deployment_index_map(self, router): |
| 1584 | + model_list = [ |
| 1585 | + { |
| 1586 | + "model_name": "gpt-3.5-turbo", |
| 1587 | + "litellm_params": {"model": "gpt-3.5-turbo"}, |
| 1588 | + "model_info": {"id": "model-1"}, |
| 1589 | + }, |
| 1590 | + { |
| 1591 | + "model_name": "gpt-4", |
| 1592 | + "litellm_params": {"model": "gpt-4"}, |
| 1593 | + "model_info": {"id": "model-2"}, |
| 1594 | + }, |
| 1595 | + ] |
| 1596 | +
|
| 1597 | + # Test: Build index from model list |
| 1598 | + router._build_model_id_to_deployment_index_map(model_list) |
| 1599 | +
|
| 1600 | + # Verify: model_list is populated |
| 1601 | + assert len(router.model_list) == 2 |
| 1602 | + # Verify: model_id_to_deployment_index_map is correctly built |
| 1603 | + assert router.model_id_to_deployment_index_map["model-1"] == 0 |
| 1604 | + assert router.model_id_to_deployment_index_map["model-2"] == 1 |
| 1605 | +""" |
| 1606 | + test_file.write_text(test_content) |
| 1607 | + |
| 1608 | + target_functions = {"Router._build_model_id_to_deployment_index_map"} |
| 1609 | + should_process = analyze_imports_in_test_file(test_file, target_functions) |
| 1610 | + |
| 1611 | + assert should_process is True |
| 1612 | + |
1338 | 1613 | def test_analyze_imports_aliased_class_method_negative(): |
1339 | 1614 | with tempfile.TemporaryDirectory() as tmpdirname: |
1340 | 1615 | test_file = Path(tmpdirname) / "test_example.py" |
|
0 commit comments