diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index 3080e66be..17ee95872 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -26,6 +26,23 @@ def parse_version_argument(s: str) -> typing.Dict[str, str]: return ret +def parse_override_remotes(s: str) -> typing.Dict[str, str]: + """Function to parse the override remote + + Inputs: + s: (str) e.g. https://@host/repo.git + + Returns: + ret: (typing.Dict[str,str]) Dictionary with software as key and git-url as value. + """ + ret: typing.Dict[str, str] = {} + for part in s.split(","): + sw, url = part.split(":", 1) + assert sw not in ret, ret + ret[sw] = url + return ret + + def parse_args(args=None) -> argparse.Namespace: parser = argparse.ArgumentParser( description=""" @@ -208,6 +225,14 @@ def parse_args(args=None) -> argparse.Namespace: in question has different versions at the endpoints of the bisection range. """, ) + version_search_args.add_argument( + "--override-remotes", + type=parse_override_remotes, + default={}, + help="""Remote URLs to be used for fetching, including auth token. E.g.: + jax:https://@host/repo.git,xla:https://@host/repo.git + """, + ) parser.add_argument( "-v", "--container-mount", @@ -225,10 +250,18 @@ def parse_args(args=None) -> argparse.Namespace: help="Container runtime used, can be docker, pyxis, or local.", type=lambda s: s.lower(), ) - args = parser.parse_args(args=args) - assert args.container_runtime in {"docker", "pyxis", "local"}, ( - args.container_runtime + parser.add_argument( + "--main-branch", + type=str, + default="main", + help="The name of the main branch (e.g. main) to derive cherry-picks from", ) + args = parser.parse_args(args=args) + assert args.container_runtime in { + "docker", + "pyxis", + "local", + }, args.container_runtime # --{passing,failing}-commits are deprecated aliases for --{passing,failing}-versions. for prefix in ["passing", "failing"]: commits = getattr(args, f"{prefix}_commits") diff --git a/.github/triage/jax_toolbox_triage/bisect.py b/.github/triage/jax_toolbox_triage/bisect.py new file mode 100644 index 000000000..bacf1dba1 --- /dev/null +++ b/.github/triage/jax_toolbox_triage/bisect.py @@ -0,0 +1,105 @@ +import datetime +import subprocess + + +def get_commit_history( + worker, + package, + start, + end, + dir, + main_branch, + logger=None, + args=None, +): + """ + Get the commit history for a given package between two commits. + + Args: + worker (Container): The container worker to execute commands. + package (str): The name of the package. + start (str): The starting commit hash. + end (str): The ending commit hash. + dir (str): The directory where the git repository is located. + main_branch (str): The main branch name. Defaults is the default branch of the repo. + logger (Logger, optional): Logger for debug information. Defaults to None. + args: Additional arguments that may contain cherry-pick commits. + + Returns: + data: list, list of all the commits + cherry_pick_range: str, range of cherry pick commits if any + """ + # In particular the end commit might not already be known if the older, + # passing, container is being used for triage. + commits_known = worker.exec( + [ + "sh", + "-c", + f"git cat-file commit {start} && git cat-file commit {end}", + ], + policy="once_per_container", + workdir=dir, + ) + if commits_known.returncode != 0: + worker.check_exec( + ["git", "fetch", args.override_remotes.get(package, "origin"), start, end], + workdir=dir, + ) + + # detect non-linear history + is_ancestor_result = worker.exec( + ["git", "merge-base", "--is-ancestor", start, end], + workdir=dir, + ) + is_linear = is_ancestor_result.returncode == 0 + cherry_pick_range = {} + + if not is_linear: + logger.info(f"Using non-linear history logic with branch {main_branch}") + + # 1. find the linear range on the main branch + passing_and_failing_cmd = worker.check_exec( + [ + "sh", + "-c", + f"git merge-base {start} {end} && git merge-base {end} {main_branch}", + ], + workdir=dir, + ).stdout.strip() + passing_main_commit, failing_main_commit = passing_and_failing_cmd.splitlines() + + # 2. find commits to cherry-pick from the failing branch + # TODO: as an alternative approach we may need to consider `{passing_main_commit}..{start}` + cherry_pick_range[package] = f"{failing_main_commit}..{end}" + + # 3. now we can use the main branch commits for bisection + start = passing_main_commit + end = failing_main_commit + + logger.info(f"cherry_pick_range: {cherry_pick_range}, start: {start}, end: {end}") + + # now create the right git command to retrieve the history between start..end + result = worker.check_exec( + [ + "git", + "log", + "--first-parent", + "--reverse", + "--format=%H %cI", + f"{start}^..{end}", + ], + policy="once", + stderr=subprocess.PIPE, + workdir=dir, + ) + + data = [] + for line in result.stdout.splitlines(): + commit, date = line.split() + # for python < 3.11 we nee dto fix: + if date.endswith("Z"): + date = date[:-1] + "+00:00" + date = datetime.datetime.fromisoformat(date).astimezone(datetime.timezone.utc) + data.append((commit, date)) + + return data, cherry_pick_range diff --git a/.github/triage/jax_toolbox_triage/container_factory.py b/.github/triage/jax_toolbox_triage/container_factory.py new file mode 100644 index 000000000..0c254ba0d --- /dev/null +++ b/.github/triage/jax_toolbox_triage/container_factory.py @@ -0,0 +1,28 @@ +import logging +from .container import Container +from .docker import DockerContainer +from .pyxis import PyxisContainer +from .local import LocalContainer + + +def make_container( + runtime: str, url: str, mounts: list, logger: logging.Logger, **kwargs +) -> Container: + """ + This function creates a container object, based on the specified runtime + + Args: + runtime (str): The container runtime to use (e.g., 'docker', 'pyxis', 'local'). + url (str): The URL of the container. + mounts (list): List of mounts to be used in the container. + logger (logging.Logger): Logger instance for logging messages. + **kwargs: Additional keyword arguments for specific container types. + + Returns: + Container: A container class associated with the specified runtime. + """ + if runtime == "local": + return LocalContainer(logger=logger) + + container_impl = DockerContainer if runtime == "docker" else PyxisContainer + return container_impl(url, logger=logger, mounts=mounts, **kwargs) diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index 0bad0a00e..fff2e99d0 100644 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -1,591 +1,25 @@ -import collections -import datetime -import functools -import hashlib -import itertools -import json -import logging -import pathlib -import subprocess -import time -import typing +from .args import parse_args +from .utils import get_logger +from .triage_tool import TriageTool -from .args import compulsory_software, optional_software, parse_args -from .container import Container -from .docker import DockerContainer -from .local import LocalContainer -from .logic import container_search, TestResult, version_search -from .pyxis import PyxisContainer -from .utils import ( - container_url as container_url_base, - get_logger, - prepare_bazel_cache_mounts, -) - -def get_env(worker: Container) -> typing.Dict[str, str]: - """ - Get the runtime environment in the given container. - - Returns: {env_var: value} dictionary, sorted by key. - """ - - def impl() -> typing.Dict[str, str]: - kvs = ( - worker.check_exec(["env", "-0"], policy="once", stderr="separate") - .stdout[:-1] # skip the trailing \0 - .split("\0") - ) - return dict(kv.split("=", 1) for kv in kvs) - - # Remove any environment variables that differ between consecutive `env` calls, for - # example some step-specific Slurm variables. - env1, env2 = impl(), impl() - # sorted(...) for run-to-run determinism - return {k: env1[k] for k in sorted(env1.keys() & env2.keys()) if env1[k] == env2[k]} - - -def get_commits_and_dirs( - worker: Container, -) -> typing.Tuple[typing.Dict[str, str], typing.Dict[str, str]]: - """ - Get the git repository paths and current HEAD commits in the given environment of - the software packages named in `compulsory_software` and `optional_software`. - - Returns: ({package: commit}, {package: directory}) - """ - # Formulated this way to avoid paying too many times for container startup. - cmds = [] - for package in compulsory_software + optional_software: - bits = [ - f"(cd /opt/{package} && git rev-parse HEAD && echo {package} && pwd)", - f"(cd /opt/{package}-source && git rev-parse HEAD && echo {package} && pwd)", - ] - if package in optional_software: - bits.append("true") - cmds.append(f"({' || '.join(bits)})") - result = worker.check_exec( - ["sh", "-c", " && ".join(cmds)], policy="once", stderr="separate" - ) - versions, dirs = {}, {} - # Look over triplets of output lines - for commit, package, dirname in zip(*([iter(result.stdout.splitlines())] * 3)): - dirs[package] = dirname - versions[package] = commit - return versions, dirs - - -def get_versions_dirs_env( - worker: Container, - versions_from_env: bool, -) -> typing.Tuple[typing.Dict[str, str], typing.Dict[str, str], typing.Dict[str, str]]: +def main() -> None: """ - Get software versions in the given [container] environment, git repository paths - where relevant, and the runtime environment. - - The list of software versions is drawn from git repositories at known container - locations and, if `versions_from_env` is True, from the environment. - - Returns: - versions: {package: version or commit}, - dirs: {package: git_repository_dir} - env: {env_var: value} + Main entry point for the triage tool. """ - # Get the git repository paths and commits from the container. - versions, dirs = get_commits_and_dirs(worker) - - # Get the environment variables from the container. - env = get_env(worker) - - if versions_from_env: - # Promote any XXX_VERSION environment variables into `versions` if `XXX` is - # not already there. - for k, v in env.items(): - if not len(v) or not k.endswith("_VERSION"): - continue - package = k[:-8] - assert package not in versions, (versions, package) - versions[package] = v - return versions, dirs, env - - -def main() -> None: args = parse_args() - bazel_cache_mounts = prepare_bazel_cache_mounts(args.bazel_cache) logger = get_logger(args.output_prefix) - logger.info("Arguments:") - for k, v in vars(args).items(): - logger.info(f" {k}: {v}") - logger.info( - "Verbose output, including stdout/err of triage commands, will be written to " - f"{(args.output_prefix / 'debug.log').resolve()}" - ) - - def test_output_directory( - url: str, versions: typing.Dict[str, str] = {} - ) -> pathlib.Path: - # Construct an output directory name, on the host, for output files written by - # the test case. - hash_chars = 8 - urlhash = f"container-{hashlib.sha1(url.encode()).hexdigest()[:hash_chars]}" - out_dirname = "-".join( - itertools.chain( - [urlhash], - map(lambda t: f"{t[0]}-{t[1][:hash_chars]}", sorted(versions.items())), - ) - ) - out_dir = args.output_prefix / out_dirname - assert not out_dir.exists(), ( - f"{out_dir} should not already exist, maybe you are re-using {args.output_prefix}?" - ) - out_dir.mkdir(mode=0o755) - return out_dir.resolve() - - container_url = functools.partial( - container_url_base, - container=args.container, - template=args.container_url_template, - ) - def Container( - url, test_output_host_directory: typing.Optional[pathlib.Path] = None - ): - if args.container_runtime == "local": - return LocalContainer(logger=logger) + try: + tool = TriageTool(args, logger) - Imp = DockerContainer if args.container_runtime == "docker" else PyxisContainer - mounts = bazel_cache_mounts + args.container_mount - if test_output_host_directory is not None: - # This can be used to save useful output from the test case (e.g. HLOs) - mounts.append((test_output_host_directory, "/triage-tool-output")) - return Imp(url, logger=logger, mounts=mounts) + passing_url, failing_url = tool.find_container_range() - def get_versions( - container_url: typing.Optional[str], - explicit_versions: typing.Optional[typing.Dict[str, str]], - versions_from_env: bool, - ) -> typing.Tuple[ - typing.Dict[str, str], - typing.Optional[typing.Dict[str, str]], - typing.Optional[typing.Dict[str, str]], - typing.Optional[typing.Dict[str, str]], - ]: - """ - Given an optional container URL (e.g. --failing-container) and an optional set - of overriden versions (e.g. --failing-versions), obtain a list of software - versions to bookend the triage range. - - Also returns the container's runtime environment variables for diagnostic purposes. - - If *only* overrides are given, those are returned verbatim and all other return - values are None. - - Otherwise, the software versions, git repository directories and runtime - environment are extracted from the given container. If overrides are *also* given, - these take precedence over the extracted values. - - It is an error for both `container_url` and `explicit_versions` to be None. - - Returns: - versions: {packge: version} mapping defining one end of the triage range - url_versions: {package: version} corresponding to the given container (or None) - dirs: {package: git_repository_dir} (or None) - env: {env_var: value} (or None) - """ - if explicit_versions is not None and container_url is None: - return explicit_versions, None, None, None - assert container_url is not None - logger.info(f"Extracting versions from {container_url} ...") - with Container(container_url) as worker: - url_versions, dirs, env = get_versions_dirs_env(worker, versions_from_env) - overriden_versions = url_versions.copy() - if explicit_versions is not None: - overriden_versions.update(explicit_versions) - return overriden_versions, url_versions, dirs, env - - def add_summary_record( - section: str, - record: typing.Mapping[str, typing.Union[bool, float, str]], - scalar: bool = False, - ): - """ - Add a record to the output JSON file. This is intended to provide a useful record - even in case of a fatal error. - """ - summary_filename = args.output_prefix / "summary.json" - try: - with open(summary_filename, "r") as ifile: - data = json.load(ifile) - except FileNotFoundError: - data = {} - if scalar: - if section in data: - logging.warning(f"Overwriting summary data in section {section}") - data[section] = record - else: - if section not in data: - data[section] = [] - data[section].append(record) - with open(summary_filename, "w") as ofile: - json.dump(data, ofile) - - versions_from_env = args.build_scripts_path is not None - - def check_container( - date: datetime.date, *, test_output_log_level: int = logging.DEBUG - ) -> TestResult: - """ - See if the test passes in the given dated container. - """ - before = time.monotonic() - out_dir = test_output_directory(container_url(date)) - with Container( - container_url(date), test_output_host_directory=out_dir - ) as worker: - versions, _, _ = get_versions_dirs_env(worker, versions_from_env) - # This will stream interleaved stdout/stderr into the logger - result = worker.exec(args.test_command, log_level=test_output_log_level) - test_time = time.monotonic() - before - test_pass = result.returncode == 0 - logger.info( - f"Ran test case in {worker} in {test_time:.1f}s, pass={test_pass}" - ) - add_summary_record( - "container", - { - "container": container_url(date), - "output_directory": out_dir.as_posix(), - "result": test_pass, - "test_time": test_time, - } - | versions, - ) - return TestResult( - host_output_directory=out_dir, result=test_pass, stdouterr=result.stdout - ) - - if args.container_runtime == "local": - passing_url = "local" - failing_url = "local" - elif args.passing_container is None and args.failing_container is None: - # Search through the published containers, narrowing down to a pair of dates with - # the property that the test passed on `range_start` and fails on `range_end`. - range_start, range_end = container_search( - container_exists=lambda date: Container(container_url(date)).exists(), - container_passes=check_container, - start_date=args.start_date, - end_date=args.end_date, - logger=logger, - skip_precondition_checks=args.skip_precondition_checks, - threshold_days=args.threshold_days, - ) - passing_url = container_url(range_start) - failing_url = container_url(range_end) - else: - # Skip the container-level search because at lease one explicit end point was - # given - passing_url = args.passing_container - failing_url = args.failing_container - - # Get the versions from the endpoint containers (if they exist), overridden by any - # explicitly passed versions. - passing_versions, original_passing_versions, passing_package_dirs, passing_env = ( - get_versions(passing_url, args.passing_versions, versions_from_env) - ) - failing_versions, original_failing_versions, failing_package_dirs, failing_env = ( - get_versions(failing_url, args.failing_versions, versions_from_env) - ) - - # If we have two containers, print the differences between their environments. This - # can be useful in the case that rebuilding the good versions in the bad container, - # or vice versa, does not reproduce the expected result. - if passing_env is not None and failing_env is not None: - logger.info(f"Environment differences between {passing_url} and {failing_url}") - for key in passing_env.keys() - failing_env.keys(): - logger.info(f"Only in {passing_url}: {key}={passing_env[key]}") - for key in failing_env.keys() - passing_env.keys(): - logger.info(f"Only in {failing_url}: {key}={failing_env[key]}") - for key in passing_env.keys() & failing_env.keys(): - if passing_env[key] == failing_env[key]: - continue - logger.info( - f"{key}: {passing_env[key]} ({passing_url}) vs. {failing_env[key]} " - f"({failing_url})" - ) - - # We should have versions for all the same software packages at both - # ends of the range, one way or another. TODO: this could be relaxed. - assert passing_versions.keys() == failing_versions.keys(), ( - passing_versions, - failing_versions, - ) - - # Which packages have versions that are not always the same? - dynamic_packages = { - pkg for pkg, _ in set(passing_versions.items()) ^ set(failing_versions.items()) - } - - # Choose an environment to do the version-level bisection in; use directory names that - # match it, and track what the initial versions of the different packages are - if args.container_runtime == "local": - bisection_url = "local" - bisection_versions = original_failing_versions - package_dirs = failing_package_dirs - elif failing_url is not None: - bisection_url = failing_url - bisection_versions = original_failing_versions - package_dirs = failing_package_dirs - else: - assert passing_url is not None - bisection_url = passing_url - bisection_versions = original_passing_versions - package_dirs = passing_package_dirs - assert package_dirs is not None - # This is the set of versions that are already installed - assert bisection_versions is not None - - # Get the full lists of JAX/XLA commits and dates - def get_commit_history(worker, start, end, dir): - # In particular the end commit might not already be known if the older, - # passing, container is being used for triage. - commits_known = worker.exec( - [ - "sh", - "-c", - f"git cat-file commit {start} && git cat-file commit {end}", - ], - policy="once_per_container", - workdir=dir, - ) - if commits_known.returncode != 0: - worker.check_exec( - ["git", "fetch"], policy="once_per_container", workdir=dir - ) - result = worker.check_exec( - [ - "git", - "log", - "--first-parent", - "--reverse", - "--format=%H %cI", - f"{start}^..{end}", - ], - policy="once", - stderr=subprocess.PIPE, - workdir=dir, + passing_versions, failing_versions = tool.gather_version_info( + passing_url, failing_url ) - logger.debug(f"stderr: {result.stderr.strip()}") - data = [] - for line in result.stdout.splitlines(): - commit, date = line.split() - date = datetime.datetime.fromisoformat(date).astimezone( - datetime.timezone.utc - ) - data.append((commit, date)) - return data - - # Fire up the container that will be used for the version-level search and use it to - # extract the relevant history of the repositories that will be triaged. - with Container(bisection_url) as worker: - packages = passing_versions.keys() - log_str = "Bisecting" - for package in packages: - log_str += ( - f" {package} [{passing_versions[package]}, {failing_versions[package]}]" - ) - log_str += f" using {worker}" - logger.info(log_str) - # Get lists of (commit_hash, commit_date) pairs - package_versions = collections.OrderedDict() - for package in packages: - if package not in package_dirs: - # This is a version that came from the container environment, not a git - # checkout directory in the container. Handle those below. - continue - package_versions[package] = get_commit_history( - worker, - passing_versions[package], - failing_versions[package], - package_dirs[package], - ) - # Confirm they're sorted by commit date - assert all( - b[1] >= a[1] - for a, b in zip( - package_versions[package], package_versions[package][1:] - ) - ) - # Confirm the end values are included as expected - assert passing_versions[package] == package_versions[package][0][0] - assert failing_versions[package] == package_versions[package][-1][0] - # For the packages that just have one or two version numbers, associate those - # version numbers with the earliest and, if appropriate, latest XLA dates. This - # is only relevant to packages that do not have git repositories checked out and - # are managed via installPACKAGE.sh scripts and PACKAGE_VERSION environment vars - for package in packages: - if package in package_versions: - continue - package_versions[package] = [ - (passing_versions[package], package_versions["xla"][0][1]), - ] - if passing_versions[package] != failing_versions[package]: - package_versions[package].append( - (failing_versions[package], package_versions["xla"][-1][1]) - ) - # Check up-front whether the installation scripts exist for the packages that - # are being triaged by version + script rather than from a git repo + build. - if args.build_scripts_path is not None: - known_scripts = worker.check_exec( - [ - "find", - args.build_scripts_path, - "-maxdepth", - "1", - "-executable", - "-print0", - ], - policy="once", - stderr="separate", - ).stdout.split("\0") - logger.debug(f"Found {known_scripts} inside {worker}") - packages_with_scripts = { - script[len(args.build_scripts_path) + 8 : -3] - for script in known_scripts - if script.startswith(args.build_scripts_path + "/install") - and script.endswith(".sh") - } - logger.debug(f"Found installation scripts for {packages_with_scripts}") - packages_needing_scripts = dynamic_packages - package_dirs.keys() - packages_missing_scripts = packages_needing_scripts - packages_with_scripts - if packages_missing_scripts: - logger.warning( - "No installation scripts found for: " - f"{' '.join(packages_missing_scripts)}, whose version(s) change " - "across the bisection range. These will be excluded from the " - "bisection, which may cause it not to converge!" - ) - dynamic_packages -= packages_missing_scripts - - def build_and_test( - *, versions: typing.Dict[str, str], test_output_log_level: int = logging.DEBUG - ) -> TestResult: - """ - The main body of the bisection loop. Update JAX/XLA/... versions, rebuild, and - run the test command. Throws on error when checking out or building, and returns - the status of the test command. - """ - # Amortise container startup overhead by batching together git commands - git_commands, changed, skipped = [], [], [] - for package in sorted(dynamic_packages): - version = versions[package] - if bisection_versions[package] == version: - # If the existing version is the desired one, do nothing. - skipped.append(f"{package}@{version}") - continue - # Cache which version is now going to be checked out in the container - bisection_versions[package] = version - changed.append(f"{package}@{version}") - if package in package_dirs: - # A git repository that exists in the container. - git_commands += [ - f"cd {package_dirs[package]}", - "git stash", - f"git checkout {version}", - ] - else: - # Another software package, `version` is probably a version number. - # Installation of this version is delegated to an installPACKAGE.sh - # script that is assumed to be available in `args.build_scripts_path`. - assert args.build_scripts_path is not None - assert package in packages_with_scripts, ( - package, - packages_with_scripts, - ) - extra_env = { - # Need the static part for the .bc library - "NVSHMEM": "DEVEL=1 STATIC=1", - }.get(package, "DEVEL=1") # Always need development headers to rebuild - git_commands += [ - f"{extra_env} {args.build_scripts_path}/install{package}.sh {version}" - ] - # Keep the pathnames shorter by only including packages that actually have - # multiple versions in the bisection range. - brief_versions = { - p: ver for p, ver in versions.items() if p in dynamic_packages - } - out_dir = test_output_directory(bisection_url, versions=brief_versions) - with Container(bisection_url, test_output_host_directory=out_dir) as worker: - change_str = " ".join(changed) if len(changed) else "" - info_str = f"Checking out {change_str} in {worker}" - if len(skipped): - info_str += f", leaving {' '.join(skipped)} unchanged" - logger.info(info_str) - worker.check_exec( - ["sh", "-c", " && ".join(git_commands)], - policy="once_per_container", - ) - # Build JAX - # TODO: teach the tool how to build TransformerEngine too - # TODO: do not build JAX/XLA/TransformerEngine if we know their versions did not change? - before = time.monotonic() - # Unfortunately the build system does not always seem to handle incremental - # rebuilds correctly, so clean the local cache and rely on the remote one. - build_cmds = [ - "bazel clean --expunge", - f"build-jax.sh --bazel-cache={args.bazel_cache}", - ] - worker.check_exec( - ["sh", "-c", " && ".join(build_cmds)], - policy="once_per_container", - workdir=package_dirs["jax"], - ) - middle = time.monotonic() - logger.info(f"Build completed in {middle - before:.1f}s") - # Run the test - test_result = worker.exec( - args.test_command, log_level=test_output_log_level - ) - test_time = time.monotonic() - middle - add_summary_record( - "versions", - { - "build_time": middle - before, - "container": bisection_url, - "output_directory": out_dir.as_posix(), - "result": test_result.returncode == 0, - "test_time": test_time, - } - | versions, - ) - result_str = "pass" if test_result.returncode == 0 else "fail" - logger.info(f"Test completed in {test_time:.1f}s ({result_str})") - return TestResult( - host_output_directory=out_dir, - result=test_result.returncode == 0, - stdouterr=test_result.stdout, - ) - - # Run the version-level bisection - result, last_known_good, first_known_bad = version_search( - versions=package_versions, - build_and_test=build_and_test, - logger=logger, - skip_precondition_checks=args.skip_precondition_checks, - ) - - def symlink(result: typing.Optional[TestResult], symlink_name: str) -> None: - if result is None: - return - symlink = (args.output_prefix / symlink_name).resolve() - assert not symlink.exists(), symlink - assert symlink.parent == result.host_output_directory.parent, ( - symlink, - result.host_output_directory, - ) - symlink.symlink_to(result.host_output_directory.name) + tool.run_version_bisection(passing_versions, failing_versions) - symlink(last_known_good, "last-known-good") - symlink(first_known_bad, "first-known-bad") - result["container"] = failing_url - add_summary_record("result", result, scalar=True) + except Exception as e: + logger.fatal(f"Triage process failed: {e}") diff --git a/.github/triage/jax_toolbox_triage/summary.py b/.github/triage/jax_toolbox_triage/summary.py new file mode 100644 index 000000000..8545d1447 --- /dev/null +++ b/.github/triage/jax_toolbox_triage/summary.py @@ -0,0 +1,79 @@ +import json +import logging +import pathlib +import typing +from .logic import TestResult + + +def add_summary_record( + output_prefix: pathlib.Path, + section: str, + record: typing.Union[typing.Dict[str, typing.Any], TestResult], + scalar=False, +): + """ + Add a record to the output JSON file. This is intended to provide a useful record + even in case of a fatal error. + + Args: + output_prefix (pathlib.Path): The prefix for the output directory. + section (str): The section of the summary to which the record belongs. + record (dict or TestResult): The record to be added, either as a dictionary or + as a TestResult object. + scalar (bool): If True, the record is a scalar value; if False, it is a list of + records. Defaults to False. + + Returns: + None + """ + summary_filename = output_prefix / "summary.json" + try: + with open(summary_filename, "r") as ifile: + data = json.load(ifile) + except FileNotFoundError: + data = {} + if scalar: + if section in data: + logging.warning(f"Overwriting summary data in section {section}") + data[section] = record + else: + if section not in data: + data[section] = [] + data[section].append(record) + + with open(summary_filename, "w") as ofile: + json.dump(data, ofile) + return data + + +def create_output_symlinks( + output_prefix: pathlib.Path, + last_known_good: typing.Optional[TestResult], + first_known_bad: typing.Optional[TestResult], +): + """ + Create symlinks to the last-good and first-bad output directories. + versions. + + Args: + output_prefix (pathlib.Path): The prefix for the output directory. + last_known_good (TestResult): The last known good test result. + first_known_bad (TestResult): The first known bad test result. + + Returns: + None + """ + + def symlink(result: typing.Optional[TestResult], symlink_name: str) -> None: + if result is None: + return + symlink_path = (output_prefix / symlink_name).resolve() + assert not symlink_path.exists(), symlink_path + assert symlink_path.parent == result.host_output_directory.parent, ( + symlink_path, + result.host_output_directory, + ) + symlink_path.symlink_to(result.host_output_directory) + + symlink(last_known_good, "last-known-good") + symlink(first_known_bad, "first-known-bad") diff --git a/.github/triage/jax_toolbox_triage/triage_tool.py b/.github/triage/jax_toolbox_triage/triage_tool.py new file mode 100644 index 000000000..2917ec569 --- /dev/null +++ b/.github/triage/jax_toolbox_triage/triage_tool.py @@ -0,0 +1,575 @@ +import collections +import datetime +import functools +import hashlib +import logging +import pathlib +import time +from typing import Dict, Tuple, Union, Any, Optional + +from .container import Container +from .logic import container_search, TestResult, version_search +from .versions import get_versions_dirs_env +from .summary import add_summary_record, create_output_symlinks +from .bisect import get_commit_history +from .utils import ( + container_url as container_url_base, + prepare_bazel_cache_mounts, +) +from .container_factory import make_container + + +class TriageTool: + """ + This is the main class that orchestrates the whole triage process. + """ + + def __init__(self, args, logger): + self.args = args + self.logger = logger + self.bisection_url = None + self.bisection_versions = {} + self.package_dirs = None + self.dynamic_packages = set() + self.packages_with_scripts = set() + self.bazel_cache_mounts = prepare_bazel_cache_mounts(self.args.bazel_cache) + # the cherry-pick gets populated only for non-linear cases + self.cherry_pick_commits = {} + + def _test_output_directory( + self, url: str, versions: Union[Dict[str, str], None] + ) -> pathlib.Path: + """ + Create a directory for test output based on the container URL and versions. + + Args: + url (str): The URL of the container. + versions (dict): A dictionary of software versions. + Returns: + pathlib.Path: The path to the output directory. + """ + hash_chars = 8 + urlhash = f"container-{hashlib.sha1(url.encode()).hexdigest()[:hash_chars]}" + out_dirname = "-".join( + [urlhash] + + [f"{k}-{v[:hash_chars]}" for k, v in sorted((versions or {}).items())] + ) + + out_dir = self.args.output_prefix / out_dirname + assert not out_dir.exists(), ( + f"{out_dir} should not already exist, maybe you are re-using {self.args.output_prefix}?" + ) + out_dir.mkdir(mode=0o755) + return out_dir.resolve() + + def _make_container( + self, url: str, test_output_directory: Optional[pathlib.Path] = None + ) -> Container: + """ + Wrapper for make_container factory + + Args: + url: (str), the input url of the docker image + test_output_directory: (pathlib.Path), the path to the output directory + + Returns: + Container object + """ + mounts = self.bazel_cache_mounts + self.args.container_mount + if test_output_directory is not None: + mounts.append((test_output_directory, "/triage-tool-output")) + + return make_container(self.args.container_runtime, url, mounts, self.logger) + + def _get_versions( + self, + container_url: str, + explicit_versions: Dict[str, str], + versions_from_env: bool, + ): + """ + Get the versions of the software packages in the container. + + Args: + container_url (str): The URL of the container. + explicit_versions (str): Explicit versions to use. + versions_from_env (bool): Whether to get versions from environment variables. + Returns: + overriden_versions (dict): The versions with explicit overrides. + url_versions (dict): The versions from the container URL. + dirs (dict): The directories of the software packages. + env (dict): The environment variables in the container. + """ + if explicit_versions is not None and container_url is None: + return explicit_versions, None, None, None + assert container_url is not None, ( + "Container URL must be provided if explicit versions are not set." + ) + + with self._make_container(container_url) as worker: + url_versions, dirs, env = get_versions_dirs_env(worker, versions_from_env) + overriden_versions = url_versions.copy() + if explicit_versions is not None: + overriden_versions.update(explicit_versions) + + return overriden_versions, url_versions, dirs, env + + def _gather_histories( + self, + worker: Container, + passing_versions: Dict[str, str], + failing_versions: Dict[str, str], + ) -> collections.OrderedDict: + """ + Gather the commit histories for the passing and failing versions. + + Args: + worker (Container): The container in which to run the commands. + passing_versions (dict): The versions that passed. + failing_versions (dict): The versions that failed. + Returns: + collections.OrderDict: The commit histories for passing and failing versions. + """ + packages = passing_versions.keys() + package_versions = collections.OrderedDict() + + for package in packages: + if package not in self.package_dirs: + continue + history, cherry_pick_range = get_commit_history( + worker, + package, + passing_versions[package], + failing_versions[package], + self.package_dirs[package], + main_branch=self.args.main_branch, + logger=self.logger, + args=self.args, + ) + package_versions[package] = history + if cherry_pick_range: + self.cherry_pick_commits[package] = cherry_pick_range[package] + + assert all( + b[1] >= a[1] + for a, b in zip( + package_versions[package], package_versions[package][1:] + ) + ) + + # this check works only for linera-case + if package not in self.cherry_pick_commits: + assert passing_versions[package] == package_versions[package][0][0] + assert failing_versions[package] == package_versions[package][-1][0] + + for package in packages: + if package in package_versions: + continue + package_versions[package] = [ + (passing_versions[package], package_versions["xla"][0][1]) + ] + if passing_versions[package] != failing_versions[package]: + package_versions[package].append( + (failing_versions[package], package_versions["xla"][-1][1]) + ) + + return package_versions + + def _log_environment_differences( + self, url1: str, url2: str, env1: Dict[str, str], env2: Dict[str, str] + ): + """ + If we have two containers, print the differences between their environments. This + can be useful in the case that rebuilding the good versions in the bad container, + or vice versa, does not reproduce the expected result. + + Args: + url1 (str): The URL of the first container. + url2 (str): The URL of the second container. + env1 (dict): The environment variables of the first container. + env2 (dict): The environment variables of the second container. + + Returns: + None + """ + if env1 is None or env2 is None: + return + self.logger.info(f"Environment differences between {url1} and {url2}:") + for key in env1.keys() - env2.keys(): + self.logger.info(f"\tOnly in {url1}: {key}={env1[key]}") + for key in env2.keys() - env1.keys(): + self.logger.info(f"\tOnly in {url2}: {key}={env2[key]}") + for key in env1.keys() & env2.keys(): + if env1[key] != env2[key]: + self.logger.info( + f"{key}: {env1[key]} ({url1}) vs. {env2[key]} ({url2})" + ) + + def _check_container_by_date( + self, date: datetime.date, *, test_output_log_level: int = logging.DEBUG + ) -> TestResult: + """ + See if the test passes in the given dated container. + + Args: + date (datetime.date): The date of the container to check. + test_output_log_level (int): The log level for test output. + Returns: + TestResult: The result of the test, including whether it passed and the output. + """ + container_url = container_url_base( + date, + container=self.args.container, + template=self.args.container_url_template, + ) + + before = time.monotonic() + out_dir = self._test_output_directory(container_url, None) + + with self._make_container( + container_url, test_output_directory=out_dir + ) as worker: + versions, _, _ = get_versions_dirs_env( + worker, self.args.build_scripts_path is not None + ) + result = worker.exec( + self.args.test_command, log_level=test_output_log_level + ) + test_time = time.monotonic() - before + test_pass = result.returncode == 0 + self.logger.info( + f"Ran test case in {worker} in {test_time:.1f}s, pass={test_pass}" + ) + + add_summary_record( + self.args.output_prefix, + "container", + { + **{ + "container": container_url, + "output_directory": out_dir.as_posix(), + "result": test_pass, + "test_time": test_time, + }, + **versions, + }, + ) + return TestResult( + host_output_directory=out_dir, result=test_pass, stdouterr=result.stdout + ) + + def _check_installation_scripts(self, worker: Container): + """ + Check for special installation cases, like cuBLAS or cuDNN + + Args: + worker (Container): The container in which to run the commands. + """ + if self.args.build_scripts_path is None: + return + + known_scripts_result = worker.exec( + [ + "find", + self.args.build_scripts_path, + "-maxdepth", + "1", + "-executable", + "-print0", + ], + policy="once", + stderr="separate", + ) + if known_scripts_result.returncode != 0: + self.logger.warning( + f"Failed to find known installation scripts in {self.args.build_scripts_path}: {known_scripts_result.stderr}" + ) + known_scripts = [] + else: + known_scripts = known_scripts_result.stdout.split("\0") + + self.logger.debug(f"Found {known_scripts} inside {worker}") + + self.packages_with_scripts = { + script[len(self.args.build_scripts_path) + 8 : -3] + for script in known_scripts + if script.startswith(self.args.build_scripts_path + "/install") + and script.endswith(".sh") + } + self.logger.debug( + f"Found installation scripts for {self.packages_with_scripts}" + ) + packages_needing_scripts = self.dynamic_packages - self.package_dirs.keys() + packages_missing_scripts = packages_needing_scripts - self.packages_with_scripts + if packages_missing_scripts: + self.logger.warning( + "No installation scripts found for: " + f"{' '.join(packages_missing_scripts)}, whose version(s) change " + "across the bisection range. These will be excluded from the " + "bisection, which may cause it not to converge!" + ) + self.dynamic_packages -= packages_missing_scripts + + def _build_and_test( + self, + *, + versions: Dict[str, str], + test_output_log_level: int = logging.DEBUG, + ) -> TestResult: + """ + The main body of the bisection loop. Update JAX/XLA/... versions, rebuild, and + run the test command. Throws on error when checking out or building, and returns + the status of the test command. + + Args: + versions (dict): The versions of the software packages to use. + test_output_log_level (int): The log level for test output. + + Returns: + TestResult: The result of the test, including whether it passed and the output. + """ + # Amortise container startup overhead by batching together git commands + git_commands, changed, skipped = [], [], [] + for package in sorted(self.dynamic_packages): + version = versions[package] + if self.bisection_versions.get(package) == version: + # If the existing version is the desired one, do nothing. + skipped.append(f"{package}@{version}") + continue + # Cache which version is now going to be checked out in the container + self.bisection_versions[package] = version + changed.append(f"{package}@{version}") + if package in self.package_dirs: + cherry_pick_range = self.cherry_pick_commits.get(package) + git_commands.append(f"cd {self.package_dirs[package]}") + git_commands.append("git stash") + # this is a checkout on the main branch + git_commands.append(f"git checkout {version}") + if cherry_pick_range: + self.logger.info("Working on a non-linear history") + git_commands.append( + f"(git cherry-pick {cherry_pick_range} || (echo 'Cherry-pick failed' && exit 1) )" + ) + + else: + # Another software package, `version` is probably a version number. + # Installation of this version is delegated to an installPACKAGE.sh + # script that is assumed to be available in `args.build_scripts_path`. + assert self.args.build_scripts_path is not None + assert package in self.packages_with_scripts, ( + package, + self.packages_with_scripts, + ) + extra_env = { + # Need the static part for the .bc library + "NVSHMEM": "DEVEL=1 STATIC=1", + }.get(package, "DEVEL=1") # Always need development headers to rebuild + git_commands += [ + f"{extra_env} {self.args.build_scripts_path}/install{package}.sh {version}" + ] + # Keep the pathnames shorter by only including packages that actually have + # multiple versions in the bisection range. + brief_versions = { + p: ver for p, ver in versions.items() if p in self.dynamic_packages + } + out_dir = self._test_output_directory( + self.bisection_url, versions=brief_versions + ) + + with self._make_container( + self.bisection_url, test_output_directory=out_dir + ) as worker: + change_str = " ".join(changed) if len(changed) else "" + info_str = f"Checking out {change_str} in {worker}" + if len(skipped): + info_str += f", leaving {' '.join(skipped)} unchanged" + self.logger.info(info_str) + worker.check_exec( + ["sh", "-c", " && ".join(git_commands)], + policy="once_per_container", + ) + # Build JAX + # TODO: teach the tool how to build TransformerEngine too + # TODO: do not build JAX/XLA/TransformerEngine if we know their versions did not change? + before = time.monotonic() + # Unfortunately the build system does not always seem to handle incremental + # rebuilds correctly, so clean the local cache and rely on the remote one. + build_cmds = [ + "bazel clean --expunge", + f"build-jax.sh --bazel-cache={self.args.bazel_cache}", + ] + worker.check_exec( + ["sh", "-c", " && ".join(build_cmds)], + policy="once_per_container", + workdir=self.package_dirs["jax"], + ) + middle = time.monotonic() + self.logger.info(f"Build completed in {middle - before:.1f}s") + # Run the test + test_result = worker.exec( + self.args.test_command, log_level=test_output_log_level + ) + test_time = time.monotonic() - middle + + add_summary_record( + self.args.output_prefix, + "versions", + { + **{ + "build_time": middle - before, + "container": self.bisection_url, + "output_directory": out_dir.as_posix(), + "result": test_result.returncode == 0, + "test_time": test_time, + }, + **versions, + }, + ) + result_str = "pass" if test_result.returncode == 0 else "fail" + self.logger.info(f"Test completed in {test_time:.1f}s ({result_str})") + return TestResult( + host_output_directory=out_dir, + result=test_result.returncode == 0, + stdouterr=test_result.stdout, + ) + + def find_container_range(self) -> Tuple[str, str]: + """ + Find the range from the passing and failing containers. + Returns a tuple of the start and end container names. + """ + self.logger.info("Finding container range...") + if self.args.container_runtime == "local": + return "local", "local" + + container_url_func = functools.partial( + container_url_base, + container=self.args.container, + template=self.args.container_url_template, + ) + + if self.args.passing_container is None and self.args.failing_container is None: + range_start, range_end = container_search( + container_exists=lambda date: make_container( + self.args.container_runtime, + container_url_func(date), + [], + self.logger, + ).exists(), + container_passes=self._check_container_by_date, + start_date=self.args.start_date, + end_date=self.args.end_date, + logger=self.logger, + skip_precondition_checks=self.args.skip_precondition_checks, + threshold_days=self.args.threshold_days, + ) + + return container_url_func(range_start), container_url_func(range_end) + else: + return self.args.passing_container, self.args.failing_container + + def gather_version_info(self, passing_url: str, failing_url: str): + """ + Gather version information from the passing and failing containers. + + Args: + passing_url (str): The URL of the passing container. + failing_url (str): The URL of the failing container. + + """ + self.logger.info("Gathering version information...") + self.logger.info(f"Using {self.bisection_url} for version-level bisection...") + versions_from_env = self.args.build_scripts_path is not None + # Get the versions from the endpoint containers (if they exist), overridden by any + # explicitly passed versions. + ( + passing_versions, + original_passing_versions, + passing_package_dirs, + passing_env, + ) = self._get_versions( + passing_url, self.args.passing_versions, versions_from_env + ) + ( + failing_versions, + original_failing_versions, + failing_package_dirs, + failing_env, + ) = self._get_versions( + failing_url, self.args.failing_versions, versions_from_env + ) + + self._log_environment_differences( + passing_url, failing_url, passing_env, failing_env + ) + + # We should have versions for all the same software packages at both + # ends of the range, one way or another. TODO: this could be relaxed. + assert passing_versions.keys() == failing_versions.keys(), ( + passing_versions, + failing_versions, + ) + # Which packages have versions that are not always the same? + self.dynamic_packages = { + pkg + for pkg, _ in set(passing_versions.items()) ^ set(failing_versions.items()) + } + # Choose an environment to do the version-level bisection in; use directory names that + # match it, and track what the initial versions of the different packages are + if self.args.container_runtime == "local": + self.bisection_url = "local" + self.bisection_versions = original_failing_versions + self.package_dirs = failing_package_dirs + elif failing_url is not None: + self.bisection_url = failing_url + self.bisection_versions = original_failing_versions + self.package_dirs = failing_package_dirs + else: + assert passing_url is not None + self.bisection_url = passing_url + self.bisection_versions = original_passing_versions + self.package_dirs = passing_package_dirs + assert self.package_dirs is not None + # This is the set of versions that are already installed + assert self.bisection_versions is not None + return passing_versions, failing_versions + + def run_version_bisection( + self, + passing_versions: Dict[str, str], + failing_versions: Dict[str, str], + ) -> Dict[str, Any]: + """ + Run the version bisection process. + + Args: + passing_versions (dict): The versions that passed. + failing_versions (dict): The versions that failed. + + Returns: + Tuple[dict, TestResult]: The final versions and the test result. + """ + self.logger.info("Running version-level bisection...") + # Prepare the container for the bisection + with self._make_container(self.bisection_url) as worker: + package_versions = self._gather_histories( + worker, passing_versions, failing_versions + ) + self._check_installation_scripts(worker) + + # Run the version-level bisection + result, last_known_good, first_known_bad = version_search( + versions=package_versions, + build_and_test=self._build_and_test, + logger=self.logger, + skip_precondition_checks=self.args.skip_precondition_checks, + ) + # Write final summary + create_output_symlinks( + self.args.output_prefix, last_known_good, first_known_bad + ) + result["container"] = self.bisection_url + self.logger.info("Version-level bisection completed") + return add_summary_record( + self.args.output_prefix, "result", result, scalar=True + ) diff --git a/.github/triage/jax_toolbox_triage/versions.py b/.github/triage/jax_toolbox_triage/versions.py new file mode 100644 index 000000000..1c8c37cca --- /dev/null +++ b/.github/triage/jax_toolbox_triage/versions.py @@ -0,0 +1,89 @@ +import typing +from .container import Container +from .args import compulsory_software, optional_software + + +def get_env(worker: Container) -> typing.Dict[str, str]: + """ + Get the runtime environment in the given container. + + Returns: {env_var: value} dictionary, sorted by key. + """ + + def impl() -> typing.Dict[str, str]: + kvs = ( + worker.check_exec(["env", "-0"], policy="once", stderr="separate") + .stdout[:-1] # skip the trailing \0 + .split("\0") + ) + return dict(kv.split("=", 1) for kv in kvs) + + # Remove any environment variables that differ between consecutive `env` calls, for + # example some step-specific Slurm variables. + env1, env2 = impl(), impl() + # sorted(...) for run-to-run determinism + return {k: env1[k] for k in sorted(env1.keys() & env2.keys()) if env1[k] == env2[k]} + + +def get_commits_and_dirs( + worker: Container, +) -> typing.Tuple[typing.Dict[str, str], typing.Dict[str, str]]: + """ + Get the git repository paths and current HEAD commits in the given environment of + the software packages named in `compulsory_software` and `optional_software`. + + Returns: ({package: commit}, {package: directory}) + """ + # Formulated this way to avoid paying too many times for container startup. + cmds = [] + for package in compulsory_software + optional_software: + bits = [ + f"(cd /opt/{package} && git rev-parse HEAD && echo {package} && pwd)", + f"(cd /opt/{package}-source && git rev-parse HEAD && echo {package} && pwd)", + ] + if package in optional_software: + bits.append("true") + cmds.append(f"({' || '.join(bits)})") + result = worker.check_exec( + ["sh", "-c", " && ".join(cmds)], policy="once", stderr="separate" + ) + versions, dirs = {}, {} + # Look over triplets of output lines + for commit, package, dirname in zip(*([iter(result.stdout.splitlines())] * 3)): + dirs[package] = dirname + versions[package] = commit + return versions, dirs + + +def get_versions_dirs_env( + worker: Container, + versions_from_env: bool, +) -> typing.Tuple[typing.Dict[str, str], typing.Dict[str, str], typing.Dict[str, str]]: + """ + Get software versions in the given [container] environment, git repository paths + where relevant, and the runtime environment. + + The list of software versions is drawn from git repositories at known container + locations and, if `versions_from_env` is True, from the environment. + + Returns: + versions: {package: version or commit}, + dirs: {package: git_repository_dir} + env: {env_var: value} + """ + # Get the git repository paths and commits from the container. + versions, dirs = get_commits_and_dirs(worker) + + # Get the environment variables from the container. + env = get_env(worker) + + if versions_from_env: + # Promote any XXX_VERSION environment variables into `versions` if `XXX` is + # not already there. + for k, v in env.items(): + if not len(v) or not k.endswith("_VERSION"): + continue + package = k[:-8] + assert package not in versions, (versions, package) + versions[package] = v + return versions, dirs, env diff --git a/.github/triage/tests/mock_scripts/bazel b/.github/triage/tests/mock_scripts/bazel new file mode 100755 index 000000000..8c3cbfc39 --- /dev/null +++ b/.github/triage/tests/mock_scripts/bazel @@ -0,0 +1,3 @@ +#!/bin/bash + +exit 0 diff --git a/.github/triage/tests/mock_scripts/build-jax.sh b/.github/triage/tests/mock_scripts/build-jax.sh new file mode 100755 index 000000000..5d19da55f --- /dev/null +++ b/.github/triage/tests/mock_scripts/build-jax.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +if [ ! -f "feature_file.txt" ]; then + echo "Build FAILED: The feature commit was not applied (feature_file.txt is missing)." + exit 1 +fi + +echo "Mock build script: Build successful (feature commit is present)." +exit 0 diff --git a/.github/triage/tests/mock_scripts/test-case.sh b/.github/triage/tests/mock_scripts/test-case.sh new file mode 100755 index 000000000..355cf71dd --- /dev/null +++ b/.github/triage/tests/mock_scripts/test-case.sh @@ -0,0 +1,20 @@ +#!/bin/bash + + +REPO_PATH=$1 +BAD_COMMIT=$2 + +if [ -z "$REPO_PATH" ] || [ -z "$BAD_COMMIT" ]; then + echo "Usage: $0 " + exit 1 +fi + +cd ${REPO_PATH} + +if git merge-base --is-ancestor ${BAD_COMMIT} HEAD; then + echo "The commit ${BAD_COMMIT} is an ancestor of the current HEAD." + exit 1 +else + echo "The commit ${BAD_COMMIT} is NOT an ancestor of the current HEAD." + exit 0 +fi diff --git a/.github/triage/tests/test_triage_history_bisection.py b/.github/triage/tests/test_triage_history_bisection.py new file mode 100644 index 000000000..59e287e43 --- /dev/null +++ b/.github/triage/tests/test_triage_history_bisection.py @@ -0,0 +1,199 @@ +import subprocess +import tempfile +import pathlib +import logging +import pytest + +from jax_toolbox_triage.args import parse_args +from jax_toolbox_triage.triage_tool import TriageTool + + +def run_command(command, cwd=None, env=None): + """Simple function to run a command in a subprocess. + + Args: + command (list): The command to run as a list of strings. + cwd (str, optional): The working directory to run the command in. + env (dict, optional): Environment variables to set for the command. + Returns: + str: The standard output of the command. + """ + try: + result = subprocess.run( + command, cwd=cwd, env=env, check=True, capture_output=True, text=True + ) + return result.stdout.strip() + except subprocess.CalledProcessError as e: + logging.error(f"Command '{' '.join(command)}' failed with error: {e}") + raise e + + +@pytest.fixture +def triage_env(): + """ + Fixture to set up the test environment for triage tests. + + The fixture creates a temp directory and a git repo with a + defined linear and non-linear history. + + The fixture yields a dictionary of paths and commit hashes + """ + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) + repo_path = temp_path / "repos" + output_path = temp_path / "output" + source_scripts_dir = pathlib.Path(__file__).parent / "mock_scripts" + repo_path.mkdir() + output_path.mkdir() + jax_repo_path = repo_path / "jax" + jax_repo_path.mkdir() + + def git_cmd(command, *args): + return run_command(["git", command, *args], cwd=jax_repo_path) + + # NON-LINEAR HISTORY + # why don't we push the scripts and use them in repo + git_cmd("init", "-b", "main") + git_cmd("config", "user.name", "Test User") + git_cmd("config", "user.email", "test@example.com") + # Create a linear commit history + git_cmd("commit", "--allow-empty", "-m", "M0_base") + git_cmd("commit", "--allow-empty", "-m", "M1") + m1 = git_cmd("rev-parse", "HEAD") + git_cmd("commit", "--allow-empty", "-m", "M2") # good commit + m2 = git_cmd("rev-parse", "HEAD") + git_cmd("commit", "--allow-empty", "-m", "M3") # bad commit + m3 = git_cmd("rev-parse", "HEAD") + # create a feature branch from M1 + git_cmd("checkout", "-b", "feature", m1) + (jax_repo_path / "feature_file.txt").write_text("feature") + git_cmd("add", "feature_file.txt") + git_cmd("commit", "-m", "F1") + passing_nonlinear = git_cmd("rev-parse", "HEAD") # F1 + # and then we apply the feature to the bad commit + # this simulated the rebase scenario + git_cmd("checkout", "-b", "failing_nonlinear", m3) + git_cmd("cherry-pick", passing_nonlinear) + failing_nonlinear = git_cmd("rev-parse", "HEAD") # F1' + # so now we have: + # M0--M1 --- M2 --- M3 + # | | + # F1 F1' + # where F1 = passing + # and F1' = failing + + # LINEAR HISTORY + git_cmd("checkout", "-b", "linear_feature_branch", passing_nonlinear) + git_cmd("commit", "--allow-empty", "-m", "L1") + l1_good_commit = git_cmd("rev-parse", "HEAD") # L1 + git_cmd("commit", "--allow-empty", "-m", "L2_BAD") # L2 bad commit + l2_bad_linear_commit = git_cmd("rev-parse", "HEAD") + git_cmd("commit", "--allow-empty", "-m", "L3") # L3 + l3_failing_linear = git_cmd("rev-parse", "HEAD") + # so the linear repo would be + # M1 -- M2 -- M3 + # | + # F1 + # | + # L1 -- L2 -- L3 + + # yield all the info + yield { + "paths": { + "repo": repo_path, + "output": output_path, + "scripts": source_scripts_dir, + }, + "commits": { + "good_main": m2, # last good commit + "bad_main": m3, + "passing_nonlinear": passing_nonlinear, + "failing_nonlinear": failing_nonlinear, + "good_linear": l1_good_commit, + "bad_commit_for_linear": l2_bad_linear_commit, + "failing_linear": l3_failing_linear, + }, + } + + +@pytest.mark.parametrize( + "scenario, passing_commit_key, failing_commit_key, bad_commit_key, expected_good_key, expected_bad_key", + [ + ( + "Non-Linear History", + "passing_nonlinear", + "failing_nonlinear", + "bad_main", + "good_main", + "bad_main", + ), + ( + "Linear History", + "good_linear", + "failing_linear", + "bad_commit_for_linear", + "good_linear", + "bad_commit_for_linear", + ), + ], +) +def test_triage_scenarios( + triage_env, + monkeypatch, + scenario, + passing_commit_key, + failing_commit_key, + bad_commit_key, + expected_good_key, + expected_bad_key, +): + """Test the get_commit_history for linear and non-linear histories.""" + paths = triage_env["paths"] + all_commits = triage_env["commits"] + jax_repo_path = paths["repo"] / "jax" + passing_versions = {"jax": all_commits[passing_commit_key]} + failing_versions = {"jax": all_commits[failing_commit_key]} + passing_versions_str = f"jax:{all_commits[passing_commit_key]}" + failing_versions_str = f"jax:{all_commits[failing_commit_key]}" + + bazel_cache_path = (paths["output"] / "bazel-cache").resolve() + bazel_cache_path.mkdir() + + arg_list = [ + "--main-branch", + "main", + "--output-prefix", + str(paths["output"]), + "--container-runtime", + "local", + "--passing-versions", + passing_versions_str, + "--failing-versions", + failing_versions_str, + "--bazel-cache", + str(bazel_cache_path), + "--", + str(paths["scripts"] / "test-case.sh"), + str(jax_repo_path), + all_commits[bad_commit_key], + ] + args = parse_args(arg_list) + logger = logging.getLogger(f"Scenario-{scenario}") + logging.basicConfig(level=logging.INFO) + # Ensure bazel and build-jax.sh can be found. + monkeypatch.setenv("PATH", paths["scripts"], prepend=":") + + tool = TriageTool(args, logger) + tool.package_dirs = {"jax": str(jax_repo_path)} + tool.dynamic_packages = {"jax"} + tool.bisection_url = "local" + + # run the bisection + summary_data = tool.run_version_bisection(passing_versions, failing_versions) + + assert "result" in summary_data, "No result section was created" + result = summary_data["result"] + + assert result.get("jax_good") == all_commits[expected_good_key] + assert result.get("jax_bad") == all_commits[expected_bad_key] diff --git a/docs/triage-tool.md b/docs/triage-tool.md index 349e8b7c8..ba8b9c6da 100644 --- a/docs/triage-tool.md +++ b/docs/triage-tool.md @@ -69,6 +69,88 @@ passed to `salloc` or set via `SLURM_` environment variables so that a bare `sru correctly launch the test case. If `--container-runtime=local` is used, the tool assumes it is already inside a JAX container and will execute all build and test commands directly. +## Triage tool logic scheme + ++--------------------------------+ +| main.py | +| main() | ++--------------------------------+ + | + v ++--------------------------------+ +| tool = TriageTool(args) | +| tool.prepare() | ++--------------------------------+ + | + v ++--------------------------------+ +| tool.find_container_range() | ++--------------------------------+ + | + v ++--------------------------------+ +| tool.gather_version_info() | +| # extract commit hashes, | +| # compare environments | ++--------------------------------+ + | + v ++--------------------------------+ +| tool.run_version_bisection() | +| # main part of the logic | +| # it runs all the steps below | ++--------------------------------+ + | + v ++------------------------------------------------+ +| tool._gather_histories() | +| (Calls bisect.py -> get_commit_history) | +| # get the list of all the commits between p&f | ++------------------------------------------------+ + | + | <--- Inside get_commit_history() + | ++----------------------------------------------------------------------------+ +| Is history linear? (git merge-base --is-ancestor) | ++--------------------------+-------------------------------------------------+ +| YES | NO | +| (Linear History Path) | (Non-Linear History Path) | ++--------------------------+-------------------------------------------------+ + | | + v v ++--------------------------+ +-------------------------------------------------+ +| get_commit_history() | | get_commit_history() | +| returns commits from | | - Uses 'git merge-base' to find linear range | +| main branch directly. | | - Finds cherry-picks to apply to args | +| | | - Returns commits from the *main* branch | ++--------------------------+ +-------------------------------------------------+ + | | + | | + +--------------------------------+ + | + v ++------------------------------------------------+ +| logic.py -> version_search() loop | +| (Repeatedly calls tool._build_and_test) | ++------------------------------------------------+ + | + | <--- Inside _build_and_test() + | ++----------------------------------------------------------------------------+ +| Does args.cherry_pick_commits exist? --> | ++--------------------------+-------------------------------------------------+ +| NO | YES | +| (Linear History Path) | (Non-Linear History Path) | ++--------------------------+-------------------------------------------------+ + | | + v v ++--------------------------+ +-------------------------------------------------+ +| _build_and_test() | | _build_and_test() | +| - git checkout | | - git checkout | +| - build & test | | - git cherry-pick | +| | | - build & test | ++--------------------------+ +-------------------------------------------------+ + ## Usage To use the tool, there are two compulsory inputs: @@ -161,7 +243,7 @@ paths and `http`/`https`/`grpc` URLs. If `--skip-precondition-checks` is passed, a sanity check that the failure can be reproduced after rebuilding the JAX/XLA commits from the first-known-bad container -inside that container will be skipped. +inside that container will be skipped. ## Example