From 2a920fb7c9e98eb2243edade3d5dd7eeb3ae9d99 Mon Sep 17 00:00:00 2001 From: Steboss Date: Wed, 2 Jul 2025 15:50:46 +0100 Subject: [PATCH 01/50] update the triage tool to deal with non-linear history --- .github/triage/jax_toolbox_triage/args.py | 36 ++++++++--- .github/triage/jax_toolbox_triage/main.py | 75 ++++++++++++++++++++--- 2 files changed, 92 insertions(+), 19 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index 3080e66be..34d570fb0 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -225,10 +225,30 @@ def parse_args(args=None) -> argparse.Namespace: help="Container runtime used, can be docker, pyxis, or local.", type=lambda s: s.lower(), ) - args = parser.parse_args(args=args) - assert args.container_runtime in {"docker", "pyxis", "local"}, ( - args.container_runtime + parser.add_argument( + "--main-branch", + type=str, + default="main", + help="The name of the main branch, linear branch to be used for bisection", ) + parser.add_argument( + "--feature-branch-name", + type=str, + default=None, + help="The name of the feature branch (e.g. blackwell-devel) to derive cherry-picks from", + ) + parser.add_argument( + "--cherry-pick-commits", + type=list, + default=None, + help="List of commits to cherry-pick from the feature branch to the main branch", + ) + args = parser.parse_args(args=args) + assert args.container_runtime in { + "docker", + "pyxis", + "local", + }, args.container_runtime # --{passing,failing}-commits are deprecated aliases for --{passing,failing}-versions. for prefix in ["passing", "failing"]: commits = getattr(args, f"{prefix}_commits") @@ -250,9 +270,7 @@ def parse_args(args=None) -> argparse.Namespace: if args.container_runtime == "local": assert ( args.passing_versions is not None and args.failing_versions is not None - ), ( - "For local runtime, --passing-versions and --failing-versions must be provided." - ) + ), "For local runtime, --passing-versions and --failing-versions must be provided." assert ( args.container is None and args.start_date is None @@ -297,7 +315,7 @@ def parse_args(args=None) -> argparse.Namespace: else: # None of --{passing,failing}-{versions,container} were passed, make sure the # compulsory arguments for the container-level search were passed - assert args.container is not None, ( - "--container must be passed for the container-level search" - ) + assert ( + args.container is not None + ), "--container must be passed for the container-level search" return args diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index 0bad0a00e..0fbff41a1 100644 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -135,9 +135,7 @@ def test_output_directory( ) ) out_dir = args.output_prefix / out_dirname - assert not out_dir.exists(), ( - f"{out_dir} should not already exist, maybe you are re-using {args.output_prefix}?" - ) + assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {args.output_prefix}?" out_dir.mkdir(mode=0o755) return out_dir.resolve() @@ -345,7 +343,9 @@ def check_container( assert bisection_versions is not None # Get the full lists of JAX/XLA commits and dates - def get_commit_history(worker, start, end, dir): + def get_commit_history( + worker, start, end, dir, main_branch=None, feature_branch_name=None + ): # In particular the end commit might not already be known if the older, # passing, container is being used for triage. commits_known = worker.exec( @@ -361,6 +361,41 @@ def get_commit_history(worker, start, end, dir): worker.check_exec( ["git", "fetch"], policy="once_per_container", workdir=dir ) + + # here we're considering the case of non-linear history + if feature_branch_name: + logger.info( + f"Using non-linear history logic with main branch {main_branch} and feature branch {feature_branch_name}" + ) + + # 1. find the linear range on the main branch + passing_main_commit_cmd = f"git merge-base {start} {end}" + failing_main_commit_cmd = f"git merge-base {end} origin/{args.main_branch}" + + passing_main_commit = worker.check_exec( + ["sh ", "-c", passing_main_commit_cmd], workdir=dir + ).stdout.strip() + failing_main_commit = worker.check_exec( + ["sh", "-c", failing_main_commit_cmd], workdir=dir + ).stdout.strip() + + # 2. find commits to cherry-pick from the failing branch + cherry_pick_cmd = f"git rev-list --reverse {failing_main_commit}..{end}" + cherry_pick_commits_str = worker.check_exec( + ["sh", "-c", cherry_pick_cmd], workdir=dir + ).stdout.strip() + cherry_pick_commits = cherry_pick_commits_str.splitlines() + # log for testing + logger.info(f"Cherry-pick commits: {cherry_pick_commits}") + + # 3. now we can use the main branch commits for bisection + start = passing_main_commit + end = failing_main_commit + # and store the cherry picks + args.cherry_pick_commits = cherry_pick_commits + else: + args.cherry_pick_commits = [] + result = worker.check_exec( [ "git", @@ -407,6 +442,8 @@ def get_commit_history(worker, start, end, dir): passing_versions[package], failing_versions[package], package_dirs[package], + args.main_branch, + args.feature_branch_name, ) # Confirm they're sorted by commit date assert all( @@ -487,12 +524,30 @@ def build_and_test( bisection_versions[package] = version changed.append(f"{package}@{version}") if package in package_dirs: - # A git repository that exists in the container. - git_commands += [ - f"cd {package_dirs[package]}", - "git stash", - f"git checkout {version}", - ] + # in case of non-linear history - should we limit this to XLA and JAX only? + if args.feature_branch_name and package in ["jax", "xla"]: + logger.info("Working on a non-linear history") + git_commands.append(f"cd {package_dirs[package]}") + git_commands.append("git stash") + # this is a checkout on the main branch + git_commands.append(f"git checkout {version}") + + # cherry-picking + if args.cherry_pick_commits: + cherry_pick_str = " ".join(args.cherry_pick_commits) + git_commands.append( + f"git cherry-pick {cherry_pick_str} || (echo 'Cherry-pick failed' && exit 1)" + ) + + else: + # Linear history + # A git repository that exists in the container. + git_commands += [ + f"cd {package_dirs[package]}", + "git stash", + f"git checkout {version}", + ] + else: # Another software package, `version` is probably a version number. # Installation of this version is delegated to an installPACKAGE.sh From a17918266b4835a1e21e34981184752a0b8d9866 Mon Sep 17 00:00:00 2001 From: Steboss Date: Wed, 2 Jul 2025 15:57:25 +0100 Subject: [PATCH 02/50] update changes --- .github/triage/jax_toolbox_triage/main.py | 39 +++++++++++------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index 0fbff41a1..7aec362bd 100644 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -344,7 +344,7 @@ def check_container( # Get the full lists of JAX/XLA commits and dates def get_commit_history( - worker, start, end, dir, main_branch=None, feature_branch_name=None + worker, package, start, end, dir, main_branch=None, feature_branch_name=None ): # In particular the end commit might not already be known if the older, # passing, container is being used for triage. @@ -363,7 +363,8 @@ def get_commit_history( ) # here we're considering the case of non-linear history - if feature_branch_name: + # limit for the moment to JAX and XLA + if feature_branch_name and package in ["jax", "xla"]: logger.info( f"Using non-linear history logic with main branch {main_branch} and feature branch {feature_branch_name}" ) @@ -381,20 +382,18 @@ def get_commit_history( # 2. find commits to cherry-pick from the failing branch cherry_pick_cmd = f"git rev-list --reverse {failing_main_commit}..{end}" - cherry_pick_commits_str = worker.check_exec( - ["sh", "-c", cherry_pick_cmd], workdir=dir - ).stdout.strip() - cherry_pick_commits = cherry_pick_commits_str.splitlines() - # log for testing - logger.info(f"Cherry-pick commits: {cherry_pick_commits}") + cherry_pick_commits_list = ( + worker.check_exec(["sh", "-c", cherry_pick_cmd], workdir=dir) + .stdout.strip() + .splitlines() + ) + if cherry_pick_commits_list: + args.cherry_pick_commits[package] = cherry_pick_commits_list + logger.info(f"Cherry-pick commits: {cherry_pick_commits_list}") # 3. now we can use the main branch commits for bisection start = passing_main_commit end = failing_main_commit - # and store the cherry picks - args.cherry_pick_commits = cherry_pick_commits - else: - args.cherry_pick_commits = [] result = worker.check_exec( [ @@ -439,6 +438,7 @@ def get_commit_history( continue package_versions[package] = get_commit_history( worker, + package, passing_versions[package], failing_versions[package], package_dirs[package], @@ -525,20 +525,17 @@ def build_and_test( changed.append(f"{package}@{version}") if package in package_dirs: # in case of non-linear history - should we limit this to XLA and JAX only? - if args.feature_branch_name and package in ["jax", "xla"]: + package_cherry_picks = args.cherry_pick_commits.get(package, []) + if package_cherry_picks: logger.info("Working on a non-linear history") git_commands.append(f"cd {package_dirs[package]}") git_commands.append("git stash") # this is a checkout on the main branch git_commands.append(f"git checkout {version}") - - # cherry-picking - if args.cherry_pick_commits: - cherry_pick_str = " ".join(args.cherry_pick_commits) - git_commands.append( - f"git cherry-pick {cherry_pick_str} || (echo 'Cherry-pick failed' && exit 1)" - ) - + cherry_pick_str = " ".join(package_cherry_picks) + git_commands.append( + f"git cherry-pick {cherry_pick_str} || (echo 'Cherry-pick failed' && exit 1)" + ) else: # Linear history # A git repository that exists in the container. From e848e29d08ed8bfa0874a32480980e18f8cc56a9 Mon Sep 17 00:00:00 2001 From: Steboss Date: Thu, 3 Jul 2025 10:57:27 +0100 Subject: [PATCH 03/50] fix args --- .github/triage/jax_toolbox_triage/args.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index 34d570fb0..eba988361 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -237,12 +237,6 @@ def parse_args(args=None) -> argparse.Namespace: default=None, help="The name of the feature branch (e.g. blackwell-devel) to derive cherry-picks from", ) - parser.add_argument( - "--cherry-pick-commits", - type=list, - default=None, - help="List of commits to cherry-pick from the feature branch to the main branch", - ) args = parser.parse_args(args=args) assert args.container_runtime in { "docker", From cf1345ceb0252c54a3ce71b19e7d828064cc6947 Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 4 Jul 2025 12:38:18 +0100 Subject: [PATCH 04/50] start drafting the test for the non-linear bisection --- .../triage/tests/mock_scripts/build-jax.sh | 9 + .../triage/tests/mock_scripts/test-case.sh | 20 ++ .../tests/test_triage_history_bisection.py | 331 ++++++++++++++++++ 3 files changed, 360 insertions(+) create mode 100755 .github/triage/tests/mock_scripts/build-jax.sh create mode 100755 .github/triage/tests/mock_scripts/test-case.sh create mode 100644 .github/triage/tests/test_triage_history_bisection.py diff --git a/.github/triage/tests/mock_scripts/build-jax.sh b/.github/triage/tests/mock_scripts/build-jax.sh new file mode 100755 index 000000000..5d19da55f --- /dev/null +++ b/.github/triage/tests/mock_scripts/build-jax.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +if [ ! -f "feature_file.txt" ]; then + echo "Build FAILED: The feature commit was not applied (feature_file.txt is missing)." + exit 1 +fi + +echo "Mock build script: Build successful (feature commit is present)." +exit 0 diff --git a/.github/triage/tests/mock_scripts/test-case.sh b/.github/triage/tests/mock_scripts/test-case.sh new file mode 100755 index 000000000..355cf71dd --- /dev/null +++ b/.github/triage/tests/mock_scripts/test-case.sh @@ -0,0 +1,20 @@ +#!/bin/bash + + +REPO_PATH=$1 +BAD_COMMIT=$2 + +if [ -z "$REPO_PATH" ] || [ -z "$BAD_COMMIT" ]; then + echo "Usage: $0 " + exit 1 +fi + +cd ${REPO_PATH} + +if git merge-base --is-ancestor ${BAD_COMMIT} HEAD; then + echo "The commit ${BAD_COMMIT} is an ancestor of the current HEAD." + exit 1 +else + echo "The commit ${BAD_COMMIT} is NOT an ancestor of the current HEAD." + exit 0 +fi diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py new file mode 100644 index 000000000..50780e8b1 --- /dev/null +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -0,0 +1,331 @@ +import subprocess +import tempfile +import pathlib +import os +import logging +from collections import OrderedDict +import pytest +import datetime + +# for the moment avoid using this, because we can't import it +# then we'll refactor the main code +# from jax_toolbox_triage.main import get_commit_history +from jax_toolbox_triage.logic import version_search, TestResult +from jax_toolbox_triage.container import Container + + +def run_command(command, cwd=None, env=None): + """Simple function to run a command in a subprocess. + + Args: + command (list): The command to run as a list of strings. + cwd (str, optional): The working directory to run the command in. + env (dict, optional): Environment variables to set for the command. + Returns: + str: The standard output of the command. + """ + try: + result = subprocess.run( + command, cwd=cwd, env=env, check=True, capture_output=True, text=True + ) + return result.stdout.strip() + except subprocess.CalledProcessError as e: + logging.error(f"Command '{' '.join(command)}' failed with error: {e}") + raise e + + +class MockContainer(Container): + """A mock container class for testing purposes.""" + + def __init__(self, mock_scripts_path, logger): + super().__init__(logger=logger) + self.mock_scripts_path = mock_scripts_path + self._env = os.environ.copy() + self._env["PATH"] = f"{self.mock_scripts_path}:{self._env['PATH']}" + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + pass + + def __repr__(self): + return "MockContainer" + + def check_exec(self, cmd, **kwargs): + """Override the check_exec""" + return super().check_exec(cmd, **kwargs) + + def exec( + self, + command, + *, + policy="default", + stderr="interleave", + workdir=None, + log_level=logging.DEBUG, + ): + self._logger.debug(f"Executing command: {command} in {workdir}") + is_shell_command = command[0] == "sh" and command[1] == "-c" + cmd_to_run = command[2] if is_shell_command else command + try: + return subprocess.run( + cmd_to_run, + capture_output=True, + text=True, + cwd=workdir, + env=self._env, + shell=is_shell_command, + ) + except FileNotFoundError as e: + return subprocess.CompletedProcess(command, 127, stderr=str(e)) + + def exists(self) -> bool: + return True + + +def get_commit_history( + worker, package, start, end, dir, main_branch, feature_branch_name, args, logger +): + """ + This is a local copy of the get_commit_history logic from main.py, + For the moment we don't want to import it, we'll then refactor the main code + """ + # In particular the end commit might not already be known if the older, + # passing, container is being used for triage. + commits_known = worker.exec( + [ + "sh", + "-c", + f"git cat-file commit {start} && git cat-file commit {end}", + ], + policy="once_per_container", + workdir=dir, + ) + if commits_known.returncode != 0: + worker.check_exec(["git", "fetch"], policy="once_per_container", workdir=dir) + + if feature_branch_name and package in ["jax", "xla"]: + logger.info( + f"Using non-linear history logic with main branch {main_branch} and feature branch {feature_branch_name}" + ) + passing_main_commit_cmd = f"git merge-base {start} {end}" + failing_main_commit_cmd = f"git merge-base {end} origin/{main_branch}" + + # In a local test, origin doesn't exist, so we use the local main branch ref. + failing_main_commit_cmd = f"git merge-base {end} {main_branch}" + + passing_main_commit = worker.check_exec( + ["sh", "-c", passing_main_commit_cmd], workdir=dir + ).stdout.strip() + failing_main_commit = worker.check_exec( + ["sh", "-c", failing_main_commit_cmd], workdir=dir + ).stdout.strip() + + cherry_pick_cmd = f"git rev-list --reverse {failing_main_commit}..{end}" + cherry_pick_commits_list = ( + worker.check_exec(["sh", "-c", cherry_pick_cmd], workdir=dir) + .stdout.strip() + .splitlines() + ) + if cherry_pick_commits_list: + args.cherry_pick_commits[package] = cherry_pick_commits_list + + start = passing_main_commit + end = failing_main_commit + + result = worker.check_exec( + [ + "git", + "log", + "--first-parent", + "--reverse", + "--format=%H %cI", + f"{start}^..{end}", + ], + policy="once", + stderr=subprocess.PIPE, + workdir=dir, + ) + data = [] + for line in result.stdout.splitlines(): + commit, date_str = line.split() + date = datetime.datetime.fromisoformat(date_str).astimezone( + datetime.timezone.utc + ) + data.append((commit, date)) + return data + + +@pytest.fixture +def triage_test_env(): + """ + Fixture to set up the test environment for triage tests. + + The fixture creates a temp directory and a git repo with a + defined linear and non-linear history. + + The fixture yields a dictionary of paths and commit hashes + """ + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) + repo_path = temp_path / "repos" + output_path = temp_path / "output" + mock_scripts_path = temp_path / "mock_scripts" + repo_path.mkdir() + output_path.mkdir() + mock_scripts_path.mkdir() + + # Generation of mock scripts + # build-jax.sh + source_scripts_dir = pathlib.Path(__file__).parent / "mock_scripts" + build_script_content = (source_scripts_dir / "build-jax.sh").read_text() + (mock_scripts_path / "build-jax.sh").write_text(build_script_content) + os.chmod(mock_scripts_path / "build-jax.sh", 0o755) + # test-case.sh helper test script + test_case_content = (source_scripts_dir / "test-case.sh").read_text() + (mock_scripts_path / "test-case.sh").write_text(test_case_content) + os.chmod(mock_scripts_path / "test-case.sh", 0o755) + + # Create a git repository + jax_repo_path = repo_path / "jax" + jax_repo_path.mkdir() + + def git_cmd(command, *args): + return ( + run_command(["git", command, *args], cwd=jax_repo_path).stdout().strip() + ) + + # main + git_cmd("init", "-b", "main") + git_cmd("config", "user.name", "Test User") + git_cmd("config", "user.email", "test@user.it") + # Create a linear commit history + m1 = git_cmd("commit", "--allow-empty", "-m", "M1") + m2 = git_cmd("commit", "--allow-empty", "-m", "M2") # good commit + m3 = git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit + # create a feature branch + git_cmd("checkout", "-b", "feature", m1) + (jax_repo_path / "feature_file.txt").write_text("feature") + git_cmd("add", "feature_file.txt") + f1 = git_cmd("commit", "-m", "F1") + + git_cmd("checkout", "-b", "passing_nonlinear", m2) + git_cmd("cherry-pick", f1) + passing_nonlinear = git_cmd("rev-parse", "HEAD") + git_cmd("checkout", "-b", "failing_nonlinear", m3) + git_cmd("cherry-pick", f1) + failing_nonlinear = git_cmd("rev-parse", "HEAD") + git_cmd("checkout", "main") + + # yield all the info + yield { + "paths": { + "repo": repo_path, + "output": output_path, + "scripts": mock_scripts_path, + }, + "commits": { + "good_main": m2, + "bad_main": m3, + "feature": f1, + "passing_nonlinear": passing_nonlinear, + "failing_nonlinear": failing_nonlinear, + }, + } + + +# Do we need to parametrize the test cases? +@pytest.mark.parametrize( + "scenario, passing_commit_key, failing_commit_key, use_nonlinear_flags, expected_good_key, expected_bad_key", + [ + ( + "Non-Linear History", # scenario + "passing_nonlinear", + "failing_nonlinear", # bisection range + True, # use the new flag + "good_main", + "bad_main", # expected results + ), + ("Linear History", "good_main", "bad_main", False, "good_main", "bad_main"), + ], +) +def test_traige_scenarios( + triage_test_env, + scenario, + passing_commit_key, + failing_commit_key, + use_nonlinear_flags, + expected_good_key, + expected_bad_key, +): + """Check if we nee dot restructure this + add types""" + paths = triage_test_env["paths"] + all_commits = triage_test_env["commits"] + + class MockArgs: + main_branch = "main" + feature_branch_name = "feature" if use_nonlinear_flags else None + bazel_cache = "" + build_scripts_path = None + test_command = ["test-case.sh", str(paths["repo"]), all_commits["bad_main"]] + cherry_pick_commits = {} + + args = MockArgs() + logger = logging.getLogger(f"Scenario-{scenario}") + logging.basicConfig(level=logging.INFO) + + passing_versions = {"jax": all_commits[passing_commit_key]} + failing_versions = {"jax": all_commits[failing_commit_key]} + package_dirs = {"jax": str(paths["repo"])} + mock_container = MockContainer(paths["scripts"], logger) + # call the get_commit_history + package_versions = OrderedDict() + package_versions["jax"] = get_commit_history( + worker=mock_container, + package="jax", + start=passing_versions["jax"], + end=failing_versions["jax"], + dir=package_dirs["jax"], + main_branch=args.main_branch, + feature_branch_name=args.feature_branch_name, + args=args, + logger=logger, + ) + + # build and test + def build_and_test_wrapper(*, versions, test_output_log_level=logging.DEBUG): + git_commands = [f"cd {package_dirs['jax']}", "git stash --include-untracked"] + if use_nonlinear_flags: + build_script = paths["scripts"] / "build-jax.sh" + git_commands.append(f"git checkout {versions['jax']}") + cherry_picks = args.cherry_pick_commits.get("jax", []) + if cherry_picks: + git_commands.append(f"git cherry-pick { ' '.join(cherry_picks)}") + else: + build_script = paths["scripts"] / "build-jax-linear.sh" + build_script.write_text("#!/bin/sh\nexit 0") + os.chmod(build_script, 0o755) + git_commands.append(f"git checkout {versions['jax']}") + + mock_container.check_exec(["sh", "-c", " && ".join(git_commands)]) + mock_container.check_exec([str(build_script)]) + result = mock_container.exec(args.test_command, workdir=package_dirs["jax"]) + return TestResult( + host_output_directory=paths["output"], + result=result.returncode == 0, + stdouterr=" ", + ) + + # bisection + result, _, _ = version_search( + versions=package_versions, + build_and_test=build_and_test_wrapper, + logger=logger, + skip_precondition_checks=False, + ) + + # test + assert result.get("jax_good") == all_commits[expected_good_key] + assert result.get("jax_bad") == all_commits[expected_bad_key] From 259b4acc24032fe53f875e3009306a8cb7181b0c Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 4 Jul 2025 14:40:46 +0100 Subject: [PATCH 05/50] fix the test and prepare for refactoring the main --- .../tests/test_triage_history_bisection.py | 52 ++++++++++++------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 50780e8b1..49480e7af 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -61,7 +61,7 @@ def exec( command, *, policy="default", - stderr="interleave", + stderr="interleaved", workdir=None, log_level=logging.DEBUG, ): @@ -102,8 +102,10 @@ def get_commit_history( policy="once_per_container", workdir=dir, ) + if commits_known.returncode != 0: - worker.check_exec(["git", "fetch"], policy="once_per_container", workdir=dir) + logger.error("ERROR!") + logger.error(f"{commits_known.stderr}") if feature_branch_name and package in ["jax", "xla"]: logger.info( @@ -193,23 +195,25 @@ def triage_test_env(): jax_repo_path.mkdir() def git_cmd(command, *args): - return ( - run_command(["git", command, *args], cwd=jax_repo_path).stdout().strip() - ) + return run_command(["git", command, *args], cwd=jax_repo_path) # main git_cmd("init", "-b", "main") git_cmd("config", "user.name", "Test User") git_cmd("config", "user.email", "test@user.it") # Create a linear commit history - m1 = git_cmd("commit", "--allow-empty", "-m", "M1") - m2 = git_cmd("commit", "--allow-empty", "-m", "M2") # good commit - m3 = git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit + git_cmd("commit", "--allow-empty", "-m", "M1") + m1 = git_cmd("rev-parse", "HEAD") + git_cmd("commit", "--allow-empty", "-m", "M2") # good commit + m2 = git_cmd("rev-parse", "HEAD") + git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit + m3 = git_cmd("rev-parse", "HEAD") # create a feature branch git_cmd("checkout", "-b", "feature", m1) (jax_repo_path / "feature_file.txt").write_text("feature") git_cmd("add", "feature_file.txt") - f1 = git_cmd("commit", "-m", "F1") + git_cmd("commit", "-m", "F1") + f1 = git_cmd("rev-parse", "HEAD") git_cmd("checkout", "-b", "passing_nonlinear", m2) git_cmd("cherry-pick", f1) @@ -251,7 +255,7 @@ def git_cmd(command, *args): ("Linear History", "good_main", "bad_main", False, "good_main", "bad_main"), ], ) -def test_traige_scenarios( +def test_triage_scenarios( triage_test_env, scenario, passing_commit_key, @@ -263,13 +267,14 @@ def test_traige_scenarios( """Check if we nee dot restructure this + add types""" paths = triage_test_env["paths"] all_commits = triage_test_env["commits"] + jax_repo_path = paths["repo"] / "jax" class MockArgs: main_branch = "main" feature_branch_name = "feature" if use_nonlinear_flags else None bazel_cache = "" build_scripts_path = None - test_command = ["test-case.sh", str(paths["repo"]), all_commits["bad_main"]] + test_command = ["test-case.sh", str(jax_repo_path), all_commits["bad_main"]] cherry_pick_commits = {} args = MockArgs() @@ -278,7 +283,7 @@ class MockArgs: passing_versions = {"jax": all_commits[passing_commit_key]} failing_versions = {"jax": all_commits[failing_commit_key]} - package_dirs = {"jax": str(paths["repo"])} + package_dirs = {"jax": str(jax_repo_path)} mock_container = MockContainer(paths["scripts"], logger) # call the get_commit_history package_versions = OrderedDict() @@ -296,22 +301,31 @@ class MockArgs: # build and test def build_and_test_wrapper(*, versions, test_output_log_level=logging.DEBUG): - git_commands = [f"cd {package_dirs['jax']}", "git stash --include-untracked"] + workdir = package_dirs["jax"] + mock_container.check_exec( + ["git", "stash", "--include-untracked"], workdir=workdir + ) + if use_nonlinear_flags: build_script = paths["scripts"] / "build-jax.sh" - git_commands.append(f"git checkout {versions['jax']}") + mock_container.check_exec( + ["git", "checkout", versions["jax"]], workdir=workdir + ) cherry_picks = args.cherry_pick_commits.get("jax", []) if cherry_picks: - git_commands.append(f"git cherry-pick { ' '.join(cherry_picks)}") + mock_container.check_exec( + ["git", "cherry-pick"] + cherry_picks, workdir=workdir + ) else: build_script = paths["scripts"] / "build-jax-linear.sh" build_script.write_text("#!/bin/sh\nexit 0") os.chmod(build_script, 0o755) - git_commands.append(f"git checkout {versions['jax']}") + mock_container.check_exec( + ["git", "checkout", versions["jax"]], workdir=workdir + ) - mock_container.check_exec(["sh", "-c", " && ".join(git_commands)]) - mock_container.check_exec([str(build_script)]) - result = mock_container.exec(args.test_command, workdir=package_dirs["jax"]) + mock_container.check_exec([str(build_script)], workdir=workdir) + result = mock_container.exec(args.test_command, workdir=workdir) return TestResult( host_output_directory=paths["output"], result=result.returncode == 0, From c6554e4f1f4097beaece7bbceecf5fd54748d630 Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 4 Jul 2025 17:04:30 +0100 Subject: [PATCH 06/50] refactor main script in multiple ones and start testing --- .github/triage/jax_toolbox_triage/bisect.py | 82 +++ .../jax_toolbox_triage/container_factory.py | 28 + .github/triage/jax_toolbox_triage/main.py | 645 +----------------- .github/triage/jax_toolbox_triage/summary.py | 78 +++ .../triage/jax_toolbox_triage/triage_tool.py | 613 +++++++++++++++++ .github/triage/jax_toolbox_triage/versions.py | 89 +++ 6 files changed, 904 insertions(+), 631 deletions(-) create mode 100644 .github/triage/jax_toolbox_triage/bisect.py create mode 100644 .github/triage/jax_toolbox_triage/container_factory.py create mode 100644 .github/triage/jax_toolbox_triage/summary.py create mode 100644 .github/triage/jax_toolbox_triage/triage_tool.py create mode 100644 .github/triage/jax_toolbox_triage/versions.py diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py new file mode 100644 index 000000000..8ca2b02de --- /dev/null +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -0,0 +1,82 @@ +import datetime +import subprocess + + +def get_commit_history( + worker, + package, + start, + end, + dir, + main_branch=None, + feature_branch_name=None, + logger=None, + args=None, +): + # In particular the end commit might not already be known if the older, + # passing, container is being used for triage. + commits_known = worker.exec( + [ + "sh", + "-c", + f"git cat-file commit {start} && git cat-file commit {end}", + ], + policy="once_per_container", + workdir=dir, + ) + if commits_known.returncode != 0: + worker.check_exec(["git", "fetch"], policy="once_per_container", workdir=dir) + + # here we're considering the case of non-linear history + # limit for the moment to JAX and XLA + if feature_branch_name and package in ["jax", "xla"]: + logger.info( + f"Using non-linear history logic with main branch {main_branch} and feature branch {feature_branch_name}" + ) + + # 1. find the linear range on the main branch + passing_main_commit_cmd = f"git merge-base {start} {end}" + failing_main_commit_cmd = f"git merge-base {end} origin/{main_branch}" + + passing_main_commit = worker.check_exec( + ["sh ", "-c", passing_main_commit_cmd], workdir=dir + ).stdout.strip() + failing_main_commit = worker.check_exec( + ["sh", "-c", failing_main_commit_cmd], workdir=dir + ).stdout.strip() + + # 2. find commits to cherry-pick from the failing branch + cherry_pick_cmd = f"git rev-list --reverse {failing_main_commit}..{end}" + cherry_pick_commits_list = ( + worker.check_exec(["sh", "-c", cherry_pick_cmd], workdir=dir) + .stdout.strip() + .splitlines() + ) + if cherry_pick_commits_list: + args.cherry_pick_commits[package] = cherry_pick_commits_list + logger.info(f"Cherry-pick commits: {cherry_pick_commits_list}") + + # 3. now we can use the main branch commits for bisection + start = passing_main_commit + end = failing_main_commit + + result = worker.check_exec( + [ + "git", + "log", + "--first-parent", + "--reverse", + "--format=%H %cI", + f"{start}^..{end}", + ], + policy="once", + stderr=subprocess.PIPE, + workdir=dir, + ) + logger.debug(f"stderr: {result.stderr.strip()}") + data = [] + for line in result.stdout.splitlines(): + commit, date = line.split() + date = datetime.datetime.fromisoformat(date).astimezone(datetime.timezone.utc) + data.append((commit, date)) + return data diff --git a/.github/triage/jax_toolbox_triage/container_factory.py b/.github/triage/jax_toolbox_triage/container_factory.py new file mode 100644 index 000000000..1a2febac6 --- /dev/null +++ b/.github/triage/jax_toolbox_triage/container_factory.py @@ -0,0 +1,28 @@ +import logging +from .container import Container +from .docker import DockerContainer +from .pyxis import PyxisContainer +from .local import LocalContainer + + +def make_container( + runtime: str, url: str, mounts: list, logger: logging.Logger, **kwargs +) -> Container: + """ + This function craetes a container objects, based on the specified runtime + + Args: + runtime (str): The container runtime to use (e.g., 'docker', 'pyxis', 'local'). + url (str): The URL of the container. + mounts (list): List of mounts to be used in the container. + logger (logging.Logger): Logger instance for logging messages. + **kwargs: Additional keyword arguments for specific container types. + + Returns: + Container: A container class associated with the specified runtime. + """ + if runtime == "local": + return LocalContainer(logger=logger) + + container_impl = DockerContainer if runtime == "docker" else PyxisContainer + return container_impl(url, logger=logger, mounts=mounts, **kwargs) diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index 7aec362bd..222869d6a 100644 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -1,643 +1,26 @@ -import collections -import datetime -import functools -import hashlib -import itertools -import json -import logging -import pathlib -import subprocess -import time -import typing +from .args import parse_args +from .utils import get_logger +from .triage_tool import TriageTool -from .args import compulsory_software, optional_software, parse_args -from .container import Container -from .docker import DockerContainer -from .local import LocalContainer -from .logic import container_search, TestResult, version_search -from .pyxis import PyxisContainer -from .utils import ( - container_url as container_url_base, - get_logger, - prepare_bazel_cache_mounts, -) - -def get_env(worker: Container) -> typing.Dict[str, str]: - """ - Get the runtime environment in the given container. - - Returns: {env_var: value} dictionary, sorted by key. - """ - - def impl() -> typing.Dict[str, str]: - kvs = ( - worker.check_exec(["env", "-0"], policy="once", stderr="separate") - .stdout[:-1] # skip the trailing \0 - .split("\0") - ) - return dict(kv.split("=", 1) for kv in kvs) - - # Remove any environment variables that differ between consecutive `env` calls, for - # example some step-specific Slurm variables. - env1, env2 = impl(), impl() - # sorted(...) for run-to-run determinism - return {k: env1[k] for k in sorted(env1.keys() & env2.keys()) if env1[k] == env2[k]} - - -def get_commits_and_dirs( - worker: Container, -) -> typing.Tuple[typing.Dict[str, str], typing.Dict[str, str]]: +def main() -> None: """ - Get the git repository paths and current HEAD commits in the given environment of - the software packages named in `compulsory_software` and `optional_software`. - - Returns: ({package: commit}, {package: directory}) + Main entry point for the triage tool. """ - # Formulated this way to avoid paying too many times for container startup. - cmds = [] - for package in compulsory_software + optional_software: - bits = [ - f"(cd /opt/{package} && git rev-parse HEAD && echo {package} && pwd)", - f"(cd /opt/{package}-source && git rev-parse HEAD && echo {package} && pwd)", - ] - if package in optional_software: - bits.append("true") - cmds.append(f"({' || '.join(bits)})") - result = worker.check_exec( - ["sh", "-c", " && ".join(cmds)], policy="once", stderr="separate" - ) - versions, dirs = {}, {} - # Look over triplets of output lines - for commit, package, dirname in zip(*([iter(result.stdout.splitlines())] * 3)): - dirs[package] = dirname - versions[package] = commit - return versions, dirs - - -def get_versions_dirs_env( - worker: Container, - versions_from_env: bool, -) -> typing.Tuple[typing.Dict[str, str], typing.Dict[str, str], typing.Dict[str, str]]: - """ - Get software versions in the given [container] environment, git repository paths - where relevant, and the runtime environment. - - The list of software versions is drawn from git repositories at known container - locations and, if `versions_from_env` is True, from the environment. - - Returns: - versions: {package: version or commit}, - dirs: {package: git_repository_dir} - env: {env_var: value} - """ - # Get the git repository paths and commits from the container. - versions, dirs = get_commits_and_dirs(worker) - - # Get the environment variables from the container. - env = get_env(worker) - - if versions_from_env: - # Promote any XXX_VERSION environment variables into `versions` if `XXX` is - # not already there. - for k, v in env.items(): - if not len(v) or not k.endswith("_VERSION"): - continue - package = k[:-8] - assert package not in versions, (versions, package) - versions[package] = v - return versions, dirs, env - - -def main() -> None: args = parse_args() - bazel_cache_mounts = prepare_bazel_cache_mounts(args.bazel_cache) logger = get_logger(args.output_prefix) - logger.info("Arguments:") - for k, v in vars(args).items(): - logger.info(f" {k}: {v}") - logger.info( - "Verbose output, including stdout/err of triage commands, will be written to " - f"{(args.output_prefix / 'debug.log').resolve()}" - ) - - def test_output_directory( - url: str, versions: typing.Dict[str, str] = {} - ) -> pathlib.Path: - # Construct an output directory name, on the host, for output files written by - # the test case. - hash_chars = 8 - urlhash = f"container-{hashlib.sha1(url.encode()).hexdigest()[:hash_chars]}" - out_dirname = "-".join( - itertools.chain( - [urlhash], - map(lambda t: f"{t[0]}-{t[1][:hash_chars]}", sorted(versions.items())), - ) - ) - out_dir = args.output_prefix / out_dirname - assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {args.output_prefix}?" - out_dir.mkdir(mode=0o755) - return out_dir.resolve() - - container_url = functools.partial( - container_url_base, - container=args.container, - template=args.container_url_template, - ) - - def Container( - url, test_output_host_directory: typing.Optional[pathlib.Path] = None - ): - if args.container_runtime == "local": - return LocalContainer(logger=logger) - - Imp = DockerContainer if args.container_runtime == "docker" else PyxisContainer - mounts = bazel_cache_mounts + args.container_mount - if test_output_host_directory is not None: - # This can be used to save useful output from the test case (e.g. HLOs) - mounts.append((test_output_host_directory, "/triage-tool-output")) - return Imp(url, logger=logger, mounts=mounts) - - def get_versions( - container_url: typing.Optional[str], - explicit_versions: typing.Optional[typing.Dict[str, str]], - versions_from_env: bool, - ) -> typing.Tuple[ - typing.Dict[str, str], - typing.Optional[typing.Dict[str, str]], - typing.Optional[typing.Dict[str, str]], - typing.Optional[typing.Dict[str, str]], - ]: - """ - Given an optional container URL (e.g. --failing-container) and an optional set - of overriden versions (e.g. --failing-versions), obtain a list of software - versions to bookend the triage range. - - Also returns the container's runtime environment variables for diagnostic purposes. - If *only* overrides are given, those are returned verbatim and all other return - values are None. + try: + tool = TriageTool(args, logger) + tool.prepare() - Otherwise, the software versions, git repository directories and runtime - environment are extracted from the given container. If overrides are *also* given, - these take precedence over the extracted values. + passing_url, failing_url = tool.find_container_range() - It is an error for both `container_url` and `explicit_versions` to be None. - - Returns: - versions: {packge: version} mapping defining one end of the triage range - url_versions: {package: version} corresponding to the given container (or None) - dirs: {package: git_repository_dir} (or None) - env: {env_var: value} (or None) - """ - if explicit_versions is not None and container_url is None: - return explicit_versions, None, None, None - assert container_url is not None - logger.info(f"Extracting versions from {container_url} ...") - with Container(container_url) as worker: - url_versions, dirs, env = get_versions_dirs_env(worker, versions_from_env) - overriden_versions = url_versions.copy() - if explicit_versions is not None: - overriden_versions.update(explicit_versions) - return overriden_versions, url_versions, dirs, env - - def add_summary_record( - section: str, - record: typing.Mapping[str, typing.Union[bool, float, str]], - scalar: bool = False, - ): - """ - Add a record to the output JSON file. This is intended to provide a useful record - even in case of a fatal error. - """ - summary_filename = args.output_prefix / "summary.json" - try: - with open(summary_filename, "r") as ifile: - data = json.load(ifile) - except FileNotFoundError: - data = {} - if scalar: - if section in data: - logging.warning(f"Overwriting summary data in section {section}") - data[section] = record - else: - if section not in data: - data[section] = [] - data[section].append(record) - with open(summary_filename, "w") as ofile: - json.dump(data, ofile) - - versions_from_env = args.build_scripts_path is not None - - def check_container( - date: datetime.date, *, test_output_log_level: int = logging.DEBUG - ) -> TestResult: - """ - See if the test passes in the given dated container. - """ - before = time.monotonic() - out_dir = test_output_directory(container_url(date)) - with Container( - container_url(date), test_output_host_directory=out_dir - ) as worker: - versions, _, _ = get_versions_dirs_env(worker, versions_from_env) - # This will stream interleaved stdout/stderr into the logger - result = worker.exec(args.test_command, log_level=test_output_log_level) - test_time = time.monotonic() - before - test_pass = result.returncode == 0 - logger.info( - f"Ran test case in {worker} in {test_time:.1f}s, pass={test_pass}" - ) - add_summary_record( - "container", - { - "container": container_url(date), - "output_directory": out_dir.as_posix(), - "result": test_pass, - "test_time": test_time, - } - | versions, + passing_versions, failing_versions = tool.gather_version_info( + passing_url, failing_url ) - return TestResult( - host_output_directory=out_dir, result=test_pass, stdouterr=result.stdout - ) - - if args.container_runtime == "local": - passing_url = "local" - failing_url = "local" - elif args.passing_container is None and args.failing_container is None: - # Search through the published containers, narrowing down to a pair of dates with - # the property that the test passed on `range_start` and fails on `range_end`. - range_start, range_end = container_search( - container_exists=lambda date: Container(container_url(date)).exists(), - container_passes=check_container, - start_date=args.start_date, - end_date=args.end_date, - logger=logger, - skip_precondition_checks=args.skip_precondition_checks, - threshold_days=args.threshold_days, - ) - passing_url = container_url(range_start) - failing_url = container_url(range_end) - else: - # Skip the container-level search because at lease one explicit end point was - # given - passing_url = args.passing_container - failing_url = args.failing_container - - # Get the versions from the endpoint containers (if they exist), overridden by any - # explicitly passed versions. - passing_versions, original_passing_versions, passing_package_dirs, passing_env = ( - get_versions(passing_url, args.passing_versions, versions_from_env) - ) - failing_versions, original_failing_versions, failing_package_dirs, failing_env = ( - get_versions(failing_url, args.failing_versions, versions_from_env) - ) - - # If we have two containers, print the differences between their environments. This - # can be useful in the case that rebuilding the good versions in the bad container, - # or vice versa, does not reproduce the expected result. - if passing_env is not None and failing_env is not None: - logger.info(f"Environment differences between {passing_url} and {failing_url}") - for key in passing_env.keys() - failing_env.keys(): - logger.info(f"Only in {passing_url}: {key}={passing_env[key]}") - for key in failing_env.keys() - passing_env.keys(): - logger.info(f"Only in {failing_url}: {key}={failing_env[key]}") - for key in passing_env.keys() & failing_env.keys(): - if passing_env[key] == failing_env[key]: - continue - logger.info( - f"{key}: {passing_env[key]} ({passing_url}) vs. {failing_env[key]} " - f"({failing_url})" - ) - - # We should have versions for all the same software packages at both - # ends of the range, one way or another. TODO: this could be relaxed. - assert passing_versions.keys() == failing_versions.keys(), ( - passing_versions, - failing_versions, - ) - - # Which packages have versions that are not always the same? - dynamic_packages = { - pkg for pkg, _ in set(passing_versions.items()) ^ set(failing_versions.items()) - } - - # Choose an environment to do the version-level bisection in; use directory names that - # match it, and track what the initial versions of the different packages are - if args.container_runtime == "local": - bisection_url = "local" - bisection_versions = original_failing_versions - package_dirs = failing_package_dirs - elif failing_url is not None: - bisection_url = failing_url - bisection_versions = original_failing_versions - package_dirs = failing_package_dirs - else: - assert passing_url is not None - bisection_url = passing_url - bisection_versions = original_passing_versions - package_dirs = passing_package_dirs - assert package_dirs is not None - # This is the set of versions that are already installed - assert bisection_versions is not None - - # Get the full lists of JAX/XLA commits and dates - def get_commit_history( - worker, package, start, end, dir, main_branch=None, feature_branch_name=None - ): - # In particular the end commit might not already be known if the older, - # passing, container is being used for triage. - commits_known = worker.exec( - [ - "sh", - "-c", - f"git cat-file commit {start} && git cat-file commit {end}", - ], - policy="once_per_container", - workdir=dir, - ) - if commits_known.returncode != 0: - worker.check_exec( - ["git", "fetch"], policy="once_per_container", workdir=dir - ) - - # here we're considering the case of non-linear history - # limit for the moment to JAX and XLA - if feature_branch_name and package in ["jax", "xla"]: - logger.info( - f"Using non-linear history logic with main branch {main_branch} and feature branch {feature_branch_name}" - ) - # 1. find the linear range on the main branch - passing_main_commit_cmd = f"git merge-base {start} {end}" - failing_main_commit_cmd = f"git merge-base {end} origin/{args.main_branch}" - - passing_main_commit = worker.check_exec( - ["sh ", "-c", passing_main_commit_cmd], workdir=dir - ).stdout.strip() - failing_main_commit = worker.check_exec( - ["sh", "-c", failing_main_commit_cmd], workdir=dir - ).stdout.strip() - - # 2. find commits to cherry-pick from the failing branch - cherry_pick_cmd = f"git rev-list --reverse {failing_main_commit}..{end}" - cherry_pick_commits_list = ( - worker.check_exec(["sh", "-c", cherry_pick_cmd], workdir=dir) - .stdout.strip() - .splitlines() - ) - if cherry_pick_commits_list: - args.cherry_pick_commits[package] = cherry_pick_commits_list - logger.info(f"Cherry-pick commits: {cherry_pick_commits_list}") - - # 3. now we can use the main branch commits for bisection - start = passing_main_commit - end = failing_main_commit - - result = worker.check_exec( - [ - "git", - "log", - "--first-parent", - "--reverse", - "--format=%H %cI", - f"{start}^..{end}", - ], - policy="once", - stderr=subprocess.PIPE, - workdir=dir, - ) - logger.debug(f"stderr: {result.stderr.strip()}") - data = [] - for line in result.stdout.splitlines(): - commit, date = line.split() - date = datetime.datetime.fromisoformat(date).astimezone( - datetime.timezone.utc - ) - data.append((commit, date)) - return data - - # Fire up the container that will be used for the version-level search and use it to - # extract the relevant history of the repositories that will be triaged. - with Container(bisection_url) as worker: - packages = passing_versions.keys() - log_str = "Bisecting" - for package in packages: - log_str += ( - f" {package} [{passing_versions[package]}, {failing_versions[package]}]" - ) - log_str += f" using {worker}" - logger.info(log_str) - # Get lists of (commit_hash, commit_date) pairs - package_versions = collections.OrderedDict() - for package in packages: - if package not in package_dirs: - # This is a version that came from the container environment, not a git - # checkout directory in the container. Handle those below. - continue - package_versions[package] = get_commit_history( - worker, - package, - passing_versions[package], - failing_versions[package], - package_dirs[package], - args.main_branch, - args.feature_branch_name, - ) - # Confirm they're sorted by commit date - assert all( - b[1] >= a[1] - for a, b in zip( - package_versions[package], package_versions[package][1:] - ) - ) - # Confirm the end values are included as expected - assert passing_versions[package] == package_versions[package][0][0] - assert failing_versions[package] == package_versions[package][-1][0] - # For the packages that just have one or two version numbers, associate those - # version numbers with the earliest and, if appropriate, latest XLA dates. This - # is only relevant to packages that do not have git repositories checked out and - # are managed via installPACKAGE.sh scripts and PACKAGE_VERSION environment vars - for package in packages: - if package in package_versions: - continue - package_versions[package] = [ - (passing_versions[package], package_versions["xla"][0][1]), - ] - if passing_versions[package] != failing_versions[package]: - package_versions[package].append( - (failing_versions[package], package_versions["xla"][-1][1]) - ) - - # Check up-front whether the installation scripts exist for the packages that - # are being triaged by version + script rather than from a git repo + build. - if args.build_scripts_path is not None: - known_scripts = worker.check_exec( - [ - "find", - args.build_scripts_path, - "-maxdepth", - "1", - "-executable", - "-print0", - ], - policy="once", - stderr="separate", - ).stdout.split("\0") - logger.debug(f"Found {known_scripts} inside {worker}") - packages_with_scripts = { - script[len(args.build_scripts_path) + 8 : -3] - for script in known_scripts - if script.startswith(args.build_scripts_path + "/install") - and script.endswith(".sh") - } - logger.debug(f"Found installation scripts for {packages_with_scripts}") - packages_needing_scripts = dynamic_packages - package_dirs.keys() - packages_missing_scripts = packages_needing_scripts - packages_with_scripts - if packages_missing_scripts: - logger.warning( - "No installation scripts found for: " - f"{' '.join(packages_missing_scripts)}, whose version(s) change " - "across the bisection range. These will be excluded from the " - "bisection, which may cause it not to converge!" - ) - dynamic_packages -= packages_missing_scripts - - def build_and_test( - *, versions: typing.Dict[str, str], test_output_log_level: int = logging.DEBUG - ) -> TestResult: - """ - The main body of the bisection loop. Update JAX/XLA/... versions, rebuild, and - run the test command. Throws on error when checking out or building, and returns - the status of the test command. - """ - # Amortise container startup overhead by batching together git commands - git_commands, changed, skipped = [], [], [] - for package in sorted(dynamic_packages): - version = versions[package] - if bisection_versions[package] == version: - # If the existing version is the desired one, do nothing. - skipped.append(f"{package}@{version}") - continue - # Cache which version is now going to be checked out in the container - bisection_versions[package] = version - changed.append(f"{package}@{version}") - if package in package_dirs: - # in case of non-linear history - should we limit this to XLA and JAX only? - package_cherry_picks = args.cherry_pick_commits.get(package, []) - if package_cherry_picks: - logger.info("Working on a non-linear history") - git_commands.append(f"cd {package_dirs[package]}") - git_commands.append("git stash") - # this is a checkout on the main branch - git_commands.append(f"git checkout {version}") - cherry_pick_str = " ".join(package_cherry_picks) - git_commands.append( - f"git cherry-pick {cherry_pick_str} || (echo 'Cherry-pick failed' && exit 1)" - ) - else: - # Linear history - # A git repository that exists in the container. - git_commands += [ - f"cd {package_dirs[package]}", - "git stash", - f"git checkout {version}", - ] - - else: - # Another software package, `version` is probably a version number. - # Installation of this version is delegated to an installPACKAGE.sh - # script that is assumed to be available in `args.build_scripts_path`. - assert args.build_scripts_path is not None - assert package in packages_with_scripts, ( - package, - packages_with_scripts, - ) - extra_env = { - # Need the static part for the .bc library - "NVSHMEM": "DEVEL=1 STATIC=1", - }.get(package, "DEVEL=1") # Always need development headers to rebuild - git_commands += [ - f"{extra_env} {args.build_scripts_path}/install{package}.sh {version}" - ] - # Keep the pathnames shorter by only including packages that actually have - # multiple versions in the bisection range. - brief_versions = { - p: ver for p, ver in versions.items() if p in dynamic_packages - } - out_dir = test_output_directory(bisection_url, versions=brief_versions) - with Container(bisection_url, test_output_host_directory=out_dir) as worker: - change_str = " ".join(changed) if len(changed) else "" - info_str = f"Checking out {change_str} in {worker}" - if len(skipped): - info_str += f", leaving {' '.join(skipped)} unchanged" - logger.info(info_str) - worker.check_exec( - ["sh", "-c", " && ".join(git_commands)], - policy="once_per_container", - ) - # Build JAX - # TODO: teach the tool how to build TransformerEngine too - # TODO: do not build JAX/XLA/TransformerEngine if we know their versions did not change? - before = time.monotonic() - # Unfortunately the build system does not always seem to handle incremental - # rebuilds correctly, so clean the local cache and rely on the remote one. - build_cmds = [ - "bazel clean --expunge", - f"build-jax.sh --bazel-cache={args.bazel_cache}", - ] - worker.check_exec( - ["sh", "-c", " && ".join(build_cmds)], - policy="once_per_container", - workdir=package_dirs["jax"], - ) - middle = time.monotonic() - logger.info(f"Build completed in {middle - before:.1f}s") - # Run the test - test_result = worker.exec( - args.test_command, log_level=test_output_log_level - ) - test_time = time.monotonic() - middle - add_summary_record( - "versions", - { - "build_time": middle - before, - "container": bisection_url, - "output_directory": out_dir.as_posix(), - "result": test_result.returncode == 0, - "test_time": test_time, - } - | versions, - ) - result_str = "pass" if test_result.returncode == 0 else "fail" - logger.info(f"Test completed in {test_time:.1f}s ({result_str})") - return TestResult( - host_output_directory=out_dir, - result=test_result.returncode == 0, - stdouterr=test_result.stdout, - ) - - # Run the version-level bisection - result, last_known_good, first_known_bad = version_search( - versions=package_versions, - build_and_test=build_and_test, - logger=logger, - skip_precondition_checks=args.skip_precondition_checks, - ) - - def symlink(result: typing.Optional[TestResult], symlink_name: str) -> None: - if result is None: - return - symlink = (args.output_prefix / symlink_name).resolve() - assert not symlink.exists(), symlink - assert symlink.parent == result.host_output_directory.parent, ( - symlink, - result.host_output_directory, - ) - symlink.symlink_to(result.host_output_directory.name) + tool.run_version_bisection(passing_versions, failing_versions) - symlink(last_known_good, "last-known-good") - symlink(first_known_bad, "first-known-bad") - result["container"] = failing_url - add_summary_record("result", result, scalar=True) + except Exception as e: + logger.fatal(f"Triage process failed: {e}") diff --git a/.github/triage/jax_toolbox_triage/summary.py b/.github/triage/jax_toolbox_triage/summary.py new file mode 100644 index 000000000..53db602cc --- /dev/null +++ b/.github/triage/jax_toolbox_triage/summary.py @@ -0,0 +1,78 @@ +import json +import logging +import pathlib +import typing +from .logic import TestResult + + +def add_summary_record( + output_prefix: pathlib.Path, + section: str, + record: typing.Union[typing.Dict[str, typing.Any], TestResult], + scalar=False, +): + """ + Add a record to the output JSON file. This is intended to provide a useful record + even in case of a fatal error. + + Args: + output_prefix (pathlib.Path): The prefix for the output directory. + section (str): The section of the summary to which the record belongs. + record (dict or TestResult): The record to be added, either as a dictionary or + as a TestResult object. + scalar (bool): If True, the record is a scalar value; if False, it is a list of + records. Defaults to False. + + Returns: + None + """ + summary_filename = output_prefix / "summary.json" + try: + with open(summary_filename, "r") as ifile: + data = json.load(ifile) + except FileNotFoundError: + data = {} + if scalar: + if section in data: + logging.warning(f"Overwriting summary data in section {section}") + data[section] = record + else: + if section not in data: + data[section] = [] + data[section].append(record) + + with open(summary_filename, "w") as ofile: + json.dump(data, ofile) + + +def create_output_symlinks( + output_prefix: pathlib.Path, + last_known_good: typing.Optional[TestResult], + first_known_bad: typing.Optional[TestResult], +): + """ + Create symlinks to the last-good and first-bad output directories. + versions. + + Args: + output_prefix (pathlib.Path): The prefix for the output directory. + last_known_good (TestResult): The last known good test result. + first_known_bad (TestResult): The first known bad test result. + + Returns: + None + """ + + def symlink(result: typing.Optional[TestResult], symlink_name: str) -> None: + if result is None: + return + symlink_path = (output_prefix / symlink_name).resolve() + assert not symlink_path.exists(), symlink_path + assert symlink_path.parent == result.host_output_directory.parent, ( + symlink_path, + result.host_output_directory, + ) + symlink_path.symlink_to(result.host_output_directory) + + symlink(last_known_good, "last-known-good") + symlink(first_known_bad, "first-known-bad") diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py new file mode 100644 index 000000000..76e7efb52 --- /dev/null +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -0,0 +1,613 @@ +import collections +import datetime +import functools +import hashlib +import itertools +import logging +import pathlib +import time +import typing + +from .args import compulsory_software, optional_software +from .container import Container +from .logic import container_search, TestResult, version_search +from .versions import get_versions_dirs_env +from .summary import add_summary_record, create_output_symlinks +from .bisect import get_commit_history +from .utils import ( + container_url as container_url_base, + prepare_bazel_cache_mounts, +) +from .container_factory import make_container + + +class TriageTool: + """ + This is the main class that orchestrates the whole triage process. + """ + + def __init__(self, args, logger): + self.args = args + self.logger = logger + self.bazel_cache_mounts = [] + self.bisection_url = None + self.bisection_versions = None + self.package_dirs = None + self.dynamic_packages = set() + # the cherry-pick gets populated only for non-linear cases + self.args.cherry_pick_commits = {} + + def _test_output_directory( + self, url: str, versions: typing.Dict[str, str] = None + ) -> pathlib.Path: + """ + Create a directory for test output based on the container URL and versions. + + Args: + url (str): The URL of the container. + versions (dict): A dictionary of software versions. + Returns: + pathlib.Path: The path to the output directory. + """ + hash_chars = 8 + urlhash = f"container-{hashlib.sha1(url.encode()).hexdigest()[:hash_chars]}" + out_dirname = "-".join( + itertools.chain( + [urlhash], + map(lambda t: f"{t[0]}-{t[1][:hash_chars]}", sorted(versions.items())), + ) + ) + out_dir = self.args.output_prefix / out_dirname + assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + out_dir.mkdir(mode=0o755) + return out_dir.resolve() + + def _get_versions( + self, container_url: str, explicit_versions: str, versions_from_env: str + ): + """ + Get the versions of the software packages in the container. + + Args: + container_url (str): The URL of the container. + explicit_versions (str): Explicit versions to use. + versions_from_env (bool): Whether to get versions from environment variables. + Returns: + overriden_versions (dict): The versions with explicit overrides. + url_versions (dict): The versions from the container URL. + dirs (dict): The directories of the software packages. + env (dict): The environment variables in the container. + """ + if explicit_versions is not None and container_url is None: + return explicit_versions, None, None, None + assert ( + container_url is not None + ), "Container URL must be provided if explicit versions are not set." + + with make_container( + self.args.container_runtime, + container_url, + self.bazel_cache_mounts, + self.logger, + ) as worker: + url_versions, dirs, env = get_versions_dirs_env(worker, versions_from_env) + overriden_versions = url_versions.copy() + if explicit_versions is not None: + overriden_versions.update(explicit_versions) + + return overriden_versions, url_versions, dirs, env + + def _gather_histories( + self, + worker: Container, + passing_versions: typing.Dict[str, str], + failing_versions: typing.Dict[str, str], + ): + """ + Gather the commit histories for the passing and failing versions. + + Args: + worker (Container): The container in which to run the commands. + passing_versions (dict): The versions that passed. + failing_versions (dict): The versions that failed. + Returns: + Tuple[List[str], List[str]]: The commit histories for passing and failing versions. + """ + passing_commits = [] + failing_commits = [] + + for package, version in passing_versions.items(): + if package not in compulsory_software + optional_software: + continue + cmd = f"cd /opt/{package} && git log --pretty=format:'%H' -n 1 {version}" + result = worker.check_exec(["sh", "-c", cmd]) + passing_commits.append(result.stdout.strip()) + + for package, version in failing_versions.items(): + if package not in compulsory_software + optional_software: + continue + cmd = f"cd /opt/{package} && git log --pretty=format:'%H' -n 1 {version}" + result = worker.check_exec(["sh", "-c", cmd]) + failing_commits.append(result.stdout.strip()) + + return passing_commits, failing_commits + + def _log_environment_differences(self, url1: str, url2: str, env1: str, env2: str): + """ + If we have two containers, print the differences between their environments. This + can be useful in the case that rebuilding the good versions in the bad container, + or vice versa, does not reproduce the expected result. + + Args: + url1 (str): The URL of the first container. + url2 (str): The URL of the second container. + env1 (dict): The environment variables of the first container. + env2 (dict): The environment variables of the second container. + + Returns: + None + """ + if env1 is None or env2 is None: + return + self.logger.info(f"Environment differences between {url1} and {url2}") + for key in env1.keys() - env2.keys(): + self.logger.info(f"Only in {url1}: {key}={env1[key]}") + for key in env2.keys() - env1.keys(): + self.logger.info(f"Only in {url2}: {key}={env2[key]}") + for key in env1.keys() & env2.keys(): + if env1[key] != env2[key]: + self.logger.info( + f"{key}: {env1[key]} ({url1}) vs. {env2[key]} ({url2})" + ) + + def _check_container_by_date( + self, date: datetime.date, *, test_output_log_level: int = logging.DEBUG + ) -> TestResult: + """ + See if the test passes in the given dated container. + + Args: + date (datetime.date): The date of the container to check. + test_output_log_level (int): The log level for test output. + Returns: + TestResult: The result of the test, including whether it passed and the output. + """ + container_url_func = functools.partial( + container_url_base, + container=self.args.container, + template=self.args.container_url_template, + ) + container_url = container_url_func(date) + + before = time.monotonic() + out_dir = self._test_output_directory(container_url(date)) + + # this is from the previous Container class implementaiton in main + mounts = self.args.container_mount + [(out_dir, "/triage-tool-output")] + + with make_container( + self.args.container_runtime, container_url, mounts, self.logger + ) as worker: + versions, _, _ = get_versions_dirs_env( + worker, self.args.build_scripts_path is not None + ) + result = worker.exec( + self.args.test_command, log_level=test_output_log_level + ) + test_time = time.monotonic() - before + test_pass = result.returncode == 0 + self.logger.info( + f"Ran test case in {worker} in {test_time:.1f}s, pass={test_pass}" + ) + + add_summary_record( + self.args.output_prefix, + "container", + { + "container": container_url(date), + "output_directory": out_dir.as_posix(), + "result": test_pass, + "test_time": test_time, + } + | versions, + ) + return TestResult( + host_output_directory=out_dir, result=test_pass, stdouterr=result.stdout + ) + + def _gather_histories( + self, + worker: Container, + passing_versions: typing.Dict[str, str], + failing_versions: typing.Dict[str, str], + ) -> typing.Tuple[typing.List[str], typing.List[str]]: + """ + Gather the commit histories for the passing and failing versions. + This function is pivotal for non-linear history logic search + + Args: + worker (Container): The container in which to run the commands. + passing_versions (dict): The versions that passed. + failing_versions (dict): The versions that failed. + + Returns: + Tuple[List[str], List[str]]: The commit histories for passing and failing versions. + """ + packages = passing_versions.keys() + self.logger.info( + f"Bisecting {' '.join(f'{p} [{passing_versions[p]}, {failing_versions[p]}]' for p in packages)} using {worker}" + ) + package_versions = collections.OrderedDict() + + for package in packages: + if package not in self.package_dirs: + continue + package_versions[package] = get_commit_history( + worker, + package, + passing_versions[package], + failing_versions[package], + self.package_dirs[package], + main_branch=self.args.main_branch, + feature_branch_name=self.args.feature_branch_name, + logger=self.logger, + args=self.args, + ) + + if not self.args.cherry_pick_commits.get(package): + # Confirm they're sorted by commit date + assert all( + b[1] >= a[1] + for a, b in zip( + package_versions[package], package_versions[package][1:] + ) + ) + # Confirm the end values are included as expected + assert passing_versions[package] == package_versions[package][0][0] + assert failing_versions[package] == package_versions[package][-1][0] + for package in packages: + if package in package_versions: + continue + package_versions[package] = [ + (passing_versions[package], package_versions["xla"][0][1]) + ] + if passing_versions[package] != failing_versions[package]: + package_versions[package].append( + (failing_versions[package], package_versions["xla"][-1][1]) + ) + + return package_versions + + def _check_installation_scripts(self, worker: Container): + """ + Check for special installation cases, like cuBLAS or cuDNN + + Args: + worker (Container): The container in which to run the commands. + """ + if self.args.build_scripts_path is None: + return + + known_scripts_result = worker.exec( + [ + "find", + self.args.build_scripts_path, + "-maxdepth", + "1", + "-executable", + "-print0", + ], + policy="once", + stderr="separate", + ) + if known_scripts_result.returncode != 0: + self.logger.warning( + f"Failed to find known installation scripts in {self.args.build_scripts_path}: {known_scripts_result.stderr}" + ) + known_scripts = [] + else: + known_scripts = known_scripts_result.stdout.split("\0") + + self.logger.debug(f"Found {known_scripts} inside {worker}") + + self.packages_with_scripts = { + script[len(self.args.build_scripts_path) + 8 : -3] + for script in known_scripts + if script.startswith(self.args.build_scripts_path + "/install") + and script.endswith(".sh") + } + self.logger.debug( + f"Found installation scripts for {self.packages_with_scripts}" + ) + packages_needing_scripts = self.dynamic_packages - self.package_dirs.keys() + packages_missing_scripts = packages_needing_scripts - self.packages_with_scripts + if packages_missing_scripts: + self.logger.warning( + "No installation scripts found for: " + f"{' '.join(packages_missing_scripts)}, whose version(s) change " + "across the bisection range. These will be excluded from the " + "bisection, which may cause it not to converge!" + ) + self.dynamic_packages -= packages_missing_scripts + + def _build_and_test( + self, + *, + versions: typing.Dict[str, str], + test_output_log_level: int = logging.DEBUG, + ) -> TestResult: + """ + The main body of the bisection loop. Update JAX/XLA/... versions, rebuild, and + run the test command. Throws on error when checking out or building, and returns + the status of the test command. + + Args: + versions (dict): The versions of the software packages to use. + test_output_log_level (int): The log level for test output. + + Returns: + TestResult: The result of the test, including whether it passed and the output. + """ + # Amortise container startup overhead by batching together git commands + git_commands, changed, skipped = [], [], [] + for package in sorted(self.dynamic_packages): + version = versions[package] + if self.bisection_versions[package] == version: + # If the existing version is the desired one, do nothing. + skipped.append(f"{package}@{version}") + continue + # Cache which version is now going to be checked out in the container + self.bisection_versions[package] = version + changed.append(f"{package}@{version}") + if package in self.package_dirs: + # in case of non-linear history - should we limit this to XLA and JAX only? + package_cherry_picks = self.args.cherry_pick_commits.get(package, []) + if package_cherry_picks: + self.logger.info("Working on a non-linear history") + git_commands.append(f"cd {self.package_dirs[package]}") + git_commands.append("git stash") + # this is a checkout on the main branch + git_commands.append(f"git checkout {version}") + cherry_pick_str = " ".join(package_cherry_picks) + git_commands.append( + f"git cherry-pick {cherry_pick_str} || (echo 'Cherry-pick failed' && exit 1)" + ) + else: + # Linear history + # A git repository that exists in the container. + git_commands += [ + f"cd {self.package_dirs[package]}", + "git stash", + f"git checkout {version}", + ] + + else: + # Another software package, `version` is probably a version number. + # Installation of this version is delegated to an installPACKAGE.sh + # script that is assumed to be available in `args.build_scripts_path`. + assert self.args.build_scripts_path is not None + assert package in self.packages_with_scripts, ( + package, + self.packages_with_scripts, + ) + extra_env = { + # Need the static part for the .bc library + "NVSHMEM": "DEVEL=1 STATIC=1", + }.get(package, "DEVEL=1") # Always need development headers to rebuild + git_commands += [ + f"{extra_env} {self.args.build_scripts_path}/install{package}.sh {version}" + ] + # Keep the pathnames shorter by only including packages that actually have + # multiple versions in the bisection range. + brief_versions = { + p: ver for p, ver in versions.items() if p in self.dynamic_packages + } + out_dir = self._test_output_directory( + self.bisection_url, versions=brief_versions + ) + with make_container( + self.args.container_runtime, + self.bisection_url, + self.bazel_cache_mounts, + self.logger, + test_output_host_directory=out_dir, + ) as worker: + change_str = " ".join(changed) if len(changed) else "" + info_str = f"Checking out {change_str} in {worker}" + if len(skipped): + info_str += f", leaving {' '.join(skipped)} unchanged" + self.logger.info(info_str) + worker.check_exec( + ["sh", "-c", " && ".join(git_commands)], + policy="once_per_container", + ) + # Build JAX + # TODO: teach the tool how to build TransformerEngine too + # TODO: do not build JAX/XLA/TransformerEngine if we know their versions did not change? + before = time.monotonic() + # Unfortunately the build system does not always seem to handle incremental + # rebuilds correctly, so clean the local cache and rely on the remote one. + build_cmds = [ + "bazel clean --expunge", + f"build-jax.sh --bazel-cache={self.args.bazel_cache}", + ] + worker.check_exec( + ["sh", "-c", " && ".join(build_cmds)], + policy="once_per_container", + workdir=self.package_dirs["jax"], + ) + middle = time.monotonic() + self.logger.info(f"Build completed in {middle - before:.1f}s") + # Run the test + test_result = worker.exec( + self.args.test_command, log_level=test_output_log_level + ) + test_time = time.monotonic() - middle + + add_summary_record( + self.args.output_prefix, + "versions", + { + "build_time": middle - before, + "container": self.bisection_url, + "output_directory": out_dir.as_posix(), + "result": test_result.returncode == 0, + "test_time": test_time, + } + | versions, + ) + result_str = "pass" if test_result.returncode == 0 else "fail" + self.logger.info(f"Test completed in {test_time:.1f}s ({result_str})") + return TestResult( + host_output_directory=out_dir, + result=test_result.returncode == 0, + stdouterr=test_result.stdout, + ) + + def prepare(self): + """ + Function to prepare the triage tool for execution. + """ + self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args) + + def find_container_range(self) -> typing.Tuple[str, str]: + """ + Find the range from the passing and failing containers. + Returns a tuple of the start and end container names. + """ + if self.args.container_runtime == "local": + return "local", "local" + + container_url_func = functools.partial( + container_url_base, + container=self.args.container, + template=self.args.container_url_template, + ) + + if self.args.passing_container is None and self.args.failing_container is None: + range_start, range_end = container_search( + container_exists=lambda date: make_container( + self.args.container_runtime, + container_url_func(date), + [], + self.logger, + ).exists(), + container_passes=self._check_container_by_date, + start_date=self.args.start_date, + end_date=self.args.end_date, + logger=self.logger, + skip_precondition_checks=self.args.skip_precondition_checks, + threshold_days=self.args.threshold_days, + ) + + return container_url_func(range_start), container_url_func(range_end) + else: + return self.args.passing_container, self.args.failing_container + + def gather_version_info(self, passing_url: str, failing_url: str): + """ + Gather version information from the passing and failing containers. + + Args: + passing_url (str): The URL of the passing container. + failing_url (str): The URL of the failing container. + + Returns: + Tuple[dict, dict, dict, dict]: The versions, URLs, directories, and environment variables. + """ + versions_from_env = self.args.build_scripts_path is not None + # Get the versions from the endpoint containers (if they exist), overridden by any + # explicitly passed versions. + ( + passing_versions, + original_passing_versions, + passing_package_dirs, + passing_env, + ) = self._get_versions( + passing_url, self.args.passing_versions, versions_from_env + ) + ( + failing_versions, + original_failing_versions, + failing_package_dirs, + failing_env, + ) = self._get_versions( + failing_url, self.args.failing_versions, versions_from_env + ) + + self._log_environment_differences( + passing_url, failing_url, passing_env, failing_env + ) + + # We should have versions for all the same software packages at both + # ends of the range, one way or another. TODO: this could be relaxed. + assert passing_versions.keys() == failing_versions.keys(), ( + passing_versions, + failing_versions, + ) + # Which packages have versions that are not always the same? + # TODO: DOUBLE CHECK THIS + self.dynamic_packages = { + pkg + for pkg, _ in set(passing_versions.items()) ^ set(failing_versions.items()) + } + # Choose an environment to do the version-level bisection in; use directory names that + # match it, and track what the initial versions of the different packages are + if self.args.container_runtime == "local": + self.bisection_url = "local" + self.bisection_versions = original_failing_versions + self.package_dirs = failing_package_dirs + elif failing_url is not None: + self.bisection_url = failing_url + self.bisection_versions = original_failing_versions + self.package_dirs = failing_package_dirs + else: + assert passing_url is not None + self.bisection_url = passing_url + self.bisection_versions = original_passing_versions + self.package_dirs = passing_package_dirs + assert self.package_dirs is not None + # This is the set of versions that are already installed + assert self.bisection_versions is not None + return passing_versions, failing_versions + + def run_version_bisection( + self, + passing_versions: typing.Dict[str, str], + failing_versions: typing.Dict[str, str], + ) -> typing.Tuple[typing.Dict[str, str], TestResult]: + """ + Run the version bisection process. + + Args: + passing_versions (dict): The versions that passed. + failing_versions (dict): The versions that failed. + + Returns: + Tuple[dict, TestResult]: The final versions and the test result. + """ + # Prepare the container for the bisection + with make_container( + self.args.container_runtime, + self.bisection_url, + self.bazel_cache_mounts, + self.logger, + ) as worker: + package_versions = self._gather_histories( + worker, passing_versions, failing_versions + ) + self._check_installation_scripts(worker) + + # Run the version-level bisection + result, last_known_good, first_known_bad = version_search( + versions=package_versions, + build_and_test=self._build_and_test, + logger=self.logger, + skip_precondition_checks=self.args.skip_precondition_checks, + ) + # Write final summary + create_output_symlinks( + self.args.output_prefix, last_known_good, first_known_bad + ) + result["container"] = self.bisection_url + add_summary_record(self.args.output_prefix, "result", result, scalar=True) diff --git a/.github/triage/jax_toolbox_triage/versions.py b/.github/triage/jax_toolbox_triage/versions.py new file mode 100644 index 000000000..1c8c37cca --- /dev/null +++ b/.github/triage/jax_toolbox_triage/versions.py @@ -0,0 +1,89 @@ +import typing +from .container import Container +from .args import compulsory_software, optional_software + + +def get_env(worker: Container) -> typing.Dict[str, str]: + """ + Get the runtime environment in the given container. + + Returns: {env_var: value} dictionary, sorted by key. + """ + + def impl() -> typing.Dict[str, str]: + kvs = ( + worker.check_exec(["env", "-0"], policy="once", stderr="separate") + .stdout[:-1] # skip the trailing \0 + .split("\0") + ) + return dict(kv.split("=", 1) for kv in kvs) + + # Remove any environment variables that differ between consecutive `env` calls, for + # example some step-specific Slurm variables. + env1, env2 = impl(), impl() + # sorted(...) for run-to-run determinism + return {k: env1[k] for k in sorted(env1.keys() & env2.keys()) if env1[k] == env2[k]} + + +def get_commits_and_dirs( + worker: Container, +) -> typing.Tuple[typing.Dict[str, str], typing.Dict[str, str]]: + """ + Get the git repository paths and current HEAD commits in the given environment of + the software packages named in `compulsory_software` and `optional_software`. + + Returns: ({package: commit}, {package: directory}) + """ + # Formulated this way to avoid paying too many times for container startup. + cmds = [] + for package in compulsory_software + optional_software: + bits = [ + f"(cd /opt/{package} && git rev-parse HEAD && echo {package} && pwd)", + f"(cd /opt/{package}-source && git rev-parse HEAD && echo {package} && pwd)", + ] + if package in optional_software: + bits.append("true") + cmds.append(f"({' || '.join(bits)})") + result = worker.check_exec( + ["sh", "-c", " && ".join(cmds)], policy="once", stderr="separate" + ) + versions, dirs = {}, {} + # Look over triplets of output lines + for commit, package, dirname in zip(*([iter(result.stdout.splitlines())] * 3)): + dirs[package] = dirname + versions[package] = commit + return versions, dirs + + +def get_versions_dirs_env( + worker: Container, + versions_from_env: bool, +) -> typing.Tuple[typing.Dict[str, str], typing.Dict[str, str], typing.Dict[str, str]]: + """ + Get software versions in the given [container] environment, git repository paths + where relevant, and the runtime environment. + + The list of software versions is drawn from git repositories at known container + locations and, if `versions_from_env` is True, from the environment. + + Returns: + versions: {package: version or commit}, + dirs: {package: git_repository_dir} + env: {env_var: value} + """ + # Get the git repository paths and commits from the container. + versions, dirs = get_commits_and_dirs(worker) + + # Get the environment variables from the container. + env = get_env(worker) + + if versions_from_env: + # Promote any XXX_VERSION environment variables into `versions` if `XXX` is + # not already there. + for k, v in env.items(): + if not len(v) or not k.endswith("_VERSION"): + continue + package = k[:-8] + assert package not in versions, (versions, package) + versions[package] = v + return versions, dirs, env From acd5bb375e84f16264e9d3752845d4656b776f87 Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 7 Jul 2025 11:08:54 +0100 Subject: [PATCH 07/50] save this version before testing it --- .github/triage/jax_toolbox_triage/bisect.py | 17 +++++ .../jax_toolbox_triage/container_factory.py | 2 +- .../triage/jax_toolbox_triage/triage_tool.py | 63 ++++++++++++------- 3 files changed, 60 insertions(+), 22 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index 8ca2b02de..927ce169d 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -13,6 +13,23 @@ def get_commit_history( logger=None, args=None, ): + """ + Get the commit history for a given package between two commits. + + Args: + worker (Container): The container worker to execute commands. + package (str): The name of the package. + start (str): The starting commit hash. + end (str): The ending commit hash. + dir (str): The directory where the git repository is located. + main_branch (str, optional): The main branch name. Defaults to None. + feature_branch_name (str, optional): The feature branch name. Defaults to None. + logger (Logger, optional): Logger for debug information. Defaults to None. + args: Additional arguments that may contain cherry-pick commits. + + Returns: + list: A list of tuples containing commit hashes and their corresponding dates. + """ # In particular the end commit might not already be known if the older, # passing, container is being used for triage. commits_known = worker.exec( diff --git a/.github/triage/jax_toolbox_triage/container_factory.py b/.github/triage/jax_toolbox_triage/container_factory.py index 1a2febac6..0c254ba0d 100644 --- a/.github/triage/jax_toolbox_triage/container_factory.py +++ b/.github/triage/jax_toolbox_triage/container_factory.py @@ -9,7 +9,7 @@ def make_container( runtime: str, url: str, mounts: list, logger: logging.Logger, **kwargs ) -> Container: """ - This function craetes a container objects, based on the specified runtime + This function creates a container object, based on the specified runtime Args: runtime (str): The container runtime to use (e.g., 'docker', 'pyxis', 'local'). diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 76e7efb52..58d8d377f 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -8,7 +8,6 @@ import time import typing -from .args import compulsory_software, optional_software from .container import Container from .logic import container_search, TestResult, version_search from .versions import get_versions_dirs_env @@ -102,7 +101,7 @@ def _gather_histories( worker: Container, passing_versions: typing.Dict[str, str], failing_versions: typing.Dict[str, str], - ): + ) -> collections.OrderedDict: """ Gather the commit histories for the passing and failing versions. @@ -111,26 +110,48 @@ def _gather_histories( passing_versions (dict): The versions that passed. failing_versions (dict): The versions that failed. Returns: - Tuple[List[str], List[str]]: The commit histories for passing and failing versions. + collections.OrderDict: The commit histories for passing and failing versions. """ - passing_commits = [] - failing_commits = [] + packages = passing_versions.keys() + package_versions = collections.OrderedDict() - for package, version in passing_versions.items(): - if package not in compulsory_software + optional_software: + for package in packages: + if package not in self.package_dirs: continue - cmd = f"cd /opt/{package} && git log --pretty=format:'%H' -n 1 {version}" - result = worker.check_exec(["sh", "-c", cmd]) - passing_commits.append(result.stdout.strip()) + package_versions[package] = get_commit_history( + worker, + package, + passing_versions[package], + failing_versions[package], + self.package_dirs[package], + main_branch=self.args.main_branch, + feature_branch_name=self.args.feature_branch_name, + logger=self.logger, + args=self.args, + ) + + if not self.args.cherry_pick_commits.get(package): + assert all( + b[1] >= a[1] + for a, b in zip( + package_versions[package], package_versions[package][1:] + ) + ) + assert passing_versions[package] == package_versions[package][0][0] + assert failing_versions[package] == package_versions[package][-1][0] - for package, version in failing_versions.items(): - if package not in compulsory_software + optional_software: + for package in packages: + if package in package_versions: continue - cmd = f"cd /opt/{package} && git log --pretty=format:'%H' -n 1 {version}" - result = worker.check_exec(["sh", "-c", cmd]) - failing_commits.append(result.stdout.strip()) + package_versions[package] = [ + (passing_versions[package], package_versions["xla"][0][1]) + ] + if passing_versions[package] != failing_versions[package]: + package_versions[package].append( + (failing_versions[package], package_versions["xla"][-1][1]) + ) - return passing_commits, failing_commits + return package_versions def _log_environment_differences(self, url1: str, url2: str, env1: str, env2: str): """ @@ -180,9 +201,9 @@ def _check_container_by_date( container_url = container_url_func(date) before = time.monotonic() - out_dir = self._test_output_directory(container_url(date)) + out_dir = self._test_output_directory(container_url) - # this is from the previous Container class implementaiton in main + # this is from the previous Container class implementation in main mounts = self.args.container_mount + [(out_dir, "/triage-tool-output")] with make_container( @@ -467,6 +488,7 @@ def _build_and_test( def prepare(self): """ Function to prepare the triage tool for execution. + At the moment, we're adding the bazel cache mounts to the tool. """ self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args) @@ -512,8 +534,6 @@ def gather_version_info(self, passing_url: str, failing_url: str): passing_url (str): The URL of the passing container. failing_url (str): The URL of the failing container. - Returns: - Tuple[dict, dict, dict, dict]: The versions, URLs, directories, and environment variables. """ versions_from_env = self.args.build_scripts_path is not None # Get the versions from the endpoint containers (if they exist), overridden by any @@ -546,7 +566,8 @@ def gather_version_info(self, passing_url: str, failing_url: str): failing_versions, ) # Which packages have versions that are not always the same? - # TODO: DOUBLE CHECK THIS + # TODO: DOUBLE CHECK THIS what if: + # pkg for pkg in passing_versions if passing_versions[pkg] != failing_versions[pkg] self.dynamic_packages = { pkg for pkg, _ in set(passing_versions.items()) ^ set(failing_versions.items()) From 15ff577aaf2d9ef6d530dcd51d2242e79eb2c76a Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 7 Jul 2025 11:39:27 +0100 Subject: [PATCH 08/50] fix argument --- .github/triage/jax_toolbox_triage/triage_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 58d8d377f..1012ffa90 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -490,7 +490,7 @@ def prepare(self): Function to prepare the triage tool for execution. At the moment, we're adding the bazel cache mounts to the tool. """ - self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args) + self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args.bazel_cache) def find_container_range(self) -> typing.Tuple[str, str]: """ From 6bdc74c2431719e857911b40c4ef0adf9dc74057 Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 7 Jul 2025 14:28:34 +0100 Subject: [PATCH 09/50] fix directory mount --- .github/triage/jax_toolbox_triage/triage_tool.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 1012ffa90..cda1ba7f3 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -426,12 +426,17 @@ def _build_and_test( out_dir = self._test_output_directory( self.bisection_url, versions=brief_versions ) + mounts = ( + self.bazel_cache_mounts + + self.args.container_mount + + [(out_dir, "/triage-tool-output")] + ) + with make_container( self.args.container_runtime, self.bisection_url, - self.bazel_cache_mounts, + mounts, self.logger, - test_output_host_directory=out_dir, ) as worker: change_str = " ".join(changed) if len(changed) else "" info_str = f"Checking out {change_str} in {worker}" From 6bc0f85cf0292bd942ef83e6a0c1930d57e41049 Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 7 Jul 2025 16:48:59 +0100 Subject: [PATCH 10/50] finish off the testing with real examples --- .github/triage/jax_toolbox_triage/triage_tool.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index cda1ba7f3..4ad3b172c 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -33,6 +33,7 @@ def __init__(self, args, logger): self.bisection_versions = None self.package_dirs = None self.dynamic_packages = set() + self.packages_with_scripts = set() # the cherry-pick gets populated only for non-linear cases self.args.cherry_pick_commits = {} @@ -502,6 +503,7 @@ def find_container_range(self) -> typing.Tuple[str, str]: Find the range from the passing and failing containers. Returns a tuple of the start and end container names. """ + self.logger.info("Finding container range...") if self.args.container_runtime == "local": return "local", "local" @@ -540,6 +542,8 @@ def gather_version_info(self, passing_url: str, failing_url: str): failing_url (str): The URL of the failing container. """ + self.logger.info("Gathering version information...") + self.logger.info(f"Using {self.bisection_url} for version-level bisection...") versions_from_env = self.args.build_scripts_path is not None # Get the versions from the endpoint containers (if they exist), overridden by any # explicitly passed versions. @@ -571,8 +575,6 @@ def gather_version_info(self, passing_url: str, failing_url: str): failing_versions, ) # Which packages have versions that are not always the same? - # TODO: DOUBLE CHECK THIS what if: - # pkg for pkg in passing_versions if passing_versions[pkg] != failing_versions[pkg] self.dynamic_packages = { pkg for pkg, _ in set(passing_versions.items()) ^ set(failing_versions.items()) @@ -612,6 +614,7 @@ def run_version_bisection( Returns: Tuple[dict, TestResult]: The final versions and the test result. """ + self.logger.info("Running version-level bisection...") # Prepare the container for the bisection with make_container( self.args.container_runtime, From c5f758ef1f59a6687124f176b289b2387983db27 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 8 Jul 2025 13:39:09 +0100 Subject: [PATCH 11/50] fix tests --- .github/triage/jax_toolbox_triage/bisect.py | 9 +- .../triage/jax_toolbox_triage/triage_tool.py | 6 +- .../tests/test_triage_history_bisection.py | 88 ++----------------- 3 files changed, 15 insertions(+), 88 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index 927ce169d..bbb09b3a4 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -42,7 +42,12 @@ def get_commit_history( workdir=dir, ) if commits_known.returncode != 0: - worker.check_exec(["git", "fetch"], policy="once_per_container", workdir=dir) + if worker.exec(["git", "remote"]).stout.strip(): + worker.check_exec( + ["git", "fetch"], policy="once_per_container", workdir=dir + ) + else: + logger.warning("No remote found, skipping fetch.") # here we're considering the case of non-linear history # limit for the moment to JAX and XLA @@ -56,7 +61,7 @@ def get_commit_history( failing_main_commit_cmd = f"git merge-base {end} origin/{main_branch}" passing_main_commit = worker.check_exec( - ["sh ", "-c", passing_main_commit_cmd], workdir=dir + ["sh", "-c", passing_main_commit_cmd], workdir=dir ).stdout.strip() failing_main_commit = worker.check_exec( ["sh", "-c", failing_main_commit_cmd], workdir=dir diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 4ad3b172c..620399ec3 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -171,11 +171,11 @@ def _log_environment_differences(self, url1: str, url2: str, env1: str, env2: st """ if env1 is None or env2 is None: return - self.logger.info(f"Environment differences between {url1} and {url2}") + self.logger.info(f"Environment differences between {url1} and {url2}:") for key in env1.keys() - env2.keys(): - self.logger.info(f"Only in {url1}: {key}={env1[key]}") + self.logger.info(f"\tOnly in {url1}: {key}={env1[key]}") for key in env2.keys() - env1.keys(): - self.logger.info(f"Only in {url2}: {key}={env2[key]}") + self.logger.info(f"\tOnly in {url2}: {key}={env2[key]}") for key in env1.keys() & env2.keys(): if env1[key] != env2[key]: self.logger.info( diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 49480e7af..0bb5f5767 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -5,11 +5,8 @@ import logging from collections import OrderedDict import pytest -import datetime -# for the moment avoid using this, because we can't import it -# then we'll refactor the main code -# from jax_toolbox_triage.main import get_commit_history +from jax_toolbox_triage.bisect import get_commit_history from jax_toolbox_triage.logic import version_search, TestResult from jax_toolbox_triage.container import Container @@ -84,81 +81,6 @@ def exists(self) -> bool: return True -def get_commit_history( - worker, package, start, end, dir, main_branch, feature_branch_name, args, logger -): - """ - This is a local copy of the get_commit_history logic from main.py, - For the moment we don't want to import it, we'll then refactor the main code - """ - # In particular the end commit might not already be known if the older, - # passing, container is being used for triage. - commits_known = worker.exec( - [ - "sh", - "-c", - f"git cat-file commit {start} && git cat-file commit {end}", - ], - policy="once_per_container", - workdir=dir, - ) - - if commits_known.returncode != 0: - logger.error("ERROR!") - logger.error(f"{commits_known.stderr}") - - if feature_branch_name and package in ["jax", "xla"]: - logger.info( - f"Using non-linear history logic with main branch {main_branch} and feature branch {feature_branch_name}" - ) - passing_main_commit_cmd = f"git merge-base {start} {end}" - failing_main_commit_cmd = f"git merge-base {end} origin/{main_branch}" - - # In a local test, origin doesn't exist, so we use the local main branch ref. - failing_main_commit_cmd = f"git merge-base {end} {main_branch}" - - passing_main_commit = worker.check_exec( - ["sh", "-c", passing_main_commit_cmd], workdir=dir - ).stdout.strip() - failing_main_commit = worker.check_exec( - ["sh", "-c", failing_main_commit_cmd], workdir=dir - ).stdout.strip() - - cherry_pick_cmd = f"git rev-list --reverse {failing_main_commit}..{end}" - cherry_pick_commits_list = ( - worker.check_exec(["sh", "-c", cherry_pick_cmd], workdir=dir) - .stdout.strip() - .splitlines() - ) - if cherry_pick_commits_list: - args.cherry_pick_commits[package] = cherry_pick_commits_list - - start = passing_main_commit - end = failing_main_commit - - result = worker.check_exec( - [ - "git", - "log", - "--first-parent", - "--reverse", - "--format=%H %cI", - f"{start}^..{end}", - ], - policy="once", - stderr=subprocess.PIPE, - workdir=dir, - ) - data = [] - for line in result.stdout.splitlines(): - commit, date_str = line.split() - date = datetime.datetime.fromisoformat(date_str).astimezone( - datetime.timezone.utc - ) - data.append((commit, date)) - return data - - @pytest.fixture def triage_test_env(): """ @@ -240,17 +162,17 @@ def git_cmd(command, *args): } -# Do we need to parametrize the test cases? +# TEST CASES @pytest.mark.parametrize( "scenario, passing_commit_key, failing_commit_key, use_nonlinear_flags, expected_good_key, expected_bad_key", [ ( "Non-Linear History", # scenario "passing_nonlinear", - "failing_nonlinear", # bisection range + "failing_nonlinear", True, # use the new flag "good_main", - "bad_main", # expected results + "bad_main", ), ("Linear History", "good_main", "bad_main", False, "good_main", "bad_main"), ], @@ -285,7 +207,7 @@ class MockArgs: failing_versions = {"jax": all_commits[failing_commit_key]} package_dirs = {"jax": str(jax_repo_path)} mock_container = MockContainer(paths["scripts"], logger) - # call the get_commit_history + package_versions = OrderedDict() package_versions["jax"] = get_commit_history( worker=mock_container, From 3ed49f23d7d147bc083f33d0fa8f88e958267277 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 8 Jul 2025 14:02:01 +0100 Subject: [PATCH 12/50] fix test with origin main --- .github/triage/jax_toolbox_triage/bisect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index bbb09b3a4..76094d451 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -58,7 +58,7 @@ def get_commit_history( # 1. find the linear range on the main branch passing_main_commit_cmd = f"git merge-base {start} {end}" - failing_main_commit_cmd = f"git merge-base {end} origin/{main_branch}" + failing_main_commit_cmd = f"git merge-base {end} {main_branch}" passing_main_commit = worker.check_exec( ["sh", "-c", passing_main_commit_cmd], workdir=dir From fbc6faafbe885ed73e82707b1085072a7002e7c3 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 8 Jul 2025 14:33:22 +0100 Subject: [PATCH 13/50] fix test and code, so we can avoid having feature-branch --- .github/triage/jax_toolbox_triage/args.py | 8 +------ .github/triage/jax_toolbox_triage/bisect.py | 18 ++++++++------- .../triage/jax_toolbox_triage/triage_tool.py | 2 -- .../tests/test_triage_history_bisection.py | 22 ++++++++----------- 4 files changed, 20 insertions(+), 30 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index eba988361..e6afadf1f 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -229,13 +229,7 @@ def parse_args(args=None) -> argparse.Namespace: "--main-branch", type=str, default="main", - help="The name of the main branch, linear branch to be used for bisection", - ) - parser.add_argument( - "--feature-branch-name", - type=str, - default=None, - help="The name of the feature branch (e.g. blackwell-devel) to derive cherry-picks from", + help="The name of the main branch (e.g. main) to derive cherry-picks from", ) args = parser.parse_args(args=args) assert args.container_runtime in { diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index 76094d451..d44524db0 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -9,7 +9,6 @@ def get_commit_history( end, dir, main_branch=None, - feature_branch_name=None, logger=None, args=None, ): @@ -23,7 +22,6 @@ def get_commit_history( end (str): The ending commit hash. dir (str): The directory where the git repository is located. main_branch (str, optional): The main branch name. Defaults to None. - feature_branch_name (str, optional): The feature branch name. Defaults to None. logger (Logger, optional): Logger for debug information. Defaults to None. args: Additional arguments that may contain cherry-pick commits. @@ -49,12 +47,16 @@ def get_commit_history( else: logger.warning("No remote found, skipping fetch.") - # here we're considering the case of non-linear history - # limit for the moment to JAX and XLA - if feature_branch_name and package in ["jax", "xla"]: - logger.info( - f"Using non-linear history logic with main branch {main_branch} and feature branch {feature_branch_name}" - ) + # detect non-linear history + is_ancestor_cmd = f"git merge-base --is-ancestor {start} {end}" + is_ancestor_result = worker.exec( + ["sh", "-c", is_ancestor_cmd], + workdir=dir, + ) + is_linear = is_ancestor_result.returncode == 0 + + if not is_linear and package in ["jax", "xla"]: + logger.info(f"Using non-linear history logic with main branch {main_branch}") # 1. find the linear range on the main branch passing_main_commit_cmd = f"git merge-base {start} {end}" diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 620399ec3..ae9753ebe 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -126,7 +126,6 @@ def _gather_histories( failing_versions[package], self.package_dirs[package], main_branch=self.args.main_branch, - feature_branch_name=self.args.feature_branch_name, logger=self.logger, args=self.args, ) @@ -271,7 +270,6 @@ def _gather_histories( failing_versions[package], self.package_dirs[package], main_branch=self.args.main_branch, - feature_branch_name=self.args.feature_branch_name, logger=self.logger, args=self.args, ) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 0bb5f5767..3d50300ab 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -121,6 +121,7 @@ def git_cmd(command, *args): # main git_cmd("init", "-b", "main") + git_cmd("remote", "add", "origin", str(jax_repo_path)) git_cmd("config", "user.name", "Test User") git_cmd("config", "user.email", "test@user.it") # Create a linear commit history @@ -164,17 +165,16 @@ def git_cmd(command, *args): # TEST CASES @pytest.mark.parametrize( - "scenario, passing_commit_key, failing_commit_key, use_nonlinear_flags, expected_good_key, expected_bad_key", + "scenario, passing_commit_key, failing_commit_key, expected_good_key, expected_bad_key", [ ( - "Non-Linear History", # scenario + "Non-Linear History", "passing_nonlinear", "failing_nonlinear", - True, # use the new flag "good_main", "bad_main", ), - ("Linear History", "good_main", "bad_main", False, "good_main", "bad_main"), + ("Linear History", "good_main", "bad_main", "good_main", "bad_main"), ], ) def test_triage_scenarios( @@ -182,7 +182,6 @@ def test_triage_scenarios( scenario, passing_commit_key, failing_commit_key, - use_nonlinear_flags, expected_good_key, expected_bad_key, ): @@ -193,7 +192,6 @@ def test_triage_scenarios( class MockArgs: main_branch = "main" - feature_branch_name = "feature" if use_nonlinear_flags else None bazel_cache = "" build_scripts_path = None test_command = ["test-case.sh", str(jax_repo_path), all_commits["bad_main"]] @@ -216,7 +214,6 @@ class MockArgs: end=failing_versions["jax"], dir=package_dirs["jax"], main_branch=args.main_branch, - feature_branch_name=args.feature_branch_name, args=args, logger=logger, ) @@ -227,17 +224,16 @@ def build_and_test_wrapper(*, versions, test_output_log_level=logging.DEBUG): mock_container.check_exec( ["git", "stash", "--include-untracked"], workdir=workdir ) + cherry_pick_commits = args.cherry_pick_commits.get("jax", []) - if use_nonlinear_flags: + if cherry_pick_commits: build_script = paths["scripts"] / "build-jax.sh" mock_container.check_exec( ["git", "checkout", versions["jax"]], workdir=workdir ) - cherry_picks = args.cherry_pick_commits.get("jax", []) - if cherry_picks: - mock_container.check_exec( - ["git", "cherry-pick"] + cherry_picks, workdir=workdir - ) + mock_container.check_exec( + ["git", "cherry-pick"] + cherry_pick_commits, workdir=workdir + ) else: build_script = paths["scripts"] / "build-jax-linear.sh" build_script.write_text("#!/bin/sh\nexit 0") From 5845bfe3974f11f236f31d33495c726d46336adb Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 8 Jul 2025 15:06:01 +0100 Subject: [PATCH 14/50] fix initialization of package and create a more robust test --- .../triage/jax_toolbox_triage/triage_tool.py | 2 +- .../tests/test_triage_history_bisection.py | 2 +- .../triage/tests/test_triage_tool_class.py | 231 ++++++++++++++++++ 3 files changed, 233 insertions(+), 2 deletions(-) create mode 100644 .github/triage/tests/test_triage_tool_class.py diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index ae9753ebe..8508d20a5 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -30,7 +30,7 @@ def __init__(self, args, logger): self.logger = logger self.bazel_cache_mounts = [] self.bisection_url = None - self.bisection_versions = None + self.bisection_versions = {} self.package_dirs = None self.dynamic_packages = set() self.packages_with_scripts = set() diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 3d50300ab..92544845d 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -185,7 +185,7 @@ def test_triage_scenarios( expected_good_key, expected_bad_key, ): - """Check if we nee dot restructure this + add types""" + """Test the get_commit_history for linear and non-linear histories.""" paths = triage_test_env["paths"] all_commits = triage_test_env["commits"] jax_repo_path = paths["repo"] / "jax" diff --git a/.github/triage/tests/test_triage_tool_class.py b/.github/triage/tests/test_triage_tool_class.py new file mode 100644 index 000000000..7e41524da --- /dev/null +++ b/.github/triage/tests/test_triage_tool_class.py @@ -0,0 +1,231 @@ +import subprocess +import tempfile +import pathlib +import os +import logging +import pytest + +from jax_toolbox_triage.triage_tool import TriageTool +from jax_toolbox_triage.logic import version_search +from jax_toolbox_triage.container import Container + + +def run_command(command, cwd=None, env=None): + """Simple function to run a command in a subprocess. + + Args: + command (list): The command to run as a list of strings. + cwd (str, optional): The working directory to run the command in. + env (dict, optional): Environment variables to set for the command. + Returns: + str: The standard output of the command. + """ + try: + result = subprocess.run( + command, cwd=cwd, env=env, check=True, capture_output=True, text=True + ) + return result.stdout.strip() + except subprocess.CalledProcessError as e: + logging.error(f"Command '{' '.join(command)}' failed with error: {e}") + raise e + + +class MockContainer(Container): + """A mock container class for testing purposes.""" + + def __init__(self, mock_scripts_path, logger): + super().__init__(logger=logger) + self.mock_scripts_path = mock_scripts_path + self._env = os.environ.copy() + self._env["PATH"] = f"{self.mock_scripts_path}:{self._env['PATH']}" + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + pass + + def __repr__(self): + return "MockContainer" + + def check_exec(self, cmd, **kwargs): + """Override the check_exec""" + return super().check_exec(cmd, **kwargs) + + def exec( + self, + command, + *, + policy="default", + stderr="interleaved", + workdir=None, + log_level=logging.DEBUG, + ): + self._logger.debug(f"Executing command: {command} in {workdir}") + is_shell_command = command[0] == "sh" and command[1] == "-c" + cmd_to_run = command[2] if is_shell_command else command + try: + return subprocess.run( + cmd_to_run, + capture_output=True, + text=True, + cwd=workdir, + env=self._env, + shell=is_shell_command, + ) + except FileNotFoundError as e: + return subprocess.CompletedProcess(command, 127, stderr=str(e)) + + def exists(self) -> bool: + return True + + +@pytest.fixture +def triage_test_env(): + """ + Fixture to set up the test environment for triage tests. + + The fixture creates a temp directory and a git repo with a + defined linear and non-linear history. + + The fixture yields a dictionary of paths and commit hashes + """ + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) + repo_path = temp_path / "repos" + output_path = temp_path / "output" + mock_scripts_path = temp_path / "mock_scripts" + repo_path.mkdir() + output_path.mkdir() + mock_scripts_path.mkdir() + + source_scripts_dir = pathlib.Path(__file__).parent / "mock_scripts" + build_script_content = (source_scripts_dir / "build-jax.sh").read_text() + (mock_scripts_path / "build-jax.sh").write_text(build_script_content) + os.chmod(mock_scripts_path / "build-jax.sh", 0o755) + + test_case_content = (source_scripts_dir / "test-case.sh").read_text() + (mock_scripts_path / "test-case.sh").write_text(test_case_content) + os.chmod(mock_scripts_path / "test-case.sh", 0o755) + + jax_repo_path = repo_path / "jax" + jax_repo_path.mkdir() + + def git_cmd(command, *args): + return run_command(["git", command, *args], cwd=jax_repo_path) + + git_cmd("init", "-b", "main") + git_cmd("remote", "add", "origin", str(jax_repo_path)) + git_cmd("config", "user.name", "Test User") + git_cmd("config", "user.email", "test@user.it") + + git_cmd("commit", "--allow-empty", "-m", "M1") + m1 = git_cmd("rev-parse", "HEAD") + + git_cmd("commit", "--allow-empty", "-m", "M2") + m2 = git_cmd("rev-parse", "HEAD") + + git_cmd("commit", "--allow-empty", "-m", "M3") + m3 = git_cmd("rev-parse", "HEAD") + + git_cmd("checkout", "-b", "feature", m1) + (jax_repo_path / "feature_file.txt").write_text("feature") + git_cmd("add", "feature_file.txt") + git_cmd("commit", "-m", "F1") + f1 = git_cmd("rev-parse", "HEAD") + + git_cmd("checkout", "-b", "passing_nonlinear", m2) + git_cmd("cherry-pick", f1) + passing_nonlinear = git_cmd("rev-parse", "HEAD") + + git_cmd("checkout", "-b", "failing_nonlinear", m3) + git_cmd("cherry-pick", f1) + failing_nonlinear = git_cmd("rev-parse", "HEAD") + + git_cmd("checkout", "main") + + yield { + "paths": { + "repo": repo_path, + "output": output_path, + "scripts": mock_scripts_path, + }, + "commits": { + "good_main": m2, + "bad_main": m3, + "feature": f1, + "passing_nonlinear": passing_nonlinear, + "failing_nonlinear": failing_nonlinear, + }, + } + + +@pytest.mark.parametrize( + "scenario, passing_commit_key, failing_commit_key, expected_good_key, expected_bad_key", + [ + ( + "Non-Linear History", + "passing_nonlinear", + "failing_nonlinear", + "good_main", + "bad_main", + ), + ("Linear History", "good_main", "bad_main", "good_main", "bad_main"), + ], +) +def test_triage_scenarios( + triage_test_env, + monkeypatch, + scenario, + passing_commit_key, + failing_commit_key, + expected_good_key, + expected_bad_key, +): + """Tests the TriageTool class.""" + paths = triage_test_env["paths"] + all_commits = triage_test_env["commits"] + jax_repo_path = paths["repo"] / "jax" + + class MockArgs: + main_branch = "main" + bazel_cache = "" + build_scripts_path = None + test_command = ["test-case.sh", str(jax_repo_path), all_commits["bad_main"]] + cherry_pick_commits = {} + output_prefix = paths["output"] + container_runtime = "mock" # Use a mock runtime + container_mount = [] + + args = MockArgs() + logger = logging.getLogger(f"Scenario-{scenario}") + logging.basicConfig(level=logging.INFO) + + tool = TriageTool(args, logger) + tool.package_dirs = {"jax": str(jax_repo_path)} + tool.dynamic_packages = {"jax"} + tool.bisection_url = "mock_url" + + # Set up a monkeypatch for the container creation + # in this way we're using MockContainer rather than make_container + mock_container = MockContainer(paths["scripts"], logger) + monkeypatch.setattr( + "jax_toolbox_triage.triage_tool.make_container", lambda *a, **kw: mock_container + ) + + passing_versions = {"jax": all_commits[passing_commit_key]} + failing_versions = {"jax": all_commits[failing_commit_key]} + + package_versions = tool._gather_histories( + mock_container, passing_versions, failing_versions + ) + # Bisection test + result, _, _ = version_search( + versions=package_versions, + build_and_test=tool._build_and_test, + logger=logger, + skip_precondition_checks=False, + ) + + assert result.get("jax_good") == all_commits[expected_good_key] + assert result.get("jax_bad") == all_commits[expected_bad_key] From c0538841c1e9c5f0176b53b00d78e7e4928eca0c Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 8 Jul 2025 15:20:26 +0100 Subject: [PATCH 15/50] create a test for the entire triage tool class --- .github/triage/jax_toolbox_triage/triage_tool.py | 2 +- .github/triage/tests/test_triage_tool_class.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 8508d20a5..b993f2159 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -372,7 +372,7 @@ def _build_and_test( git_commands, changed, skipped = [], [], [] for package in sorted(self.dynamic_packages): version = versions[package] - if self.bisection_versions[package] == version: + if self.bisection_versions.get(package) == version: # If the existing version is the desired one, do nothing. skipped.append(f"{package}@{version}") continue diff --git a/.github/triage/tests/test_triage_tool_class.py b/.github/triage/tests/test_triage_tool_class.py index 7e41524da..3fd212baa 100644 --- a/.github/triage/tests/test_triage_tool_class.py +++ b/.github/triage/tests/test_triage_tool_class.py @@ -108,6 +108,11 @@ def triage_test_env(): (mock_scripts_path / "test-case.sh").write_text(test_case_content) os.chmod(mock_scripts_path / "test-case.sh", 0o755) + # Create a fake bazel executable + (mock_scripts_path / "bazel").write_text("#!/bin/sh\necho bazel") + os.chmod(mock_scripts_path / "bazel", 0o755) + + # setup the jax repo path jax_repo_path = repo_path / "jax" jax_repo_path.mkdir() @@ -213,6 +218,14 @@ class MockArgs: "jax_toolbox_triage.triage_tool.make_container", lambda *a, **kw: mock_container ) + # In case of linear-history scenario, we need a fake jax script too + if scenario == "Linear History": + linear_build_script_path = paths["scripts"] / "build-jax.sh" + linear_build_script_path.write_text( + "#!/bin/sh\necho 'Mock linear build successful.'\nexit 0" + ) + os.chmod(linear_build_script_path, 0o755) + passing_versions = {"jax": all_commits[passing_commit_key]} failing_versions = {"jax": all_commits[failing_commit_key]} From e6334de611f85769e505b8cc1ea72b8ccc3a92d9 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 8 Jul 2025 15:20:47 +0100 Subject: [PATCH 16/50] remove previous test --- .../tests/test_triage_history_bisection.py | 263 ------------------ 1 file changed, 263 deletions(-) delete mode 100644 .github/triage/tests/test_triage_history_bisection.py diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py deleted file mode 100644 index 92544845d..000000000 --- a/.github/triage/tests/test_triage_history_bisection.py +++ /dev/null @@ -1,263 +0,0 @@ -import subprocess -import tempfile -import pathlib -import os -import logging -from collections import OrderedDict -import pytest - -from jax_toolbox_triage.bisect import get_commit_history -from jax_toolbox_triage.logic import version_search, TestResult -from jax_toolbox_triage.container import Container - - -def run_command(command, cwd=None, env=None): - """Simple function to run a command in a subprocess. - - Args: - command (list): The command to run as a list of strings. - cwd (str, optional): The working directory to run the command in. - env (dict, optional): Environment variables to set for the command. - Returns: - str: The standard output of the command. - """ - try: - result = subprocess.run( - command, cwd=cwd, env=env, check=True, capture_output=True, text=True - ) - return result.stdout.strip() - except subprocess.CalledProcessError as e: - logging.error(f"Command '{' '.join(command)}' failed with error: {e}") - raise e - - -class MockContainer(Container): - """A mock container class for testing purposes.""" - - def __init__(self, mock_scripts_path, logger): - super().__init__(logger=logger) - self.mock_scripts_path = mock_scripts_path - self._env = os.environ.copy() - self._env["PATH"] = f"{self.mock_scripts_path}:{self._env['PATH']}" - - def __enter__(self): - return self - - def __exit__(self, *exc_info): - pass - - def __repr__(self): - return "MockContainer" - - def check_exec(self, cmd, **kwargs): - """Override the check_exec""" - return super().check_exec(cmd, **kwargs) - - def exec( - self, - command, - *, - policy="default", - stderr="interleaved", - workdir=None, - log_level=logging.DEBUG, - ): - self._logger.debug(f"Executing command: {command} in {workdir}") - is_shell_command = command[0] == "sh" and command[1] == "-c" - cmd_to_run = command[2] if is_shell_command else command - try: - return subprocess.run( - cmd_to_run, - capture_output=True, - text=True, - cwd=workdir, - env=self._env, - shell=is_shell_command, - ) - except FileNotFoundError as e: - return subprocess.CompletedProcess(command, 127, stderr=str(e)) - - def exists(self) -> bool: - return True - - -@pytest.fixture -def triage_test_env(): - """ - Fixture to set up the test environment for triage tests. - - The fixture creates a temp directory and a git repo with a - defined linear and non-linear history. - - The fixture yields a dictionary of paths and commit hashes - """ - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = pathlib.Path(temp_dir) - repo_path = temp_path / "repos" - output_path = temp_path / "output" - mock_scripts_path = temp_path / "mock_scripts" - repo_path.mkdir() - output_path.mkdir() - mock_scripts_path.mkdir() - - # Generation of mock scripts - # build-jax.sh - source_scripts_dir = pathlib.Path(__file__).parent / "mock_scripts" - build_script_content = (source_scripts_dir / "build-jax.sh").read_text() - (mock_scripts_path / "build-jax.sh").write_text(build_script_content) - os.chmod(mock_scripts_path / "build-jax.sh", 0o755) - # test-case.sh helper test script - test_case_content = (source_scripts_dir / "test-case.sh").read_text() - (mock_scripts_path / "test-case.sh").write_text(test_case_content) - os.chmod(mock_scripts_path / "test-case.sh", 0o755) - - # Create a git repository - jax_repo_path = repo_path / "jax" - jax_repo_path.mkdir() - - def git_cmd(command, *args): - return run_command(["git", command, *args], cwd=jax_repo_path) - - # main - git_cmd("init", "-b", "main") - git_cmd("remote", "add", "origin", str(jax_repo_path)) - git_cmd("config", "user.name", "Test User") - git_cmd("config", "user.email", "test@user.it") - # Create a linear commit history - git_cmd("commit", "--allow-empty", "-m", "M1") - m1 = git_cmd("rev-parse", "HEAD") - git_cmd("commit", "--allow-empty", "-m", "M2") # good commit - m2 = git_cmd("rev-parse", "HEAD") - git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit - m3 = git_cmd("rev-parse", "HEAD") - # create a feature branch - git_cmd("checkout", "-b", "feature", m1) - (jax_repo_path / "feature_file.txt").write_text("feature") - git_cmd("add", "feature_file.txt") - git_cmd("commit", "-m", "F1") - f1 = git_cmd("rev-parse", "HEAD") - - git_cmd("checkout", "-b", "passing_nonlinear", m2) - git_cmd("cherry-pick", f1) - passing_nonlinear = git_cmd("rev-parse", "HEAD") - git_cmd("checkout", "-b", "failing_nonlinear", m3) - git_cmd("cherry-pick", f1) - failing_nonlinear = git_cmd("rev-parse", "HEAD") - git_cmd("checkout", "main") - - # yield all the info - yield { - "paths": { - "repo": repo_path, - "output": output_path, - "scripts": mock_scripts_path, - }, - "commits": { - "good_main": m2, - "bad_main": m3, - "feature": f1, - "passing_nonlinear": passing_nonlinear, - "failing_nonlinear": failing_nonlinear, - }, - } - - -# TEST CASES -@pytest.mark.parametrize( - "scenario, passing_commit_key, failing_commit_key, expected_good_key, expected_bad_key", - [ - ( - "Non-Linear History", - "passing_nonlinear", - "failing_nonlinear", - "good_main", - "bad_main", - ), - ("Linear History", "good_main", "bad_main", "good_main", "bad_main"), - ], -) -def test_triage_scenarios( - triage_test_env, - scenario, - passing_commit_key, - failing_commit_key, - expected_good_key, - expected_bad_key, -): - """Test the get_commit_history for linear and non-linear histories.""" - paths = triage_test_env["paths"] - all_commits = triage_test_env["commits"] - jax_repo_path = paths["repo"] / "jax" - - class MockArgs: - main_branch = "main" - bazel_cache = "" - build_scripts_path = None - test_command = ["test-case.sh", str(jax_repo_path), all_commits["bad_main"]] - cherry_pick_commits = {} - - args = MockArgs() - logger = logging.getLogger(f"Scenario-{scenario}") - logging.basicConfig(level=logging.INFO) - - passing_versions = {"jax": all_commits[passing_commit_key]} - failing_versions = {"jax": all_commits[failing_commit_key]} - package_dirs = {"jax": str(jax_repo_path)} - mock_container = MockContainer(paths["scripts"], logger) - - package_versions = OrderedDict() - package_versions["jax"] = get_commit_history( - worker=mock_container, - package="jax", - start=passing_versions["jax"], - end=failing_versions["jax"], - dir=package_dirs["jax"], - main_branch=args.main_branch, - args=args, - logger=logger, - ) - - # build and test - def build_and_test_wrapper(*, versions, test_output_log_level=logging.DEBUG): - workdir = package_dirs["jax"] - mock_container.check_exec( - ["git", "stash", "--include-untracked"], workdir=workdir - ) - cherry_pick_commits = args.cherry_pick_commits.get("jax", []) - - if cherry_pick_commits: - build_script = paths["scripts"] / "build-jax.sh" - mock_container.check_exec( - ["git", "checkout", versions["jax"]], workdir=workdir - ) - mock_container.check_exec( - ["git", "cherry-pick"] + cherry_pick_commits, workdir=workdir - ) - else: - build_script = paths["scripts"] / "build-jax-linear.sh" - build_script.write_text("#!/bin/sh\nexit 0") - os.chmod(build_script, 0o755) - mock_container.check_exec( - ["git", "checkout", versions["jax"]], workdir=workdir - ) - - mock_container.check_exec([str(build_script)], workdir=workdir) - result = mock_container.exec(args.test_command, workdir=workdir) - return TestResult( - host_output_directory=paths["output"], - result=result.returncode == 0, - stdouterr=" ", - ) - - # bisection - result, _, _ = version_search( - versions=package_versions, - build_and_test=build_and_test_wrapper, - logger=logger, - skip_precondition_checks=False, - ) - - # test - assert result.get("jax_good") == all_commits[expected_good_key] - assert result.get("jax_bad") == all_commits[expected_bad_key] From 883062f94ab8db7d74ed7f3b82e39d2ba0ec521d Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 8 Jul 2025 16:29:19 +0100 Subject: [PATCH 17/50] Ruff and MyPy fixes --- .github/triage/jax_toolbox_triage/args.py | 10 +- .../triage/jax_toolbox_triage/triage_tool.py | 112 +++++------------- 2 files changed, 34 insertions(+), 88 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index e6afadf1f..2df59b38e 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -258,7 +258,9 @@ def parse_args(args=None) -> argparse.Namespace: if args.container_runtime == "local": assert ( args.passing_versions is not None and args.failing_versions is not None - ), "For local runtime, --passing-versions and --failing-versions must be provided." + ), ( + "For local runtime, --passing-versions and --failing-versions must be provided." + ) assert ( args.container is None and args.start_date is None @@ -303,7 +305,7 @@ def parse_args(args=None) -> argparse.Namespace: else: # None of --{passing,failing}-{versions,container} were passed, make sure the # compulsory arguments for the container-level search were passed - assert ( - args.container is not None - ), "--container must be passed for the container-level search" + assert args.container is not None, ( + "--container must be passed for the container-level search" + ) return args diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index b993f2159..bfc5f0b29 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -2,11 +2,10 @@ import datetime import functools import hashlib -import itertools import logging import pathlib import time -import typing +from typing import Dict, Tuple, Union from .container import Container from .logic import container_search, TestResult, version_search @@ -38,7 +37,7 @@ def __init__(self, args, logger): self.args.cherry_pick_commits = {} def _test_output_directory( - self, url: str, versions: typing.Dict[str, str] = None + self, url: str, versions: Union[Dict[str, str], None] ) -> pathlib.Path: """ Create a directory for test output based on the container URL and versions. @@ -52,18 +51,22 @@ def _test_output_directory( hash_chars = 8 urlhash = f"container-{hashlib.sha1(url.encode()).hexdigest()[:hash_chars]}" out_dirname = "-".join( - itertools.chain( - [urlhash], - map(lambda t: f"{t[0]}-{t[1][:hash_chars]}", sorted(versions.items())), - ) + [urlhash] + + [f"{k}-{v[:hash_chars]}" for k, v in sorted((versions or {}).items())] ) + out_dir = self.args.output_prefix / out_dirname - assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + assert not out_dir.exists(), ( + f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + ) out_dir.mkdir(mode=0o755) return out_dir.resolve() def _get_versions( - self, container_url: str, explicit_versions: str, versions_from_env: str + self, + container_url: str, + explicit_versions: Dict[str, str], + versions_from_env: bool, ): """ Get the versions of the software packages in the container. @@ -80,9 +83,9 @@ def _get_versions( """ if explicit_versions is not None and container_url is None: return explicit_versions, None, None, None - assert ( - container_url is not None - ), "Container URL must be provided if explicit versions are not set." + assert container_url is not None, ( + "Container URL must be provided if explicit versions are not set." + ) with make_container( self.args.container_runtime, @@ -100,8 +103,8 @@ def _get_versions( def _gather_histories( self, worker: Container, - passing_versions: typing.Dict[str, str], - failing_versions: typing.Dict[str, str], + passing_versions: Dict[str, str], + failing_versions: Dict[str, str], ) -> collections.OrderedDict: """ Gather the commit histories for the passing and failing versions. @@ -153,7 +156,9 @@ def _gather_histories( return package_versions - def _log_environment_differences(self, url1: str, url2: str, env1: str, env2: str): + def _log_environment_differences( + self, url1: str, url2: str, env1: Dict[str, str], env2: Dict[str, str] + ): """ If we have two containers, print the differences between their environments. This can be useful in the case that rebuilding the good versions in the bad container, @@ -201,7 +206,7 @@ def _check_container_by_date( container_url = container_url_func(date) before = time.monotonic() - out_dir = self._test_output_directory(container_url) + out_dir = self._test_output_directory(container_url, None) # this is from the previous Container class implementation in main mounts = self.args.container_mount + [(out_dir, "/triage-tool-output")] @@ -225,7 +230,7 @@ def _check_container_by_date( self.args.output_prefix, "container", { - "container": container_url(date), + "container": container_url, "output_directory": out_dir.as_posix(), "result": test_pass, "test_time": test_time, @@ -236,68 +241,6 @@ def _check_container_by_date( host_output_directory=out_dir, result=test_pass, stdouterr=result.stdout ) - def _gather_histories( - self, - worker: Container, - passing_versions: typing.Dict[str, str], - failing_versions: typing.Dict[str, str], - ) -> typing.Tuple[typing.List[str], typing.List[str]]: - """ - Gather the commit histories for the passing and failing versions. - This function is pivotal for non-linear history logic search - - Args: - worker (Container): The container in which to run the commands. - passing_versions (dict): The versions that passed. - failing_versions (dict): The versions that failed. - - Returns: - Tuple[List[str], List[str]]: The commit histories for passing and failing versions. - """ - packages = passing_versions.keys() - self.logger.info( - f"Bisecting {' '.join(f'{p} [{passing_versions[p]}, {failing_versions[p]}]' for p in packages)} using {worker}" - ) - package_versions = collections.OrderedDict() - - for package in packages: - if package not in self.package_dirs: - continue - package_versions[package] = get_commit_history( - worker, - package, - passing_versions[package], - failing_versions[package], - self.package_dirs[package], - main_branch=self.args.main_branch, - logger=self.logger, - args=self.args, - ) - - if not self.args.cherry_pick_commits.get(package): - # Confirm they're sorted by commit date - assert all( - b[1] >= a[1] - for a, b in zip( - package_versions[package], package_versions[package][1:] - ) - ) - # Confirm the end values are included as expected - assert passing_versions[package] == package_versions[package][0][0] - assert failing_versions[package] == package_versions[package][-1][0] - for package in packages: - if package in package_versions: - continue - package_versions[package] = [ - (passing_versions[package], package_versions["xla"][0][1]) - ] - if passing_versions[package] != failing_versions[package]: - package_versions[package].append( - (failing_versions[package], package_versions["xla"][-1][1]) - ) - - return package_versions - def _check_installation_scripts(self, worker: Container): """ Check for special installation cases, like cuBLAS or cuDNN @@ -353,7 +296,7 @@ def _check_installation_scripts(self, worker: Container): def _build_and_test( self, *, - versions: typing.Dict[str, str], + versions: Dict[str, str], test_output_log_level: int = logging.DEBUG, ) -> TestResult: """ @@ -496,7 +439,7 @@ def prepare(self): """ self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args.bazel_cache) - def find_container_range(self) -> typing.Tuple[str, str]: + def find_container_range(self) -> Tuple[str, str]: """ Find the range from the passing and failing containers. Returns a tuple of the start and end container names. @@ -599,9 +542,9 @@ def gather_version_info(self, passing_url: str, failing_url: str): def run_version_bisection( self, - passing_versions: typing.Dict[str, str], - failing_versions: typing.Dict[str, str], - ) -> typing.Tuple[typing.Dict[str, str], TestResult]: + passing_versions: Dict[str, str], + failing_versions: Dict[str, str], + ) -> None: """ Run the version bisection process. @@ -638,3 +581,4 @@ def run_version_bisection( ) result["container"] = self.bisection_url add_summary_record(self.args.output_prefix, "result", result, scalar=True) + self.logger.info("Version-level bisection completed") From f1a02d838fa6d1ed7e476d279c8fbc5d6ce29e9b Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 8 Jul 2025 16:55:19 +0100 Subject: [PATCH 18/50] try to fix for python 3.8 --- .github/triage/jax_toolbox_triage/bisect.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index d44524db0..d5beefbf2 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -101,6 +101,9 @@ def get_commit_history( data = [] for line in result.stdout.splitlines(): commit, date = line.split() + # for python < 3.11 we nee dto fix: + if date.endswith("Z"): + date = date[:-1] + "+00:00" date = datetime.datetime.fromisoformat(date).astimezone(datetime.timezone.utc) data.append((commit, date)) return data From eb15191d9063a77cb1375715a2eaf4938539e63e Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 8 Jul 2025 17:03:27 +0100 Subject: [PATCH 19/50] try another way to pack dictionaries for 3.8 --- .../triage/jax_toolbox_triage/triage_tool.py | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index bfc5f0b29..eafacff1f 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -56,9 +56,7 @@ def _test_output_directory( ) out_dir = self.args.output_prefix / out_dirname - assert not out_dir.exists(), ( - f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" - ) + assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" out_dir.mkdir(mode=0o755) return out_dir.resolve() @@ -83,9 +81,9 @@ def _get_versions( """ if explicit_versions is not None and container_url is None: return explicit_versions, None, None, None - assert container_url is not None, ( - "Container URL must be provided if explicit versions are not set." - ) + assert ( + container_url is not None + ), "Container URL must be provided if explicit versions are not set." with make_container( self.args.container_runtime, @@ -230,12 +228,14 @@ def _check_container_by_date( self.args.output_prefix, "container", { - "container": container_url, - "output_directory": out_dir.as_posix(), - "result": test_pass, - "test_time": test_time, - } - | versions, + **{ + "container": container_url, + "output_directory": out_dir.as_posix(), + "result": test_pass, + "test_time": test_time, + }, + **versions, + }, ) return TestResult( host_output_directory=out_dir, result=test_pass, stdouterr=result.stdout @@ -416,13 +416,15 @@ def _build_and_test( self.args.output_prefix, "versions", { - "build_time": middle - before, - "container": self.bisection_url, - "output_directory": out_dir.as_posix(), - "result": test_result.returncode == 0, - "test_time": test_time, - } - | versions, + **{ + "build_time": middle - before, + "container": self.bisection_url, + "output_directory": out_dir.as_posix(), + "result": test_result.returncode == 0, + "test_time": test_time, + }, + **versions, + }, ) result_str = "pass" if test_result.returncode == 0 else "fail" self.logger.info(f"Test completed in {test_time:.1f}s ({result_str})") From 057f2d12b39a059682d29dffcef4c695cce47d8b Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 8 Jul 2025 17:06:49 +0100 Subject: [PATCH 20/50] ruff fix --- .github/triage/jax_toolbox_triage/triage_tool.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index eafacff1f..554fb0679 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -56,7 +56,9 @@ def _test_output_directory( ) out_dir = self.args.output_prefix / out_dirname - assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + assert not out_dir.exists(), ( + f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + ) out_dir.mkdir(mode=0o755) return out_dir.resolve() @@ -81,9 +83,9 @@ def _get_versions( """ if explicit_versions is not None and container_url is None: return explicit_versions, None, None, None - assert ( - container_url is not None - ), "Container URL must be provided if explicit versions are not set." + assert container_url is not None, ( + "Container URL must be provided if explicit versions are not set." + ) with make_container( self.args.container_runtime, From 25713635fdf8a13272571faeaa9a3c089b6c763d Mon Sep 17 00:00:00 2001 From: Steboss Date: Wed, 9 Jul 2025 10:50:53 +0100 Subject: [PATCH 21/50] add a scheme with all the process --- docs/triage-tool.md | 84 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/docs/triage-tool.md b/docs/triage-tool.md index 349e8b7c8..ba8b9c6da 100644 --- a/docs/triage-tool.md +++ b/docs/triage-tool.md @@ -69,6 +69,88 @@ passed to `salloc` or set via `SLURM_` environment variables so that a bare `sru correctly launch the test case. If `--container-runtime=local` is used, the tool assumes it is already inside a JAX container and will execute all build and test commands directly. +## Triage tool logic scheme + ++--------------------------------+ +| main.py | +| main() | ++--------------------------------+ + | + v ++--------------------------------+ +| tool = TriageTool(args) | +| tool.prepare() | ++--------------------------------+ + | + v ++--------------------------------+ +| tool.find_container_range() | ++--------------------------------+ + | + v ++--------------------------------+ +| tool.gather_version_info() | +| # extract commit hashes, | +| # compare environments | ++--------------------------------+ + | + v ++--------------------------------+ +| tool.run_version_bisection() | +| # main part of the logic | +| # it runs all the steps below | ++--------------------------------+ + | + v ++------------------------------------------------+ +| tool._gather_histories() | +| (Calls bisect.py -> get_commit_history) | +| # get the list of all the commits between p&f | ++------------------------------------------------+ + | + | <--- Inside get_commit_history() + | ++----------------------------------------------------------------------------+ +| Is history linear? (git merge-base --is-ancestor) | ++--------------------------+-------------------------------------------------+ +| YES | NO | +| (Linear History Path) | (Non-Linear History Path) | ++--------------------------+-------------------------------------------------+ + | | + v v ++--------------------------+ +-------------------------------------------------+ +| get_commit_history() | | get_commit_history() | +| returns commits from | | - Uses 'git merge-base' to find linear range | +| main branch directly. | | - Finds cherry-picks to apply to args | +| | | - Returns commits from the *main* branch | ++--------------------------+ +-------------------------------------------------+ + | | + | | + +--------------------------------+ + | + v ++------------------------------------------------+ +| logic.py -> version_search() loop | +| (Repeatedly calls tool._build_and_test) | ++------------------------------------------------+ + | + | <--- Inside _build_and_test() + | ++----------------------------------------------------------------------------+ +| Does args.cherry_pick_commits exist? --> | ++--------------------------+-------------------------------------------------+ +| NO | YES | +| (Linear History Path) | (Non-Linear History Path) | ++--------------------------+-------------------------------------------------+ + | | + v v ++--------------------------+ +-------------------------------------------------+ +| _build_and_test() | | _build_and_test() | +| - git checkout | | - git checkout | +| - build & test | | - git cherry-pick | +| | | - build & test | ++--------------------------+ +-------------------------------------------------+ + ## Usage To use the tool, there are two compulsory inputs: @@ -161,7 +243,7 @@ paths and `http`/`https`/`grpc` URLs. If `--skip-precondition-checks` is passed, a sanity check that the failure can be reproduced after rebuilding the JAX/XLA commits from the first-known-bad container -inside that container will be skipped. +inside that container will be skipped. ## Example From 457595b7122204a4d9574e14409d361620c4c62e Mon Sep 17 00:00:00 2001 From: Steboss Date: Thu, 10 Jul 2025 10:34:42 +0100 Subject: [PATCH 22/50] Update .github/triage/jax_toolbox_triage/bisect.py Co-authored-by: Olli Lupton --- .github/triage/jax_toolbox_triage/bisect.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index d5beefbf2..1a495c998 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -48,9 +48,8 @@ def get_commit_history( logger.warning("No remote found, skipping fetch.") # detect non-linear history - is_ancestor_cmd = f"git merge-base --is-ancestor {start} {end}" is_ancestor_result = worker.exec( - ["sh", "-c", is_ancestor_cmd], + ["git", "merge-base", "--is-ancestor", start, end], workdir=dir, ) is_linear = is_ancestor_result.returncode == 0 From 0e59bbff58292a03728d726782ddd810701b5b10 Mon Sep 17 00:00:00 2001 From: Steboss Date: Thu, 10 Jul 2025 12:18:50 +0100 Subject: [PATCH 23/50] start addressing comments --- .github/triage/jax_toolbox_triage/bisect.py | 28 ++++++---- .github/triage/jax_toolbox_triage/main.py | 1 - .../triage/jax_toolbox_triage/triage_tool.py | 55 ++++++------------- 3 files changed, 34 insertions(+), 50 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index 1a495c998..961692e85 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -8,7 +8,7 @@ def get_commit_history( start, end, dir, - main_branch=None, + main_branch, logger=None, args=None, ): @@ -21,7 +21,7 @@ def get_commit_history( start (str): The starting commit hash. end (str): The ending commit hash. dir (str): The directory where the git repository is located. - main_branch (str, optional): The main branch name. Defaults to None. + main_branch (str): The main branch name. Defaults is the default branch of the repo. logger (Logger, optional): Logger for debug information. Defaults to None. args: Additional arguments that may contain cherry-pick commits. @@ -54,19 +54,23 @@ def get_commit_history( ) is_linear = is_ancestor_result.returncode == 0 - if not is_linear and package in ["jax", "xla"]: + if not is_linear: logger.info(f"Using non-linear history logic with main branch {main_branch}") # 1. find the linear range on the main branch - passing_main_commit_cmd = f"git merge-base {start} {end}" - failing_main_commit_cmd = f"git merge-base {end} {main_branch}" - - passing_main_commit = worker.check_exec( - ["sh", "-c", passing_main_commit_cmd], workdir=dir - ).stdout.strip() - failing_main_commit = worker.check_exec( - ["sh", "-c", failing_main_commit_cmd], workdir=dir - ).stdout.strip() + passing_and_failing_cmd = ( + worker.check_exec( + [ + "sh", + "-c", + f"git merge-base {start} {end} && git merge-base {end} {main_branch}", + ], + workdir=dir, + ) + .stdout() + .strip() + ) + passing_main_commit, failing_main_commit = passing_and_failing_cmd.splitlines() # 2. find commits to cherry-pick from the failing branch cherry_pick_cmd = f"git rev-list --reverse {failing_main_commit}..{end}" diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index 222869d6a..fff2e99d0 100644 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -12,7 +12,6 @@ def main() -> None: try: tool = TriageTool(args, logger) - tool.prepare() passing_url, failing_url = tool.find_container_range() diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 554fb0679..c6bf7c101 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -27,12 +27,12 @@ class TriageTool: def __init__(self, args, logger): self.args = args self.logger = logger - self.bazel_cache_mounts = [] self.bisection_url = None self.bisection_versions = {} self.package_dirs = None self.dynamic_packages = set() self.packages_with_scripts = set() + self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args.bazel_cache) # the cherry-pick gets populated only for non-linear cases self.args.cherry_pick_commits = {} @@ -56,9 +56,7 @@ def _test_output_directory( ) out_dir = self.args.output_prefix / out_dirname - assert not out_dir.exists(), ( - f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" - ) + assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" out_dir.mkdir(mode=0o755) return out_dir.resolve() @@ -83,9 +81,9 @@ def _get_versions( """ if explicit_versions is not None and container_url is None: return explicit_versions, None, None, None - assert container_url is not None, ( - "Container URL must be provided if explicit versions are not set." - ) + assert ( + container_url is not None + ), "Container URL must be provided if explicit versions are not set." with make_container( self.args.container_runtime, @@ -133,15 +131,14 @@ def _gather_histories( args=self.args, ) - if not self.args.cherry_pick_commits.get(package): - assert all( - b[1] >= a[1] - for a, b in zip( - package_versions[package], package_versions[package][1:] - ) + assert all( + b[1] >= a[1] + for a, b in zip( + package_versions[package], package_versions[package][1:] ) - assert passing_versions[package] == package_versions[package][0][0] - assert failing_versions[package] == package_versions[package][-1][0] + ) + assert passing_versions[package] == package_versions[package][0][0] + assert failing_versions[package] == package_versions[package][-1][0] for package in packages: if package in package_versions: @@ -325,26 +322,17 @@ def _build_and_test( self.bisection_versions[package] = version changed.append(f"{package}@{version}") if package in self.package_dirs: - # in case of non-linear history - should we limit this to XLA and JAX only? package_cherry_picks = self.args.cherry_pick_commits.get(package, []) + git_commands.append(f"cd {self.package_dirs[package]}") + git_commands.append("git stash") + # this is a checkout on the main branch + git_commands.append(f"git checkout {version}") if package_cherry_picks: self.logger.info("Working on a non-linear history") - git_commands.append(f"cd {self.package_dirs[package]}") - git_commands.append("git stash") - # this is a checkout on the main branch - git_commands.append(f"git checkout {version}") cherry_pick_str = " ".join(package_cherry_picks) git_commands.append( - f"git cherry-pick {cherry_pick_str} || (echo 'Cherry-pick failed' && exit 1)" + f"(git cherry-pick {cherry_pick_str} || (echo 'Cherry-pick failed' && exit 1) )" ) - else: - # Linear history - # A git repository that exists in the container. - git_commands += [ - f"cd {self.package_dirs[package]}", - "git stash", - f"git checkout {version}", - ] else: # Another software package, `version` is probably a version number. @@ -436,13 +424,6 @@ def _build_and_test( stdouterr=test_result.stdout, ) - def prepare(self): - """ - Function to prepare the triage tool for execution. - At the moment, we're adding the bazel cache mounts to the tool. - """ - self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args.bazel_cache) - def find_container_range(self) -> Tuple[str, str]: """ Find the range from the passing and failing containers. @@ -564,7 +545,7 @@ def run_version_bisection( with make_container( self.args.container_runtime, self.bisection_url, - self.bazel_cache_mounts, + self.args.container_mount, self.logger, ) as worker: package_versions = self._gather_histories( From c86388a5b0cffa2170bb6e3c43f8ede9758be5f7 Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 11 Jul 2025 11:51:28 +0100 Subject: [PATCH 24/50] fix comments --- .github/triage/jax_toolbox_triage/args.py | 34 ++++++++++++++--- .github/triage/jax_toolbox_triage/bisect.py | 37 +++++++++++-------- .../triage/jax_toolbox_triage/triage_tool.py | 9 +++-- 3 files changed, 56 insertions(+), 24 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index 2df59b38e..d52339597 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -26,6 +26,23 @@ def parse_version_argument(s: str) -> typing.Dict[str, str]: return ret +def parse_override_remotes(s: str) -> typing.Dict[str, str]: + """Function to parse the override remote + + Inputs: + s: (str) e.g. https://@host/repo.git + + Returns: + ret: (typing.Dict[str,str]) Dictionary with software as key and git-url as value. + """ + ret: typing.Dict[str, str] = {} + for part in s.split(","): + sw, url = part.split(":", 1) + assert sw not in ret, ret + ret[sw] = url + return ret + + def parse_args(args=None) -> argparse.Namespace: parser = argparse.ArgumentParser( description=""" @@ -208,6 +225,13 @@ def parse_args(args=None) -> argparse.Namespace: in question has different versions at the endpoints of the bisection range. """, ) + version_search_args.add_argument( + "--override-remotes", + type=parse_override_remotes, + help="""Remote URLs to be used for fetching, including auth token. E.g.: + jax:https://@host/repo.git,xla:https://@host/repo.git + """, + ) parser.add_argument( "-v", "--container-mount", @@ -258,9 +282,7 @@ def parse_args(args=None) -> argparse.Namespace: if args.container_runtime == "local": assert ( args.passing_versions is not None and args.failing_versions is not None - ), ( - "For local runtime, --passing-versions and --failing-versions must be provided." - ) + ), "For local runtime, --passing-versions and --failing-versions must be provided." assert ( args.container is None and args.start_date is None @@ -305,7 +327,7 @@ def parse_args(args=None) -> argparse.Namespace: else: # None of --{passing,failing}-{versions,container} were passed, make sure the # compulsory arguments for the container-level search were passed - assert args.container is not None, ( - "--container must be passed for the container-level search" - ) + assert ( + args.container is not None + ), "--container must be passed for the container-level search" return args diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index 961692e85..2bb2aef9d 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -40,12 +40,25 @@ def get_commit_history( workdir=dir, ) if commits_known.returncode != 0: - if worker.exec(["git", "remote"]).stout.strip(): - worker.check_exec( - ["git", "fetch"], policy="once_per_container", workdir=dir - ) + if args.override_remotes and package in args.override_remotes: + remote_url = args.override_remotes[package] + fetch_cmd = ["git", "fetch", remote_url, start, end] + worker.check_exec(fetch_cmd, workdir=dir) else: - logger.warning("No remote found, skipping fetch.") + # default behaviour + # re: https://stackoverflow.com/questions/4089430/how-to-determine-the-url-that-a-local-git-repository-was-originally-cloned-from + remote_url_result = worker.check_exec( + ["git", "config", "--get", "remote.origin.url"], workdir=dir + ) + if remote_url_result.returncode == 0: + remote_url = remote_url_result.stdout().strip() + fetch_cmd = ["git", "fetch", remote_url, start, end] + worker.check_exec(fetch_cmd, workdir=dir) + else: + logger.error( + "No remote 'origin' found and no override provided. Cannot fetch missing commits" + ) + raise Exception("Cannot find commits and no remote is configured") # detect non-linear history is_ancestor_result = worker.exec( @@ -53,6 +66,7 @@ def get_commit_history( workdir=dir, ) is_linear = is_ancestor_result.returncode == 0 + cherry_pick_range = {} if not is_linear: logger.info(f"Using non-linear history logic with main branch {main_branch}") @@ -73,15 +87,7 @@ def get_commit_history( passing_main_commit, failing_main_commit = passing_and_failing_cmd.splitlines() # 2. find commits to cherry-pick from the failing branch - cherry_pick_cmd = f"git rev-list --reverse {failing_main_commit}..{end}" - cherry_pick_commits_list = ( - worker.check_exec(["sh", "-c", cherry_pick_cmd], workdir=dir) - .stdout.strip() - .splitlines() - ) - if cherry_pick_commits_list: - args.cherry_pick_commits[package] = cherry_pick_commits_list - logger.info(f"Cherry-pick commits: {cherry_pick_commits_list}") + cherry_pick_range[package] = f"{failing_main_commit}..{end}" # 3. now we can use the main branch commits for bisection start = passing_main_commit @@ -109,4 +115,5 @@ def get_commit_history( date = date[:-1] + "+00:00" date = datetime.datetime.fromisoformat(date).astimezone(datetime.timezone.utc) data.append((commit, date)) - return data + + return data, cherry_pick_range diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index c6bf7c101..82df4a0d6 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -34,7 +34,7 @@ def __init__(self, args, logger): self.packages_with_scripts = set() self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args.bazel_cache) # the cherry-pick gets populated only for non-linear cases - self.args.cherry_pick_commits = {} + self.cherry_pick_commits = {} def _test_output_directory( self, url: str, versions: Union[Dict[str, str], None] @@ -120,7 +120,7 @@ def _gather_histories( for package in packages: if package not in self.package_dirs: continue - package_versions[package] = get_commit_history( + history, cherry_pick_range = get_commit_history( worker, package, passing_versions[package], @@ -130,6 +130,9 @@ def _gather_histories( logger=self.logger, args=self.args, ) + package_versions[package] = history + if cherry_pick_range: + self.cherry_pick_commits[package] = cherry_pick_range assert all( b[1] >= a[1] @@ -322,7 +325,7 @@ def _build_and_test( self.bisection_versions[package] = version changed.append(f"{package}@{version}") if package in self.package_dirs: - package_cherry_picks = self.args.cherry_pick_commits.get(package, []) + package_cherry_picks = self.cherry_pick_commits.get(package, []) git_commands.append(f"cd {self.package_dirs[package]}") git_commands.append("git stash") # this is a checkout on the main branch From 27326d0f28878c05111d27c39906fafc16d80bd7 Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 11 Jul 2025 12:36:49 +0100 Subject: [PATCH 25/50] start fixing the tests too --- .github/triage/jax_toolbox_triage/bisect.py | 3 +- .../tests/test_triage_history_bisection.py | 174 ++++++++++++++++++ 2 files changed, 176 insertions(+), 1 deletion(-) create mode 100644 .github/triage/tests/test_triage_history_bisection.py diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index 2bb2aef9d..b36f1227a 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -26,7 +26,8 @@ def get_commit_history( args: Additional arguments that may contain cherry-pick commits. Returns: - list: A list of tuples containing commit hashes and their corresponding dates. + data: list, list of all the commits + cherry_pick_range: str, range of cherry pick commits if any """ # In particular the end commit might not already be known if the older, # passing, container is being used for triage. diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py new file mode 100644 index 000000000..2374645e1 --- /dev/null +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -0,0 +1,174 @@ +import subprocess +import tempfile +import pathlib +import os +import logging +import json +import pytest + +from jax_toolbox_triage.args import parse_args +from jax_toolbox_triage.triage_tool import TriageTool + + +def run_command(command, cwd=None, env=None): + """Simple function to run a command in a subprocess. + + Args: + command (list): The command to run as a list of strings. + cwd (str, optional): The working directory to run the command in. + env (dict, optional): Environment variables to set for the command. + Returns: + str: The standard output of the command. + """ + try: + result = subprocess.run( + command, cwd=cwd, env=env, check=True, capture_output=True, text=True + ) + return result.stdout.strip() + except subprocess.CalledProcessError as e: + logging.error(f"Command '{' '.join(command)}' failed with error: {e}") + raise e + + +@pytest.fixture +def triage_test_env(): + """ + Fixture to set up the test environment for triage tests. + + The fixture creates a temp directory and a git repo with a + defined linear and non-linear history. + + The fixture yields a dictionary of paths and commit hashes + """ + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) + repo_path = temp_path / "repos" + output_path = temp_path / "output" + mock_scripts_path = temp_path / "mock_scripts" + repo_path.mkdir() + output_path.mkdir() + mock_scripts_path.mkdir() + source_scripts_dir = pathlib.Path(__file__).parent / "mock_scripts" + # fake build-jax + build_script_content = (source_scripts_dir / "build-jax.sh").read_text() + (mock_scripts_path / "build-jax.sh").write_text(build_script_content) + os.chmod(mock_scripts_path / "build-jax.sh", 0o755) + # test-case.sh helper test script + test_case_content = (source_scripts_dir / "test-case.sh").read_text() + (mock_scripts_path / "test-case.sh").write_text(test_case_content) + os.chmod(mock_scripts_path / "test-case.sh", 0o755) + + # Create a git repository + jax_repo_path = repo_path / "jax" + jax_repo_path.mkdir() + + def git_cmd(command, *args): + return run_command(["git", command, *args], cwd=jax_repo_path) + + # main + # why don't we push the scripts and use them in repo + git_cmd("init", "-b", "main") + git_cmd("config", "user.name", "Test User") + git_cmd("config", "user.email", "test@user.it") + # Create a linear commit history + git_cmd("commit", "--allow-empty", "-m", "M1") + m1 = git_cmd("rev-parse", "HEAD") + git_cmd("commit", "--allow-empty", "-m", "M2") # good commit + m2 = git_cmd("rev-parse", "HEAD") + git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit + m3 = git_cmd("rev-parse", "HEAD") + # create a feature branch + git_cmd("checkout", "-b", "feature", m1) + (jax_repo_path / "feature_file.txt").write_text("feature") + git_cmd("add", "feature_file.txt") + git_cmd("commit", "-m", "F1") + f1 = git_cmd("rev-parse", "HEAD") + # here we're applying a feature to the good f1 commit + git_cmd("checkout", "-b", "passing_nonlinear", m2) + git_cmd("cherry-pick", f1) + passing_nonlinear = git_cmd("rev-parse", "HEAD") + # and then we apply the feature to the bad commit + # this simulated the rebase scenario + git_cmd("checkout", "-b", "failing_nonlinear", m3) + git_cmd("cherry-pick", f1) + failing_nonlinear = git_cmd("rev-parse", "HEAD") + git_cmd("checkout", "main") + + # yield all the info + yield { + "paths": { + "repo": repo_path, + "output": output_path, + "scripts": mock_scripts_path, + }, + "commits": { + "good_main": m2, + "bad_main": m3, + "feature": f1, + "passing_nonlinear": passing_nonlinear, + "failing_nonlinear": failing_nonlinear, + }, + } + + +@pytest.mark.parametrize( + "scenario, passing_commit_key, failing_commit_key, expected_good_key, expected_bad_key", + [ + ( + "Non-Linear History", + "passing_nonlinear", + "failing_nonlinear", + "good_main", + "bad_main", + ), + ("Linear History", "good_main", "bad_main", "good_main", "bad_main"), + ], +) +def test_triage_scenarios( + triage_env, + scenario, + passing_commit_key, + failing_commit_key, + expected_good_key, + expected_bad_key, +): + """Test the get_commit_history for linear and non-linear histories.""" + paths = triage_env["paths"] + all_commits = triage_env["commits"] + jax_repo_path = paths["repo"] / "jax" + + arg_list = [ + "--main-branch", + "main", + "--output-prefix", + str(paths["output"]), + "--container-runtime", + "local", + str(paths["scripts"] / "test-case.sh"), + str(jax_repo_path), + all_commits["bad_main"], + ] + args = parse_args(arg_list) + logger = logging.getLogger(f"Scenario-{scenario}") + logging.basicConfig(level=logging.INFO) + + tool = TriageTool(args, logger) + tool.package_dirs = {"jax": str(jax_repo_path)} + tool.dynamic_packages = {"jax"} + tool.bisection_url = "local" + + passing_versions = {"jax": all_commits[passing_commit_key]} + failing_versions = {"jax": all_commits[failing_commit_key]} + + tool.run_version_bisection(passing_versions, failing_versions) + summary_file = paths["output"] / "summary.json" + assert summary_file.exists(), "The summary file was not created" + with open(summary_file, "r") as f: + summary_data = json.load(f) + + assert "result" in summary_data, "No result section was created" + result = summary_data["result"] + + assert result.get("jax_good") == all_commits[expected_good_key] + assert result.get("jax_bad") == all_commits[expected_bad_key] From b3c1baec8d4f57d6254187e479690bba725db13c Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 11 Jul 2025 16:05:50 +0100 Subject: [PATCH 26/50] fix tests and code itself. then remove logging --- .github/triage/jax_toolbox_triage/bisect.py | 25 +++++++------- .../triage/jax_toolbox_triage/triage_tool.py | 25 ++++++++++---- .../tests/test_triage_history_bisection.py | 34 ++++++++++++++++--- 3 files changed, 60 insertions(+), 24 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index b36f1227a..7f728c3cb 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -70,21 +70,17 @@ def get_commit_history( cherry_pick_range = {} if not is_linear: - logger.info(f"Using non-linear history logic with main branch {main_branch}") + logger.info(f"Using non-linear history logic with branch {main_branch}") # 1. find the linear range on the main branch - passing_and_failing_cmd = ( - worker.check_exec( - [ - "sh", - "-c", - f"git merge-base {start} {end} && git merge-base {end} {main_branch}", - ], - workdir=dir, - ) - .stdout() - .strip() - ) + passing_and_failing_cmd = worker.check_exec( + [ + "sh", + "-c", + f"git merge-base {start} {end} && git merge-base {end} {main_branch}", + ], + workdir=dir, + ).stdout.strip() passing_main_commit, failing_main_commit = passing_and_failing_cmd.splitlines() # 2. find commits to cherry-pick from the failing branch @@ -117,4 +113,7 @@ def get_commit_history( date = datetime.datetime.fromisoformat(date).astimezone(datetime.timezone.utc) data.append((commit, date)) + logger.info(f"Data: {data}") + logger.info(f"Cherry pick {cherry_pick_range}") + return data, cherry_pick_range diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 82df4a0d6..57eeb8a03 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -132,7 +132,7 @@ def _gather_histories( ) package_versions[package] = history if cherry_pick_range: - self.cherry_pick_commits[package] = cherry_pick_range + self.cherry_pick_commits[package] = cherry_pick_range[package] assert all( b[1] >= a[1] @@ -140,8 +140,14 @@ def _gather_histories( package_versions[package], package_versions[package][1:] ) ) - assert passing_versions[package] == package_versions[package][0][0] - assert failing_versions[package] == package_versions[package][-1][0] + + # this check works only for linera-case + if not self.cherry_pick_commits.get(package): + self.logger.info(f"package_versions: {package_versions}") + self.logger.info(f"passing versions: {passing_versions}") + self.logger.info(f"failing versions: {failing_versions}") + assert passing_versions[package] == package_versions[package][0][0] + assert failing_versions[package] == package_versions[package][-1][0] for package in packages: if package in package_versions: @@ -325,17 +331,22 @@ def _build_and_test( self.bisection_versions[package] = version changed.append(f"{package}@{version}") if package in self.package_dirs: - package_cherry_picks = self.cherry_pick_commits.get(package, []) + self.logger.info( + f"Package working on {package} form {self.package_dirs}" + ) + self.logger.info(f"And cherry pick commits {self.cherry_pick_commits}") + cherry_pick_range = self.cherry_pick_commits.get(package) + self.logger.info(f"Cherry pick range {cherry_pick_range}") git_commands.append(f"cd {self.package_dirs[package]}") git_commands.append("git stash") # this is a checkout on the main branch git_commands.append(f"git checkout {version}") - if package_cherry_picks: + if cherry_pick_range: self.logger.info("Working on a non-linear history") - cherry_pick_str = " ".join(package_cherry_picks) git_commands.append( - f"(git cherry-pick {cherry_pick_str} || (echo 'Cherry-pick failed' && exit 1) )" + f"(git cherry-pick {cherry_pick_range} || (echo 'Cherry-pick failed' && exit 1) )" ) + self.logger.info(f"Full command {git_commands}") else: # Another software package, `version` is probably a version number. diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 2374645e1..7425bd41e 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -31,7 +31,7 @@ def run_command(command, cwd=None, env=None): @pytest.fixture -def triage_test_env(): +def triage_env(): """ Fixture to set up the test environment for triage tests. @@ -58,6 +58,9 @@ def triage_test_env(): test_case_content = (source_scripts_dir / "test-case.sh").read_text() (mock_scripts_path / "test-case.sh").write_text(test_case_content) os.chmod(mock_scripts_path / "test-case.sh", 0o755) + # fake bazel + (mock_scripts_path / "bazel").write_text("#!/bin/sh\nexit 0") + os.chmod(mock_scripts_path / "bazel", 0o755) # Create a git repository jax_repo_path = repo_path / "jax" @@ -74,25 +77,31 @@ def git_cmd(command, *args): # Create a linear commit history git_cmd("commit", "--allow-empty", "-m", "M1") m1 = git_cmd("rev-parse", "HEAD") + logging.info(f"M1: {m1}") git_cmd("commit", "--allow-empty", "-m", "M2") # good commit m2 = git_cmd("rev-parse", "HEAD") + logging.info(f"M2: {m2}") git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit m3 = git_cmd("rev-parse", "HEAD") + logging.info(f"M3: {m3}") # create a feature branch git_cmd("checkout", "-b", "feature", m1) (jax_repo_path / "feature_file.txt").write_text("feature") git_cmd("add", "feature_file.txt") git_cmd("commit", "-m", "F1") f1 = git_cmd("rev-parse", "HEAD") + logging.info(f"F1: {f1}") # here we're applying a feature to the good f1 commit git_cmd("checkout", "-b", "passing_nonlinear", m2) git_cmd("cherry-pick", f1) passing_nonlinear = git_cmd("rev-parse", "HEAD") + logging.info(f"Passing non-linear {passing_nonlinear}") # and then we apply the feature to the bad commit # this simulated the rebase scenario git_cmd("checkout", "-b", "failing_nonlinear", m3) git_cmd("cherry-pick", f1) failing_nonlinear = git_cmd("rev-parse", "HEAD") + logging.info(f"failing non linear {failing_nonlinear}") git_cmd("checkout", "main") # yield all the info @@ -127,6 +136,7 @@ def git_cmd(command, *args): ) def test_triage_scenarios( triage_env, + monkeypatch, scenario, passing_commit_key, failing_commit_key, @@ -137,6 +147,10 @@ def test_triage_scenarios( paths = triage_env["paths"] all_commits = triage_env["commits"] jax_repo_path = paths["repo"] / "jax" + passing_versions = {"jax": all_commits[passing_commit_key]} + failing_versions = {"jax": all_commits[failing_commit_key]} + passing_versions_str = f"jax:{all_commits[passing_commit_key]}" + failing_versions_str = f"jax:{all_commits[failing_commit_key]}" arg_list = [ "--main-branch", @@ -145,6 +159,10 @@ def test_triage_scenarios( str(paths["output"]), "--container-runtime", "local", + "--passing-versions", + passing_versions_str, + "--failing-versions", + failing_versions_str, str(paths["scripts"] / "test-case.sh"), str(jax_repo_path), all_commits["bad_main"], @@ -152,15 +170,23 @@ def test_triage_scenarios( args = parse_args(arg_list) logger = logging.getLogger(f"Scenario-{scenario}") logging.basicConfig(level=logging.INFO) + logger.info(f"Inputs args: {arg_list}") + # mp to mock bazel + original_path = os.environ.get("PATH", "") + monkeypatch.setenv("PATH", f"{paths['scripts']}:{original_path}") tool = TriageTool(args, logger) tool.package_dirs = {"jax": str(jax_repo_path)} tool.dynamic_packages = {"jax"} tool.bisection_url = "local" + if scenario == "Linear History": + linear_build_script_path = paths["scripts"] / "build-jax.sh" + linear_build_script_path.write_text( + "#!/bin/sh\necho 'Mock linear build successful.'\nexit 0" + ) + os.chmod(linear_build_script_path, 0o755) - passing_versions = {"jax": all_commits[passing_commit_key]} - failing_versions = {"jax": all_commits[failing_commit_key]} - + # run the bisection tool.run_version_bisection(passing_versions, failing_versions) summary_file = paths["output"] / "summary.json" assert summary_file.exists(), "The summary file was not created" From a1f49fd0308e464f0793a369b14a5425ee4b7c54 Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 11 Jul 2025 16:28:11 +0100 Subject: [PATCH 27/50] start fixing CI, the test still fails because of the relative path --- .github/triage/jax_toolbox_triage/bisect.py | 3 --- .github/triage/jax_toolbox_triage/triage_tool.py | 9 --------- .../triage/tests/test_triage_history_bisection.py | 13 ++++++------- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index 7f728c3cb..cda524441 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -113,7 +113,4 @@ def get_commit_history( date = datetime.datetime.fromisoformat(date).astimezone(datetime.timezone.utc) data.append((commit, date)) - logger.info(f"Data: {data}") - logger.info(f"Cherry pick {cherry_pick_range}") - return data, cherry_pick_range diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 57eeb8a03..02a388f92 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -143,9 +143,6 @@ def _gather_histories( # this check works only for linera-case if not self.cherry_pick_commits.get(package): - self.logger.info(f"package_versions: {package_versions}") - self.logger.info(f"passing versions: {passing_versions}") - self.logger.info(f"failing versions: {failing_versions}") assert passing_versions[package] == package_versions[package][0][0] assert failing_versions[package] == package_versions[package][-1][0] @@ -331,12 +328,7 @@ def _build_and_test( self.bisection_versions[package] = version changed.append(f"{package}@{version}") if package in self.package_dirs: - self.logger.info( - f"Package working on {package} form {self.package_dirs}" - ) - self.logger.info(f"And cherry pick commits {self.cherry_pick_commits}") cherry_pick_range = self.cherry_pick_commits.get(package) - self.logger.info(f"Cherry pick range {cherry_pick_range}") git_commands.append(f"cd {self.package_dirs[package]}") git_commands.append("git stash") # this is a checkout on the main branch @@ -346,7 +338,6 @@ def _build_and_test( git_commands.append( f"(git cherry-pick {cherry_pick_range} || (echo 'Cherry-pick failed' && exit 1) )" ) - self.logger.info(f"Full command {git_commands}") else: # Another software package, `version` is probably a version number. diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 7425bd41e..8b5442a42 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -77,31 +77,25 @@ def git_cmd(command, *args): # Create a linear commit history git_cmd("commit", "--allow-empty", "-m", "M1") m1 = git_cmd("rev-parse", "HEAD") - logging.info(f"M1: {m1}") git_cmd("commit", "--allow-empty", "-m", "M2") # good commit m2 = git_cmd("rev-parse", "HEAD") - logging.info(f"M2: {m2}") git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit m3 = git_cmd("rev-parse", "HEAD") - logging.info(f"M3: {m3}") # create a feature branch git_cmd("checkout", "-b", "feature", m1) (jax_repo_path / "feature_file.txt").write_text("feature") git_cmd("add", "feature_file.txt") git_cmd("commit", "-m", "F1") f1 = git_cmd("rev-parse", "HEAD") - logging.info(f"F1: {f1}") # here we're applying a feature to the good f1 commit git_cmd("checkout", "-b", "passing_nonlinear", m2) git_cmd("cherry-pick", f1) passing_nonlinear = git_cmd("rev-parse", "HEAD") - logging.info(f"Passing non-linear {passing_nonlinear}") # and then we apply the feature to the bad commit # this simulated the rebase scenario git_cmd("checkout", "-b", "failing_nonlinear", m3) git_cmd("cherry-pick", f1) failing_nonlinear = git_cmd("rev-parse", "HEAD") - logging.info(f"failing non linear {failing_nonlinear}") git_cmd("checkout", "main") # yield all the info @@ -152,6 +146,9 @@ def test_triage_scenarios( passing_versions_str = f"jax:{all_commits[passing_commit_key]}" failing_versions_str = f"jax:{all_commits[failing_commit_key]}" + bazel_cache_path = (paths["output"] / "bazel-cache").resolve() + bazel_cache_path.mkdir() + arg_list = [ "--main-branch", "main", @@ -163,6 +160,9 @@ def test_triage_scenarios( passing_versions_str, "--failing-versions", failing_versions_str, + "--bazel-cache", + str(bazel_cache_path), + "--", str(paths["scripts"] / "test-case.sh"), str(jax_repo_path), all_commits["bad_main"], @@ -170,7 +170,6 @@ def test_triage_scenarios( args = parse_args(arg_list) logger = logging.getLogger(f"Scenario-{scenario}") logging.basicConfig(level=logging.INFO) - logger.info(f"Inputs args: {arg_list}") # mp to mock bazel original_path = os.environ.get("PATH", "") monkeypatch.setenv("PATH", f"{paths['scripts']}:{original_path}") From 535aedb86f9747d644934044e7e70dfc2dd0a62e Mon Sep 17 00:00:00 2001 From: Steboss Date: Fri, 11 Jul 2025 16:29:51 +0100 Subject: [PATCH 28/50] fix CI --- .github/triage/jax_toolbox_triage/args.py | 10 ++++++---- .github/triage/jax_toolbox_triage/triage_tool.py | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index d52339597..3b0e6268e 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -282,7 +282,9 @@ def parse_args(args=None) -> argparse.Namespace: if args.container_runtime == "local": assert ( args.passing_versions is not None and args.failing_versions is not None - ), "For local runtime, --passing-versions and --failing-versions must be provided." + ), ( + "For local runtime, --passing-versions and --failing-versions must be provided." + ) assert ( args.container is None and args.start_date is None @@ -327,7 +329,7 @@ def parse_args(args=None) -> argparse.Namespace: else: # None of --{passing,failing}-{versions,container} were passed, make sure the # compulsory arguments for the container-level search were passed - assert ( - args.container is not None - ), "--container must be passed for the container-level search" + assert args.container is not None, ( + "--container must be passed for the container-level search" + ) return args diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 02a388f92..d53efabec 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -56,7 +56,9 @@ def _test_output_directory( ) out_dir = self.args.output_prefix / out_dirname - assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + assert not out_dir.exists(), ( + f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + ) out_dir.mkdir(mode=0o755) return out_dir.resolve() @@ -81,9 +83,9 @@ def _get_versions( """ if explicit_versions is not None and container_url is None: return explicit_versions, None, None, None - assert ( - container_url is not None - ), "Container URL must be provided if explicit versions are not set." + assert container_url is not None, ( + "Container URL must be provided if explicit versions are not set." + ) with make_container( self.args.container_runtime, From 3a6a3d4e2823492f56add795b3ad4e62174b683e Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 14 Jul 2025 16:46:19 +0100 Subject: [PATCH 29/50] fix comments --- .github/triage/jax_toolbox_triage/args.py | 11 +++++------ .github/triage/jax_toolbox_triage/bisect.py | 3 ++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index 3b0e6268e..fd32d859d 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -228,6 +228,7 @@ def parse_args(args=None) -> argparse.Namespace: version_search_args.add_argument( "--override-remotes", type=parse_override_remotes, + default={}, help="""Remote URLs to be used for fetching, including auth token. E.g.: jax:https://@host/repo.git,xla:https://@host/repo.git """, @@ -282,9 +283,7 @@ def parse_args(args=None) -> argparse.Namespace: if args.container_runtime == "local": assert ( args.passing_versions is not None and args.failing_versions is not None - ), ( - "For local runtime, --passing-versions and --failing-versions must be provided." - ) + ), "For local runtime, --passing-versions and --failing-versions must be provided." assert ( args.container is None and args.start_date is None @@ -329,7 +328,7 @@ def parse_args(args=None) -> argparse.Namespace: else: # None of --{passing,failing}-{versions,container} were passed, make sure the # compulsory arguments for the container-level search were passed - assert args.container is not None, ( - "--container must be passed for the container-level search" - ) + assert ( + args.container is not None + ), "--container must be passed for the container-level search" return args diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index cda524441..b3e8c275d 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -42,7 +42,7 @@ def get_commit_history( ) if commits_known.returncode != 0: if args.override_remotes and package in args.override_remotes: - remote_url = args.override_remotes[package] + remote_url = args.override_remotes.get(package, "origin") fetch_cmd = ["git", "fetch", remote_url, start, end] worker.check_exec(fetch_cmd, workdir=dir) else: @@ -84,6 +84,7 @@ def get_commit_history( passing_main_commit, failing_main_commit = passing_and_failing_cmd.splitlines() # 2. find commits to cherry-pick from the failing branch + # TODO: as an alternative approach we may need to consider `{passing_main_commit}..{start}` cherry_pick_range[package] = f"{failing_main_commit}..{end}" # 3. now we can use the main branch commits for bisection From 8342ee09e31acb2ca92d2fbc0ecf2f4b2158d010 Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 14 Jul 2025 16:48:15 +0100 Subject: [PATCH 30/50] Update .github/triage/tests/test_triage_history_bisection.py Co-authored-by: Olli Lupton --- .github/triage/tests/test_triage_history_bisection.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 8b5442a42..27fb35c22 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -170,9 +170,8 @@ def test_triage_scenarios( args = parse_args(arg_list) logger = logging.getLogger(f"Scenario-{scenario}") logging.basicConfig(level=logging.INFO) - # mp to mock bazel - original_path = os.environ.get("PATH", "") - monkeypatch.setenv("PATH", f"{paths['scripts']}:{original_path}") + # Ensure bazel and build-jax.sh can be found. + monkeypatch.setenv("PATH", paths['scripts'], prepend=':') tool = TriageTool(args, logger) tool.package_dirs = {"jax": str(jax_repo_path)} From 4e56a73c0f3940f028e16fd89118d6c23702a2c7 Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 14 Jul 2025 16:50:19 +0100 Subject: [PATCH 31/50] Update .github/triage/tests/test_triage_history_bisection.py Co-authored-by: Olli Lupton --- .github/triage/tests/test_triage_history_bisection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 27fb35c22..5d6c8ff20 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -73,7 +73,7 @@ def git_cmd(command, *args): # why don't we push the scripts and use them in repo git_cmd("init", "-b", "main") git_cmd("config", "user.name", "Test User") - git_cmd("config", "user.email", "test@user.it") + git_cmd("config", "user.email", "test@example.com") # Create a linear commit history git_cmd("commit", "--allow-empty", "-m", "M1") m1 = git_cmd("rev-parse", "HEAD") From 7d760e43da017e26a33904de027c4b3535711e7a Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 14 Jul 2025 16:51:25 +0100 Subject: [PATCH 32/50] Update .github/triage/tests/test_triage_history_bisection.py Co-authored-by: Olli Lupton --- .github/triage/tests/test_triage_history_bisection.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 5d6c8ff20..8be671fd7 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -86,10 +86,6 @@ def git_cmd(command, *args): (jax_repo_path / "feature_file.txt").write_text("feature") git_cmd("add", "feature_file.txt") git_cmd("commit", "-m", "F1") - f1 = git_cmd("rev-parse", "HEAD") - # here we're applying a feature to the good f1 commit - git_cmd("checkout", "-b", "passing_nonlinear", m2) - git_cmd("cherry-pick", f1) passing_nonlinear = git_cmd("rev-parse", "HEAD") # and then we apply the feature to the bad commit # this simulated the rebase scenario From 08029c11e0e94d436ecebb73d54bbdd916347938 Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 14 Jul 2025 16:52:01 +0100 Subject: [PATCH 33/50] Update .github/triage/tests/test_triage_history_bisection.py Co-authored-by: Olli Lupton --- .github/triage/tests/test_triage_history_bisection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 8be671fd7..12fbc3b60 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -90,7 +90,7 @@ def git_cmd(command, *args): # and then we apply the feature to the bad commit # this simulated the rebase scenario git_cmd("checkout", "-b", "failing_nonlinear", m3) - git_cmd("cherry-pick", f1) + git_cmd("cherry-pick", passing_nonlinear) failing_nonlinear = git_cmd("rev-parse", "HEAD") git_cmd("checkout", "main") From cd64936dad063c90726ff6632e249d952274fa57 Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 14 Jul 2025 16:59:54 +0100 Subject: [PATCH 34/50] Update .github/triage/tests/test_triage_history_bisection.py Co-authored-by: Olli Lupton --- .github/triage/tests/test_triage_history_bisection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 12fbc3b60..cbf27b5f7 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -102,7 +102,7 @@ def git_cmd(command, *args): "scripts": mock_scripts_path, }, "commits": { - "good_main": m2, + "good_main": m1, "bad_main": m3, "feature": f1, "passing_nonlinear": passing_nonlinear, From 43b4ce11791f60d432fdcb60947295e596f091ca Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 14 Jul 2025 17:00:38 +0100 Subject: [PATCH 35/50] Update .github/triage/tests/test_triage_history_bisection.py Co-authored-by: Olli Lupton --- .github/triage/tests/test_triage_history_bisection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index cbf27b5f7..60382dc15 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -78,7 +78,6 @@ def git_cmd(command, *args): git_cmd("commit", "--allow-empty", "-m", "M1") m1 = git_cmd("rev-parse", "HEAD") git_cmd("commit", "--allow-empty", "-m", "M2") # good commit - m2 = git_cmd("rev-parse", "HEAD") git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit m3 = git_cmd("rev-parse", "HEAD") # create a feature branch From 6d125aa7291212bc44183c1ad76455ff677e1fcf Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 14 Jul 2025 17:01:36 +0100 Subject: [PATCH 36/50] Update .github/triage/tests/test_triage_history_bisection.py Co-authored-by: Olli Lupton --- .github/triage/tests/test_triage_history_bisection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 60382dc15..f5f556a15 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -80,7 +80,7 @@ def git_cmd(command, *args): git_cmd("commit", "--allow-empty", "-m", "M2") # good commit git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit m3 = git_cmd("rev-parse", "HEAD") - # create a feature branch + # create a feature branch; feature_file.txt must exist for the mock build-jax.sh to return true git_cmd("checkout", "-b", "feature", m1) (jax_repo_path / "feature_file.txt").write_text("feature") git_cmd("add", "feature_file.txt") From a97433617141a91fa34e4d107d2e0c5819586b33 Mon Sep 17 00:00:00 2001 From: Steboss Date: Mon, 14 Jul 2025 17:01:52 +0100 Subject: [PATCH 37/50] Update .github/triage/tests/test_triage_history_bisection.py Co-authored-by: Olli Lupton --- .github/triage/tests/test_triage_history_bisection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index f5f556a15..86797e171 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -103,7 +103,6 @@ def git_cmd(command, *args): "commits": { "good_main": m1, "bad_main": m3, - "feature": f1, "passing_nonlinear": passing_nonlinear, "failing_nonlinear": failing_nonlinear, }, From 654d0844ae87364e24d0b78d7336da4456557565 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 11:16:45 +0100 Subject: [PATCH 38/50] update tests for linear case, update output of triage tool --- .github/triage/jax_toolbox_triage/bisect.py | 15 ++++- .../triage/jax_toolbox_triage/triage_tool.py | 21 +++--- .../tests/test_triage_history_bisection.py | 65 ++++++++++++------- 3 files changed, 69 insertions(+), 32 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index b3e8c275d..184ec7a37 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -91,6 +91,17 @@ def get_commit_history( start = passing_main_commit end = failing_main_commit + logger.info( + f"INFO: cherry_pick_range {cherry_pick_range}, start: {start} and end {end}" + ) + # check if the start is the root commit. We may have to deal with the very start of the repo + # so we need to handle this case too + parent_check_result = worker.check_exec( + ["git", "rev-list", "--parents", "-n", "1", start], workdir=dir + ) + is_root_commit = len(parent_check_result.stdout.strip().split()) == 1 + + # now create the right git command to retrieve the history between start..end result = worker.check_exec( [ "git", @@ -98,13 +109,13 @@ def get_commit_history( "--first-parent", "--reverse", "--format=%H %cI", - f"{start}^..{end}", + f"{start}..{end}" if is_root_commit else f"{start}^..{end}", ], policy="once", stderr=subprocess.PIPE, workdir=dir, ) - logger.debug(f"stderr: {result.stderr.strip()}") + data = [] for line in result.stdout.splitlines(): commit, date = line.split() diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index d53efabec..914613cda 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -5,7 +5,8 @@ import logging import pathlib import time -from typing import Dict, Tuple, Union +import json +from typing import Dict, Tuple, Union, Any from .container import Container from .logic import container_search, TestResult, version_search @@ -56,9 +57,7 @@ def _test_output_directory( ) out_dir = self.args.output_prefix / out_dirname - assert not out_dir.exists(), ( - f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" - ) + assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" out_dir.mkdir(mode=0o755) return out_dir.resolve() @@ -83,9 +82,9 @@ def _get_versions( """ if explicit_versions is not None and container_url is None: return explicit_versions, None, None, None - assert container_url is not None, ( - "Container URL must be provided if explicit versions are not set." - ) + assert ( + container_url is not None + ), "Container URL must be provided if explicit versions are not set." with make_container( self.args.container_runtime, @@ -536,7 +535,7 @@ def run_version_bisection( self, passing_versions: Dict[str, str], failing_versions: Dict[str, str], - ) -> None: + ) -> Dict[str, Any]: """ Run the version bisection process. @@ -574,3 +573,9 @@ def run_version_bisection( result["container"] = self.bisection_url add_summary_record(self.args.output_prefix, "result", result, scalar=True) self.logger.info("Version-level bisection completed") + + summary_file = self.args.output_prefix / "summary.json" + with open(summary_file, "r") as ifile: + summary_data = json.load(ifile) + + return summary_data diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 12fbc3b60..16d2cf747 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -3,7 +3,6 @@ import pathlib import os import logging -import json import pytest from jax_toolbox_triage.args import parse_args @@ -69,7 +68,7 @@ def triage_env(): def git_cmd(command, *args): return run_command(["git", command, *args], cwd=jax_repo_path) - # main + # NON-LINEAR HISTORY # why don't we push the scripts and use them in repo git_cmd("init", "-b", "main") git_cmd("config", "user.name", "Test User") @@ -81,18 +80,38 @@ def git_cmd(command, *args): m2 = git_cmd("rev-parse", "HEAD") git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit m3 = git_cmd("rev-parse", "HEAD") - # create a feature branch + # create a feature branch from M1 git_cmd("checkout", "-b", "feature", m1) (jax_repo_path / "feature_file.txt").write_text("feature") git_cmd("add", "feature_file.txt") git_cmd("commit", "-m", "F1") - passing_nonlinear = git_cmd("rev-parse", "HEAD") + passing_nonlinear = git_cmd("rev-parse", "HEAD") # F1 # and then we apply the feature to the bad commit # this simulated the rebase scenario git_cmd("checkout", "-b", "failing_nonlinear", m3) git_cmd("cherry-pick", passing_nonlinear) - failing_nonlinear = git_cmd("rev-parse", "HEAD") - git_cmd("checkout", "main") + failing_nonlinear = git_cmd("rev-parse", "HEAD") # F1' + # so now we have: + # M1 --- M2 --- M3 + # | | + # F1 F1' + # where F1 = passing + # and F2 = failing + + # LINEAR HISTORY + git_cmd("checkout", "-b", "linear_feature_branch", passing_nonlinear) + git_cmd("commit", "--allow-empty", "-m", "L1") + l1_good_commit = git_cmd("rev-parse", "HEAD") # L1 + git_cmd("commit", "--allow-empty", "-m", "L2_BAD") # L2 bad commit + l2_bad_linear_commit = git_cmd("rev-parse", "HEAD") + git_cmd("commit", "--allow-empty", "-m", "L3") # L3 + l3_failing_linear = git_cmd("rev-parse", "HEAD") + # so the linear repo would be + # M1 -- M2 -- M3 + # | + # F1 + # | + # L1 -- L2 -- L3 # yield all the info yield { @@ -104,24 +123,35 @@ def git_cmd(command, *args): "commits": { "good_main": m2, "bad_main": m3, - "feature": f1, + "feature": passing_nonlinear, "passing_nonlinear": passing_nonlinear, "failing_nonlinear": failing_nonlinear, + "good_linear": l1_good_commit, + "bad_commit_for_linear": l2_bad_linear_commit, + "failing_linear": l3_failing_linear, }, } @pytest.mark.parametrize( - "scenario, passing_commit_key, failing_commit_key, expected_good_key, expected_bad_key", + "scenario, passing_commit_key, failing_commit_key, bad_commit_key, expected_good_key, expected_bad_key", [ ( "Non-Linear History", "passing_nonlinear", "failing_nonlinear", + "bad_main", "good_main", "bad_main", ), - ("Linear History", "good_main", "bad_main", "good_main", "bad_main"), + ( + "Linear History", + "good_linear", + "failing_linear", + "bad_commit_for_linear", + "good_linear", + "bad_commit_for_linear", + ), ], ) def test_triage_scenarios( @@ -130,6 +160,7 @@ def test_triage_scenarios( scenario, passing_commit_key, failing_commit_key, + bad_commit_key, expected_good_key, expected_bad_key, ): @@ -161,31 +192,21 @@ def test_triage_scenarios( "--", str(paths["scripts"] / "test-case.sh"), str(jax_repo_path), - all_commits["bad_main"], + all_commits[bad_commit_key], ] args = parse_args(arg_list) logger = logging.getLogger(f"Scenario-{scenario}") logging.basicConfig(level=logging.INFO) # Ensure bazel and build-jax.sh can be found. - monkeypatch.setenv("PATH", paths['scripts'], prepend=':') + monkeypatch.setenv("PATH", paths["scripts"], prepend=":") tool = TriageTool(args, logger) tool.package_dirs = {"jax": str(jax_repo_path)} tool.dynamic_packages = {"jax"} tool.bisection_url = "local" - if scenario == "Linear History": - linear_build_script_path = paths["scripts"] / "build-jax.sh" - linear_build_script_path.write_text( - "#!/bin/sh\necho 'Mock linear build successful.'\nexit 0" - ) - os.chmod(linear_build_script_path, 0o755) # run the bisection - tool.run_version_bisection(passing_versions, failing_versions) - summary_file = paths["output"] / "summary.json" - assert summary_file.exists(), "The summary file was not created" - with open(summary_file, "r") as f: - summary_data = json.load(f) + summary_data = tool.run_version_bisection(passing_versions, failing_versions) assert "result" in summary_data, "No result section was created" result = summary_data["result"] From fc4abfb6eb7f517376ad34aff47f86a141dba7f3 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 11:53:44 +0100 Subject: [PATCH 39/50] fix code with ruff and mypy --- .github/triage/jax_toolbox_triage/args.py | 10 ++++++---- .github/triage/jax_toolbox_triage/bisect.py | 10 ++-------- .github/triage/jax_toolbox_triage/triage_tool.py | 10 ++++++---- .github/triage/tests/test_triage_history_bisection.py | 3 ++- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index fd32d859d..17ee95872 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -283,7 +283,9 @@ def parse_args(args=None) -> argparse.Namespace: if args.container_runtime == "local": assert ( args.passing_versions is not None and args.failing_versions is not None - ), "For local runtime, --passing-versions and --failing-versions must be provided." + ), ( + "For local runtime, --passing-versions and --failing-versions must be provided." + ) assert ( args.container is None and args.start_date is None @@ -328,7 +330,7 @@ def parse_args(args=None) -> argparse.Namespace: else: # None of --{passing,failing}-{versions,container} were passed, make sure the # compulsory arguments for the container-level search were passed - assert ( - args.container is not None - ), "--container must be passed for the container-level search" + assert args.container is not None, ( + "--container must be passed for the container-level search" + ) return args diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index 184ec7a37..e1e878509 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -100,17 +100,11 @@ def get_commit_history( ["git", "rev-list", "--parents", "-n", "1", start], workdir=dir ) is_root_commit = len(parent_check_result.stdout.strip().split()) == 1 + log_range = f"{start}..{end}" if is_root_commit else f"{start}^..{end}" # now create the right git command to retrieve the history between start..end result = worker.check_exec( - [ - "git", - "log", - "--first-parent", - "--reverse", - "--format=%H %cI", - f"{start}..{end}" if is_root_commit else f"{start}^..{end}", - ], + ["git", "log", "--first-parent", "--reverse", "--format=%H %cI", log_range], policy="once", stderr=subprocess.PIPE, workdir=dir, diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 914613cda..6e7c4be1e 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -57,7 +57,9 @@ def _test_output_directory( ) out_dir = self.args.output_prefix / out_dirname - assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + assert not out_dir.exists(), ( + f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + ) out_dir.mkdir(mode=0o755) return out_dir.resolve() @@ -82,9 +84,9 @@ def _get_versions( """ if explicit_versions is not None and container_url is None: return explicit_versions, None, None, None - assert ( - container_url is not None - ), "Container URL must be provided if explicit versions are not set." + assert container_url is not None, ( + "Container URL must be provided if explicit versions are not set." + ) with make_container( self.args.container_runtime, diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 3dfd17bb6..c2126759d 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -77,6 +77,7 @@ def git_cmd(command, *args): git_cmd("commit", "--allow-empty", "-m", "M1") m1 = git_cmd("rev-parse", "HEAD") git_cmd("commit", "--allow-empty", "-m", "M2") # good commit + m2 = git_cmd("rev-parse", "HEAD") git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit m3 = git_cmd("rev-parse", "HEAD") # create a feature branch from M1 @@ -120,7 +121,7 @@ def git_cmd(command, *args): "scripts": mock_scripts_path, }, "commits": { - "good_main": m1, + "good_main": m2, # last good commit "bad_main": m3, "passing_nonlinear": passing_nonlinear, "failing_nonlinear": failing_nonlinear, From 9cd243c9a0a9e7f0272649e32b1f793dac115700 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 11:54:04 +0100 Subject: [PATCH 40/50] remove old test --- .../triage/tests/test_triage_tool_class.py | 244 ------------------ 1 file changed, 244 deletions(-) delete mode 100644 .github/triage/tests/test_triage_tool_class.py diff --git a/.github/triage/tests/test_triage_tool_class.py b/.github/triage/tests/test_triage_tool_class.py deleted file mode 100644 index 3fd212baa..000000000 --- a/.github/triage/tests/test_triage_tool_class.py +++ /dev/null @@ -1,244 +0,0 @@ -import subprocess -import tempfile -import pathlib -import os -import logging -import pytest - -from jax_toolbox_triage.triage_tool import TriageTool -from jax_toolbox_triage.logic import version_search -from jax_toolbox_triage.container import Container - - -def run_command(command, cwd=None, env=None): - """Simple function to run a command in a subprocess. - - Args: - command (list): The command to run as a list of strings. - cwd (str, optional): The working directory to run the command in. - env (dict, optional): Environment variables to set for the command. - Returns: - str: The standard output of the command. - """ - try: - result = subprocess.run( - command, cwd=cwd, env=env, check=True, capture_output=True, text=True - ) - return result.stdout.strip() - except subprocess.CalledProcessError as e: - logging.error(f"Command '{' '.join(command)}' failed with error: {e}") - raise e - - -class MockContainer(Container): - """A mock container class for testing purposes.""" - - def __init__(self, mock_scripts_path, logger): - super().__init__(logger=logger) - self.mock_scripts_path = mock_scripts_path - self._env = os.environ.copy() - self._env["PATH"] = f"{self.mock_scripts_path}:{self._env['PATH']}" - - def __enter__(self): - return self - - def __exit__(self, *exc_info): - pass - - def __repr__(self): - return "MockContainer" - - def check_exec(self, cmd, **kwargs): - """Override the check_exec""" - return super().check_exec(cmd, **kwargs) - - def exec( - self, - command, - *, - policy="default", - stderr="interleaved", - workdir=None, - log_level=logging.DEBUG, - ): - self._logger.debug(f"Executing command: {command} in {workdir}") - is_shell_command = command[0] == "sh" and command[1] == "-c" - cmd_to_run = command[2] if is_shell_command else command - try: - return subprocess.run( - cmd_to_run, - capture_output=True, - text=True, - cwd=workdir, - env=self._env, - shell=is_shell_command, - ) - except FileNotFoundError as e: - return subprocess.CompletedProcess(command, 127, stderr=str(e)) - - def exists(self) -> bool: - return True - - -@pytest.fixture -def triage_test_env(): - """ - Fixture to set up the test environment for triage tests. - - The fixture creates a temp directory and a git repo with a - defined linear and non-linear history. - - The fixture yields a dictionary of paths and commit hashes - """ - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = pathlib.Path(temp_dir) - repo_path = temp_path / "repos" - output_path = temp_path / "output" - mock_scripts_path = temp_path / "mock_scripts" - repo_path.mkdir() - output_path.mkdir() - mock_scripts_path.mkdir() - - source_scripts_dir = pathlib.Path(__file__).parent / "mock_scripts" - build_script_content = (source_scripts_dir / "build-jax.sh").read_text() - (mock_scripts_path / "build-jax.sh").write_text(build_script_content) - os.chmod(mock_scripts_path / "build-jax.sh", 0o755) - - test_case_content = (source_scripts_dir / "test-case.sh").read_text() - (mock_scripts_path / "test-case.sh").write_text(test_case_content) - os.chmod(mock_scripts_path / "test-case.sh", 0o755) - - # Create a fake bazel executable - (mock_scripts_path / "bazel").write_text("#!/bin/sh\necho bazel") - os.chmod(mock_scripts_path / "bazel", 0o755) - - # setup the jax repo path - jax_repo_path = repo_path / "jax" - jax_repo_path.mkdir() - - def git_cmd(command, *args): - return run_command(["git", command, *args], cwd=jax_repo_path) - - git_cmd("init", "-b", "main") - git_cmd("remote", "add", "origin", str(jax_repo_path)) - git_cmd("config", "user.name", "Test User") - git_cmd("config", "user.email", "test@user.it") - - git_cmd("commit", "--allow-empty", "-m", "M1") - m1 = git_cmd("rev-parse", "HEAD") - - git_cmd("commit", "--allow-empty", "-m", "M2") - m2 = git_cmd("rev-parse", "HEAD") - - git_cmd("commit", "--allow-empty", "-m", "M3") - m3 = git_cmd("rev-parse", "HEAD") - - git_cmd("checkout", "-b", "feature", m1) - (jax_repo_path / "feature_file.txt").write_text("feature") - git_cmd("add", "feature_file.txt") - git_cmd("commit", "-m", "F1") - f1 = git_cmd("rev-parse", "HEAD") - - git_cmd("checkout", "-b", "passing_nonlinear", m2) - git_cmd("cherry-pick", f1) - passing_nonlinear = git_cmd("rev-parse", "HEAD") - - git_cmd("checkout", "-b", "failing_nonlinear", m3) - git_cmd("cherry-pick", f1) - failing_nonlinear = git_cmd("rev-parse", "HEAD") - - git_cmd("checkout", "main") - - yield { - "paths": { - "repo": repo_path, - "output": output_path, - "scripts": mock_scripts_path, - }, - "commits": { - "good_main": m2, - "bad_main": m3, - "feature": f1, - "passing_nonlinear": passing_nonlinear, - "failing_nonlinear": failing_nonlinear, - }, - } - - -@pytest.mark.parametrize( - "scenario, passing_commit_key, failing_commit_key, expected_good_key, expected_bad_key", - [ - ( - "Non-Linear History", - "passing_nonlinear", - "failing_nonlinear", - "good_main", - "bad_main", - ), - ("Linear History", "good_main", "bad_main", "good_main", "bad_main"), - ], -) -def test_triage_scenarios( - triage_test_env, - monkeypatch, - scenario, - passing_commit_key, - failing_commit_key, - expected_good_key, - expected_bad_key, -): - """Tests the TriageTool class.""" - paths = triage_test_env["paths"] - all_commits = triage_test_env["commits"] - jax_repo_path = paths["repo"] / "jax" - - class MockArgs: - main_branch = "main" - bazel_cache = "" - build_scripts_path = None - test_command = ["test-case.sh", str(jax_repo_path), all_commits["bad_main"]] - cherry_pick_commits = {} - output_prefix = paths["output"] - container_runtime = "mock" # Use a mock runtime - container_mount = [] - - args = MockArgs() - logger = logging.getLogger(f"Scenario-{scenario}") - logging.basicConfig(level=logging.INFO) - - tool = TriageTool(args, logger) - tool.package_dirs = {"jax": str(jax_repo_path)} - tool.dynamic_packages = {"jax"} - tool.bisection_url = "mock_url" - - # Set up a monkeypatch for the container creation - # in this way we're using MockContainer rather than make_container - mock_container = MockContainer(paths["scripts"], logger) - monkeypatch.setattr( - "jax_toolbox_triage.triage_tool.make_container", lambda *a, **kw: mock_container - ) - - # In case of linear-history scenario, we need a fake jax script too - if scenario == "Linear History": - linear_build_script_path = paths["scripts"] / "build-jax.sh" - linear_build_script_path.write_text( - "#!/bin/sh\necho 'Mock linear build successful.'\nexit 0" - ) - os.chmod(linear_build_script_path, 0o755) - - passing_versions = {"jax": all_commits[passing_commit_key]} - failing_versions = {"jax": all_commits[failing_commit_key]} - - package_versions = tool._gather_histories( - mock_container, passing_versions, failing_versions - ) - # Bisection test - result, _, _ = version_search( - versions=package_versions, - build_and_test=tool._build_and_test, - logger=logger, - skip_precondition_checks=False, - ) - - assert result.get("jax_good") == all_commits[expected_good_key] - assert result.get("jax_bad") == all_commits[expected_bad_key] From 13c9c0cda90256de9c7bffb9b39d5a8a206c66c0 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 15:15:02 +0100 Subject: [PATCH 41/50] Update .github/triage/jax_toolbox_triage/bisect.py Co-authored-by: Olli Lupton --- .github/triage/jax_toolbox_triage/bisect.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index e1e878509..c6c6e7d07 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -41,25 +41,7 @@ def get_commit_history( workdir=dir, ) if commits_known.returncode != 0: - if args.override_remotes and package in args.override_remotes: - remote_url = args.override_remotes.get(package, "origin") - fetch_cmd = ["git", "fetch", remote_url, start, end] - worker.check_exec(fetch_cmd, workdir=dir) - else: - # default behaviour - # re: https://stackoverflow.com/questions/4089430/how-to-determine-the-url-that-a-local-git-repository-was-originally-cloned-from - remote_url_result = worker.check_exec( - ["git", "config", "--get", "remote.origin.url"], workdir=dir - ) - if remote_url_result.returncode == 0: - remote_url = remote_url_result.stdout().strip() - fetch_cmd = ["git", "fetch", remote_url, start, end] - worker.check_exec(fetch_cmd, workdir=dir) - else: - logger.error( - "No remote 'origin' found and no override provided. Cannot fetch missing commits" - ) - raise Exception("Cannot find commits and no remote is configured") + worker.check_exec(["git", "fetch", args.override_remotes.get(package, "origin"), start, end], workdir=dir)``` # detect non-linear history is_ancestor_result = worker.exec( From c79112451002d75bbb6596528af2a9c8b34dccbe Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 15:15:46 +0100 Subject: [PATCH 42/50] Update .github/triage/jax_toolbox_triage/triage_tool.py Co-authored-by: Olli Lupton --- .github/triage/jax_toolbox_triage/triage_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 6e7c4be1e..c94a39bf3 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -145,7 +145,7 @@ def _gather_histories( ) # this check works only for linera-case - if not self.cherry_pick_commits.get(package): + if package not in self.cherry_pick_commits: assert passing_versions[package] == package_versions[package][0][0] assert failing_versions[package] == package_versions[package][-1][0] From 9dcbaf104966b03dd57c396b5e0d768972f43c89 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 15:15:58 +0100 Subject: [PATCH 43/50] Update .github/triage/jax_toolbox_triage/triage_tool.py Co-authored-by: Olli Lupton --- .github/triage/jax_toolbox_triage/triage_tool.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index c94a39bf3..b34fcf629 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -204,12 +204,7 @@ def _check_container_by_date( Returns: TestResult: The result of the test, including whether it passed and the output. """ - container_url_func = functools.partial( - container_url_base, - container=self.args.container, - template=self.args.container_url_template, - ) - container_url = container_url_func(date) + container_url = container_url_base(date, container=self.args.container, template=self.args.container_url_template) before = time.monotonic() out_dir = self._test_output_directory(container_url, None) From e4654992246cd098e72e3e62f1dc56f64964205f Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 15:16:32 +0100 Subject: [PATCH 44/50] Update .github/triage/jax_toolbox_triage/summary.py Co-authored-by: Olli Lupton --- .github/triage/jax_toolbox_triage/summary.py | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/triage/jax_toolbox_triage/summary.py b/.github/triage/jax_toolbox_triage/summary.py index 53db602cc..3db97673c 100644 --- a/.github/triage/jax_toolbox_triage/summary.py +++ b/.github/triage/jax_toolbox_triage/summary.py @@ -43,6 +43,7 @@ def add_summary_record( with open(summary_filename, "w") as ofile: json.dump(data, ofile) + return data def create_output_symlinks( From b07f9604a2e4962205e697fe70f591918d36f7de Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 15:16:42 +0100 Subject: [PATCH 45/50] Update .github/triage/jax_toolbox_triage/triage_tool.py Co-authored-by: Olli Lupton --- .github/triage/jax_toolbox_triage/triage_tool.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index b34fcf629..ca0f92661 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -568,11 +568,5 @@ def run_version_bisection( self.args.output_prefix, last_known_good, first_known_bad ) result["container"] = self.bisection_url - add_summary_record(self.args.output_prefix, "result", result, scalar=True) self.logger.info("Version-level bisection completed") - - summary_file = self.args.output_prefix / "summary.json" - with open(summary_file, "r") as ifile: - summary_data = json.load(ifile) - - return summary_data + return add_summary_record(self.args.output_prefix, "result", result, scalar=True) From 48665b3ca872b18241f4370d022e6f529a34848c Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 15:16:49 +0100 Subject: [PATCH 46/50] Update .github/triage/tests/test_triage_history_bisection.py Co-authored-by: Olli Lupton --- .github/triage/tests/test_triage_history_bisection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index c2126759d..fc39e7d23 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -96,7 +96,7 @@ def git_cmd(command, *args): # | | # F1 F1' # where F1 = passing - # and F2 = failing + # and F1' = failing # LINEAR HISTORY git_cmd("checkout", "-b", "linear_feature_branch", passing_nonlinear) From 881fa0b8f9505b97ff8beb9b5304d05ebd23a155 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 15:17:15 +0100 Subject: [PATCH 47/50] Update .github/triage/jax_toolbox_triage/bisect.py Co-authored-by: Olli Lupton --- .github/triage/jax_toolbox_triage/bisect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index c6c6e7d07..0942a00cf 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -74,7 +74,7 @@ def get_commit_history( end = failing_main_commit logger.info( - f"INFO: cherry_pick_range {cherry_pick_range}, start: {start} and end {end}" + f"cherry_pick_range: {cherry_pick_range}, start: {start}, end: {end}" ) # check if the start is the root commit. We may have to deal with the very start of the repo # so we need to handle this case too From 729f4d911dce12452d6596bdff9e937afeee3a20 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 15:51:21 +0100 Subject: [PATCH 48/50] fix errors and comments --- .github/triage/jax_toolbox_triage/bisect.py | 21 +++--- .../triage/jax_toolbox_triage/triage_tool.py | 75 ++++++++++--------- .github/triage/tests/mock_scripts/bazel | 3 + .../tests/test_triage_history_bisection.py | 27 ++----- 4 files changed, 59 insertions(+), 67 deletions(-) create mode 100755 .github/triage/tests/mock_scripts/bazel diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index c6c6e7d07..135328474 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -41,7 +41,10 @@ def get_commit_history( workdir=dir, ) if commits_known.returncode != 0: - worker.check_exec(["git", "fetch", args.override_remotes.get(package, "origin"), start, end], workdir=dir)``` + worker.check_exec( + ["git", "fetch", args.override_remotes.get(package, "origin"), start, end], + workdir=dir, + ) # detect non-linear history is_ancestor_result = worker.exec( @@ -76,17 +79,17 @@ def get_commit_history( logger.info( f"INFO: cherry_pick_range {cherry_pick_range}, start: {start} and end {end}" ) - # check if the start is the root commit. We may have to deal with the very start of the repo - # so we need to handle this case too - parent_check_result = worker.check_exec( - ["git", "rev-list", "--parents", "-n", "1", start], workdir=dir - ) - is_root_commit = len(parent_check_result.stdout.strip().split()) == 1 - log_range = f"{start}..{end}" if is_root_commit else f"{start}^..{end}" # now create the right git command to retrieve the history between start..end result = worker.check_exec( - ["git", "log", "--first-parent", "--reverse", "--format=%H %cI", log_range], + [ + "git", + "log", + "--first-parent", + "--reverse", + "--format=%H %cI", + f"{start}^..{end}", + ], policy="once", stderr=subprocess.PIPE, workdir=dir, diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index ca0f92661..6088bcf45 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -5,8 +5,7 @@ import logging import pathlib import time -import json -from typing import Dict, Tuple, Union, Any +from typing import Dict, Tuple, Union, Any, Optional from .container import Container from .logic import container_search, TestResult, version_search @@ -57,12 +56,29 @@ def _test_output_directory( ) out_dir = self.args.output_prefix / out_dirname - assert not out_dir.exists(), ( - f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" - ) + assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" out_dir.mkdir(mode=0o755) return out_dir.resolve() + def _make_container( + self, url: str, test_output_directory: Optional[pathlib.Path] = None + ) -> Container: + """ + Wrapper for make_container factory + + Args: + url: (str), the input url of the docker image + test_output_directory: (pathlib.Path), the path to the output directory + + Returns: + Container object + """ + mounts = self.bazel_cache_mounts + self.args.container_mount + if test_output_directory is not None: + mounts.append((test_output_directory, "/triage-tool-output")) + + return make_container(self.args.container_runtime, url, mounts, self.logger) + def _get_versions( self, container_url: str, @@ -84,16 +100,11 @@ def _get_versions( """ if explicit_versions is not None and container_url is None: return explicit_versions, None, None, None - assert container_url is not None, ( - "Container URL must be provided if explicit versions are not set." - ) + assert ( + container_url is not None + ), "Container URL must be provided if explicit versions are not set." - with make_container( - self.args.container_runtime, - container_url, - self.bazel_cache_mounts, - self.logger, - ) as worker: + with self._make_container(container_url) as worker: url_versions, dirs, env = get_versions_dirs_env(worker, versions_from_env) overriden_versions = url_versions.copy() if explicit_versions is not None: @@ -204,16 +215,17 @@ def _check_container_by_date( Returns: TestResult: The result of the test, including whether it passed and the output. """ - container_url = container_url_base(date, container=self.args.container, template=self.args.container_url_template) + container_url = container_url_base( + date, + container=self.args.container, + template=self.args.container_url_template, + ) before = time.monotonic() out_dir = self._test_output_directory(container_url, None) - # this is from the previous Container class implementation in main - mounts = self.args.container_mount + [(out_dir, "/triage-tool-output")] - - with make_container( - self.args.container_runtime, container_url, mounts, self.logger + with self._make_container( + container_url, test_output_directory=out_dir ) as worker: versions, _, _ = get_versions_dirs_env( worker, self.args.build_scripts_path is not None @@ -361,17 +373,9 @@ def _build_and_test( out_dir = self._test_output_directory( self.bisection_url, versions=brief_versions ) - mounts = ( - self.bazel_cache_mounts - + self.args.container_mount - + [(out_dir, "/triage-tool-output")] - ) - with make_container( - self.args.container_runtime, - self.bisection_url, - mounts, - self.logger, + with self._make_container( + self.bisection_url, test_output_directory=out_dir ) as worker: change_str = " ".join(changed) if len(changed) else "" info_str = f"Checking out {change_str} in {worker}" @@ -545,12 +549,7 @@ def run_version_bisection( """ self.logger.info("Running version-level bisection...") # Prepare the container for the bisection - with make_container( - self.args.container_runtime, - self.bisection_url, - self.args.container_mount, - self.logger, - ) as worker: + with self._make_container(self.bisection_url) as worker: package_versions = self._gather_histories( worker, passing_versions, failing_versions ) @@ -569,4 +568,6 @@ def run_version_bisection( ) result["container"] = self.bisection_url self.logger.info("Version-level bisection completed") - return add_summary_record(self.args.output_prefix, "result", result, scalar=True) + return add_summary_record( + self.args.output_prefix, "result", result, scalar=True + ) diff --git a/.github/triage/tests/mock_scripts/bazel b/.github/triage/tests/mock_scripts/bazel new file mode 100755 index 000000000..8c3cbfc39 --- /dev/null +++ b/.github/triage/tests/mock_scripts/bazel @@ -0,0 +1,3 @@ +#!/bin/bash + +exit 0 diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index fc39e7d23..9ab5a125d 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -1,7 +1,6 @@ import subprocess import tempfile import pathlib -import os import logging import pytest @@ -44,24 +43,9 @@ def triage_env(): temp_path = pathlib.Path(temp_dir) repo_path = temp_path / "repos" output_path = temp_path / "output" - mock_scripts_path = temp_path / "mock_scripts" + source_scripts_dir = pathlib.Path(__file__).parent / "mock_scripts" repo_path.mkdir() output_path.mkdir() - mock_scripts_path.mkdir() - source_scripts_dir = pathlib.Path(__file__).parent / "mock_scripts" - # fake build-jax - build_script_content = (source_scripts_dir / "build-jax.sh").read_text() - (mock_scripts_path / "build-jax.sh").write_text(build_script_content) - os.chmod(mock_scripts_path / "build-jax.sh", 0o755) - # test-case.sh helper test script - test_case_content = (source_scripts_dir / "test-case.sh").read_text() - (mock_scripts_path / "test-case.sh").write_text(test_case_content) - os.chmod(mock_scripts_path / "test-case.sh", 0o755) - # fake bazel - (mock_scripts_path / "bazel").write_text("#!/bin/sh\nexit 0") - os.chmod(mock_scripts_path / "bazel", 0o755) - - # Create a git repository jax_repo_path = repo_path / "jax" jax_repo_path.mkdir() @@ -74,6 +58,7 @@ def git_cmd(command, *args): git_cmd("config", "user.name", "Test User") git_cmd("config", "user.email", "test@example.com") # Create a linear commit history + git_cmd("commit" "--allow-empty", "-m", "M0_base") git_cmd("commit", "--allow-empty", "-m", "M1") m1 = git_cmd("rev-parse", "HEAD") git_cmd("commit", "--allow-empty", "-m", "M2") # good commit @@ -92,9 +77,9 @@ def git_cmd(command, *args): git_cmd("cherry-pick", passing_nonlinear) failing_nonlinear = git_cmd("rev-parse", "HEAD") # F1' # so now we have: - # M1 --- M2 --- M3 - # | | - # F1 F1' + # M0--M1 --- M2 --- M3 + # | | + # F1 F1' # where F1 = passing # and F1' = failing @@ -118,7 +103,7 @@ def git_cmd(command, *args): "paths": { "repo": repo_path, "output": output_path, - "scripts": mock_scripts_path, + "scripts": source_scripts_dir, }, "commits": { "good_main": m2, # last good commit From d6b4253862744fa059851850d6e95bac5dd57117 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 15:52:16 +0100 Subject: [PATCH 49/50] fix error --- .github/triage/tests/test_triage_history_bisection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py index 9ab5a125d..59e287e43 100644 --- a/.github/triage/tests/test_triage_history_bisection.py +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -58,7 +58,7 @@ def git_cmd(command, *args): git_cmd("config", "user.name", "Test User") git_cmd("config", "user.email", "test@example.com") # Create a linear commit history - git_cmd("commit" "--allow-empty", "-m", "M0_base") + git_cmd("commit", "--allow-empty", "-m", "M0_base") git_cmd("commit", "--allow-empty", "-m", "M1") m1 = git_cmd("rev-parse", "HEAD") git_cmd("commit", "--allow-empty", "-m", "M2") # good commit From 1e13d19391257a340e1c67f4106d19096063c4a3 Mon Sep 17 00:00:00 2001 From: Steboss Date: Tue, 15 Jul 2025 15:53:43 +0100 Subject: [PATCH 50/50] fix @olupton comments --- .github/triage/jax_toolbox_triage/bisect.py | 4 +--- .github/triage/jax_toolbox_triage/summary.py | 2 +- .github/triage/jax_toolbox_triage/triage_tool.py | 10 ++++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py index 6d28d07fe..bacf1dba1 100644 --- a/.github/triage/jax_toolbox_triage/bisect.py +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -76,9 +76,7 @@ def get_commit_history( start = passing_main_commit end = failing_main_commit - logger.info( - f"cherry_pick_range: {cherry_pick_range}, start: {start}, end: {end}" - ) + logger.info(f"cherry_pick_range: {cherry_pick_range}, start: {start}, end: {end}") # now create the right git command to retrieve the history between start..end result = worker.check_exec( diff --git a/.github/triage/jax_toolbox_triage/summary.py b/.github/triage/jax_toolbox_triage/summary.py index 3db97673c..8545d1447 100644 --- a/.github/triage/jax_toolbox_triage/summary.py +++ b/.github/triage/jax_toolbox_triage/summary.py @@ -43,7 +43,7 @@ def add_summary_record( with open(summary_filename, "w") as ofile: json.dump(data, ofile) - return data + return data def create_output_symlinks( diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py index 6088bcf45..2917ec569 100644 --- a/.github/triage/jax_toolbox_triage/triage_tool.py +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -56,7 +56,9 @@ def _test_output_directory( ) out_dir = self.args.output_prefix / out_dirname - assert not out_dir.exists(), f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + assert not out_dir.exists(), ( + f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + ) out_dir.mkdir(mode=0o755) return out_dir.resolve() @@ -100,9 +102,9 @@ def _get_versions( """ if explicit_versions is not None and container_url is None: return explicit_versions, None, None, None - assert ( - container_url is not None - ), "Container URL must be provided if explicit versions are not set." + assert container_url is not None, ( + "Container URL must be provided if explicit versions are not set." + ) with self._make_container(container_url) as worker: url_versions, dirs, env = get_versions_dirs_env(worker, versions_from_env)