Skip to content

Conversation

@dasarchan
Copy link
Contributor

@dasarchan dasarchan commented Mar 19, 2025

User description

Use multiprocessing in process_test_files to speed up. Moderate speedup seen on dev server - running on systems with more CPU cores should have larger benefits.


PR Type

  • Enhancement

Description

  • Use fallback to single-threaded for <25 files.

  • Add module-level worker for multiprocessing.

  • Refactor test file processing for parallel discovery.

  • Update version metadata.


Changes walkthrough 📝

Relevant files
Enhancement
discover_unit_tests.py
Enhance test discovery with parallel processing                   

codeflash/discovery/discover_unit_tests.py

  • Check file count and default to single-threaded.
  • Introduce new module-level worker function: process_file_worker.
  • Refactor process_test_files to use multiprocessing Pool.
  • Add deduplication of test function results.
  • +308/-1 
    version.py
    Update project version metadata                                                   

    codeflash/version.py

  • Bump version string to 0.10.3.post7.dev0+86aa1cd6.
  • Update version tuple with detailed version parts.
  • +2/-2     

    Need help?
  • Type /help how to ... in the comments thread for any questions about PR-Agent usage.
  • Check out the documentation for more information.
  • @CLAassistant
    Copy link

    CLA assistant check
    Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


    Archan Das seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
    You have signed the CLA already but the status is still pending? Let us recheck it.

    @github-actions
    Copy link

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    ⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
    🧪 No relevant tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review

    Duplication

    The fallback to single-threaded processing when there are fewer than 25 files is implemented in two separate places. Consider refactoring this logic into a shared helper to reduce duplication and ease future maintenance.

    if len(file_to_test_map) < 25: #default to single-threaded if there aren't that many files
        return process_test_files_single_threaded(file_to_test_map, cfg)
    return process_test_files(file_to_test_map, cfg)
    Complexity

    The new worker function for multiprocessing (process_file_worker) contains deeply nested try/except blocks and many inline imports, which could complicate error tracing and debugging. Reviewing its error handling and modularizing some parts might improve clarity and maintainability.

    # Add this worker function at the module level (outside any other function)
    def process_file_worker(args_tuple):
        """Worker function for processing a single test file in a separate process.
    
        This must be at the module level (not nested) for multiprocessing to work.
        """
        import jedi
        import re
        import os
        from collections import defaultdict
        from pathlib import Path
    
        # Unpack the arguments
        test_file, functions, config = args_tuple
    
        try:
            # Each process creates its own Jedi project
            jedi_project = jedi.Project(path=config['project_root_path'])
    
            local_results = defaultdict(list)
            tests_found_in_file = 0
    
            # Convert test_file back to Path if necessary
            test_file_path = test_file if isinstance(test_file, Path) else Path(test_file)
    
            try:
                script = jedi.Script(path=test_file, project=jedi_project)
                all_names = script.get_names(all_scopes=True, references=True)
                all_defs = script.get_names(all_scopes=True, definitions=True)
                all_names_top = script.get_names(all_scopes=True)
    
                top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
                top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
            except Exception as e:
                return {
                    'status': 'error',
                    'error_type': 'jedi_script_error',
                    'error_message': str(e),
                    'test_file': test_file,
                    'results': {}
                }
    
            test_functions = set()
    
            if config['test_framework'] == "pytest":
                for function in functions:
                    if "[" in function.test_function:
                        function_name = re.split(r"[\[\]]", function.test_function)[0]
                        parameters = re.split(r"[\[\]]", function.test_function)[1]
                        if function_name in top_level_functions:
                            test_functions.add(
                                TestFunction(function_name, function.test_class, parameters, function.test_type)
                            )
                    elif function.test_function in top_level_functions:
                        test_functions.add(
                            TestFunction(function.test_function, function.test_class, None, function.test_type)
                        )
                    elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function):
                        # Try to match parameterized unittest functions here
                        base_name = re.sub(r"_\d+(?:_\w+)*$", "", function.test_function)
                        if base_name in top_level_functions:
                            test_functions.add(
                                TestFunction(
                                    function_name=base_name,
                                    test_class=function.test_class,
                                    parameters=function.test_function,
                                    test_type=function.test_type,
                                )
                            )
    
            elif config['test_framework'] == "unittest":
                functions_to_search = [elem.test_function for elem in functions]
                test_suites = {elem.test_class for elem in functions}
    
                matching_names = set(test_suites) & set(top_level_classes.keys())
                for matched_name in matching_names:
                    for def_name in all_defs:
                        if (
                                def_name.type == "function"
                                and def_name.full_name is not None
                                and f".{matched_name}." in def_name.full_name
                        ):
                            for function in functions_to_search:
                                (is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
    
                                if is_parameterized and new_function == def_name.name:
                                    test_functions.add(
                                        TestFunction(
                                            function_name=def_name.name,
                                            test_class=matched_name,
                                            parameters=parameters,
                                            test_type=functions[0].test_type,
                                        )
                                    )
                                elif function == def_name.name:
                                    test_functions.add(
                                        TestFunction(
                                            function_name=def_name.name,
                                            test_class=matched_name,
                                            parameters=None,
                                            test_type=functions[0].test_type,
                                        )
                                    )
    
            test_functions_list = list(test_functions)
            test_functions_raw = [elem.function_name for elem in test_functions_list]
    
            for name in all_names:
                if name.full_name is None:
                    continue
                m = re.search(r"([^.]+)\." + f"{name.name}$", name.full_name)
                if not m:
                    continue
                scope = m.group(1)
                indices = [i for i, x in enumerate(test_functions_raw) if x == scope]
                for index in indices:
                    scope_test_function = test_functions_list[index].function_name
                    scope_test_class = test_functions_list[index].test_class
                    scope_parameters = test_functions_list[index].parameters
                    test_type = test_functions_list[index].test_type
                    try:
                        definition = name.goto(follow_imports=True, follow_builtin_imports=False)
                    except Exception as e:
                        continue
                    if definition and definition[0].type == "function":
                        definition_path = str(definition[0].module_path)
                        # The definition is part of this project and not defined within the original function
                        if (
                                definition_path.startswith(config['project_root_path'] + os.sep)
                                and definition[0].module_name != name.module_name
                                and definition[0].full_name is not None
                        ):
                            if scope_parameters is not None:
                                if config['test_framework'] == "pytest":
                                    scope_test_function += "[" + scope_parameters + "]"
                                if config['test_framework'] == "unittest":
                                    scope_test_function += "_" + scope_parameters
    
                            # Get module name relative to project root
                            module_name = module_name_from_file_path(definition[0].module_path, config['project_root_path'])
    
                            full_name_without_module_prefix = definition[0].full_name.replace(
                                definition[0].module_name + ".", "", 1
                            )
                            qualified_name_with_modules_from_root = f"{module_name}.{full_name_without_module_prefix}"
    
                            # Create a serializable representation of the result
                            result_entry = {
                                'test_file': str(test_file),
                                'test_class': scope_test_class,
                                'test_function': scope_test_function,
                                'test_type': test_type,
                                'line_no': name.line,
                                'col_no': name.column
                            }
    
                            # Add to local results
                            if qualified_name_with_modules_from_root not in local_results:
                                local_results[qualified_name_with_modules_from_root] = []
                            local_results[qualified_name_with_modules_from_root].append(result_entry)
                            tests_found_in_file += 1
    
            return {
                'status': 'success',
                'test_file': test_file,
                'tests_found': tests_found_in_file,
                'results': dict(local_results)  # Convert defaultdict to dict for serialization
            }
    
        except Exception as e:
            import traceback
            return {
                'status': 'error',
                'error_type': 'general_error',
                'error_message': str(e),
                'traceback': traceback.format_exc(),
                'test_file': test_file,
                'results': {}
            }

    @github-actions
    Copy link

    PR Code Suggestions ✨

    No code suggestions found for the PR.

    @misrasaurabh1
    Copy link
    Contributor

    i believe making jedi parallel can involve some potential hiccups. see if they are a real problem or if its fine https://jedi.readthedocs.io/en/stable/docs/development.html#module-jedi.cache

    process_inputs.append((str(test_file), serializable_functions, config_dict))

    # Determine optimal number of processes
    max_processes = min(cpu_count() * 2, len(process_inputs), 32)
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    why cpu_count * 2? you don't want to create more processes than the number of cpus

    Copy link
    Contributor Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    i was under the impression this is largely a disk i/o bound task, so multiple processes per core would help

    # These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`.
    __version__ = "0.10.3"
    __version_tuple__ = (0, 10, 3)
    __version__ = "0.10.3.post7.dev0+86aa1cd6"
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    dont change this version

    processed_files += 1

    # Log progress
    if processed_files % 100 == 0 or processed_files == len(process_inputs):
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    maybe we can do something with a progress bar


    position = CodePosition(line_no=entry['line_no'], col_no=entry['col_no'])

    function_to_test_map[qualified_name].append(
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    in the initial code, function_to_test_map is a dictionary of lists, but i don't see why we can't make it a dictionary of sets so they're inmmediately deduped.


    for test in tests:
    # Create a hashable representation of the test
    test_hash = (
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    unnecessary, should be hashable already, the pydantic models are frozen


    try:
    script = jedi.Script(path=test_file, project=jedi_project)
    all_names = script.get_names(all_scopes=True, references=True)
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    if we're going for performance here, i wonder if we can do some 'caching' here. for example - if two tests are using the same function(in this case it'll be a jedi name in all_names) we shouldn't have to process it twice

    continue
    scope = m.group(1)
    indices = [i for i, x in enumerate(test_functions_raw) if x == scope]
    for index in indices:
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    this seems inefficient to do name.goto multiple times. am i missing something here? @misrasaurabh1

    @alvin-r
    Copy link
    Contributor

    alvin-r commented Mar 26, 2025

    Revisit this after #72 is done

    @alvin-r alvin-r closed this Mar 26, 2025
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    5 participants