diff --git a/llvm/utils/revert_checker.py b/llvm/utils/revert_checker.py index da80bdff86857..b1c6e228e4d41 100755 --- a/llvm/utils/revert_checker.py +++ b/llvm/utils/revert_checker.py @@ -45,35 +45,78 @@ import re import subprocess import sys -from typing import Generator, List, NamedTuple, Iterable +from typing import Dict, Generator, Iterable, List, NamedTuple, Optional, Tuple assert sys.version_info >= (3, 6), "Only Python 3.6+ is supported." # People are creative with their reverts, and heuristics are a bit difficult. -# Like 90% of of reverts have "This reverts commit ${full_sha}". -# Some lack that entirely, while others have many of them specified in ad-hoc -# ways, while others use short SHAs and whatever. +# At a glance, most reverts have "This reverts commit ${full_sha}". Many others +# have `Reverts llvm/llvm-project#${PR_NUMBER}`. # -# The 90% case is trivial to handle (and 100% free + automatic). The extra 10% -# starts involving human intervention, which is probably not worth it for now. +# By their powers combined, we should be able to automatically catch something +# like 80% of reverts with reasonable confidence. At some point, human +# intervention will always be required (e.g., I saw +# ``` +# This reverts commit ${commit_sha_1} and +# also ${commit_sha_2_shorthand} +# ``` +# during my sample) + +_CommitMessageReverts = NamedTuple( + "_CommitMessageReverts", + [ + ("potential_shas", List[str]), + ("potential_pr_numbers", List[int]), + ], +) + +def _try_parse_reverts_from_commit_message( + commit_message: str, +) -> _CommitMessageReverts: + """Tries to parse revert SHAs and LLVM PR numbers form the commit message. -def _try_parse_reverts_from_commit_message(commit_message: str) -> List[str]: + Returns: + A namedtuple containing: + - A list of potentially reverted SHAs + - A list of potentially reverted LLVM PR numbers + """ if not commit_message: - return [] + return _CommitMessageReverts([], []) - results = re.findall(r"This reverts commit ([a-f0-9]{40})\b", commit_message) + sha_reverts = re.findall( + r"This reverts commit ([a-f0-9]{40})\b", + commit_message, + ) first_line = commit_message.splitlines()[0] initial_revert = re.match(r'Revert ([a-f0-9]{6,}) "', first_line) if initial_revert: - results.append(initial_revert.group(1)) - return results + sha_reverts.append(initial_revert.group(1)) + pr_numbers = [ + int(x) + for x in re.findall( + r"Reverts llvm/llvm-project#(\d+)", + commit_message, + ) + ] + + return _CommitMessageReverts( + potential_shas=sha_reverts, + potential_pr_numbers=pr_numbers, + ) -def _stream_stdout(command: List[str]) -> Generator[str, None, None]: + +def _stream_stdout( + command: List[str], cwd: Optional[str] = None +) -> Generator[str, None, None]: with subprocess.Popen( - command, stdout=subprocess.PIPE, encoding="utf-8", errors="replace" + command, + cwd=cwd, + stdout=subprocess.PIPE, + encoding="utf-8", + errors="replace", ) as p: assert p.stdout is not None # for mypy's happiness. yield from p.stdout @@ -175,10 +218,43 @@ def _find_common_parent_commit(git_dir: str, ref_a: str, ref_b: str) -> str: ).strip() -def find_reverts(git_dir: str, across_ref: str, root: str) -> List[Revert]: +def _load_pr_commit_mappings( + git_dir: str, root: str, min_ref: str +) -> Dict[int, List[str]]: + git_log = ["git", "log", "--format=%H %s", f"{min_ref}..{root}"] + results = collections.defaultdict(list) + pr_regex = re.compile(r"\s\(#(\d+)\)$") + for line in _stream_stdout(git_log, cwd=git_dir): + m = pr_regex.search(line) + if not m: + continue + + pr_number = int(m.group(1)) + sha = line.split(None, 1)[0] + # N.B., these are kept in log (read: reverse chronological) order, + # which is what's expected by `find_reverts`. + results[pr_number].append(sha) + return results + + +# N.B., max_pr_lookback's default of 20K commits is arbitrary, but should be +# enough for the 99% case of reverts: rarely should someone land a cleanish +# revert of a >6 month old change... +def find_reverts( + git_dir: str, across_ref: str, root: str, max_pr_lookback: int = 20000 +) -> List[Revert]: """Finds reverts across `across_ref` in `git_dir`, starting from `root`. These reverts are returned in order of oldest reverts first. + + Args: + git_dir: git directory to find reverts in. + across_ref: the ref to find reverts across. + root: the 'main' ref to look for reverts on. + max_pr_lookback: this function uses heuristics to map PR numbers to + SHAs. These heuristics require that commit history from `root` to + `some_parent_of_root` is loaded in memory. `max_pr_lookback` is how + many commits behind `across_ref` should be loaded in memory. """ across_sha = _rev_parse(git_dir, across_ref) root_sha = _rev_parse(git_dir, root) @@ -201,8 +277,41 @@ def find_reverts(git_dir: str, across_ref: str, root: str) -> List[Revert]: ) all_reverts = [] + # Lazily load PR <-> commit mappings, since it can be expensive. + pr_commit_mappings = None for sha, commit_message in _log_stream(git_dir, root_sha, across_sha): - reverts = _try_parse_reverts_from_commit_message(commit_message) + reverts, pr_reverts = _try_parse_reverts_from_commit_message( + commit_message, + ) + if pr_reverts: + if pr_commit_mappings is None: + logging.info( + "Loading PR <-> commit mappings. This may take a moment..." + ) + pr_commit_mappings = _load_pr_commit_mappings( + git_dir, root_sha, f"{across_sha}~{max_pr_lookback}" + ) + logging.info( + "Loaded %d PR <-> commit mappings", len(pr_commit_mappings) + ) + + for reverted_pr_number in pr_reverts: + reverted_shas = pr_commit_mappings.get(reverted_pr_number) + if not reverted_shas: + logging.warning( + "No SHAs for reverted PR %d (commit %s)", + reverted_pr_number, + sha, + ) + continue + logging.debug( + "Inferred SHAs %s for reverted PR %d (commit %s)", + reverted_shas, + reverted_pr_number, + sha, + ) + reverts.extend(reverted_shas) + if not reverts: continue diff --git a/llvm/utils/revert_checker_test.py b/llvm/utils/revert_checker_test.py index 9d992663c5be8..c149be8dc0dd1 100755 --- a/llvm/utils/revert_checker_test.py +++ b/llvm/utils/revert_checker_test.py @@ -96,6 +96,7 @@ def test_reverted_noncommit_object_is_a_nop(self) -> None: git_dir=get_llvm_project_path(), across_ref="c9944df916e41b1014dff5f6f75d52297b48ecdc~", root="c9944df916e41b1014dff5f6f75d52297b48ecdc", + max_pr_lookback=50, ) self.assertEqual(reverts, []) @@ -113,6 +114,7 @@ def test_known_reverts_across_arbitrary_llvm_rev(self) -> None: git_dir=get_llvm_project_path(), across_ref="c47f971694be0159ffddfee8a75ae515eba91439", root="9f981e9adf9c8d29bb80306daf08d2770263ade6", + max_pr_lookback=50, ) self.assertEqual( reverts, @@ -128,6 +130,27 @@ def test_known_reverts_across_arbitrary_llvm_rev(self) -> None: ], ) + def test_pr_based_revert_works(self) -> None: + reverts = revert_checker.find_reverts( + git_dir=get_llvm_project_path(), + # This SHA is a direct child of the reverted SHA expected below. + across_ref="2d5f3b0a61fb171617012a2c3ba05fd31fb3bb1d", + # This SHA is a direct child of the revert SHA listed below. + root="2c01b278580212914ec037bb5dd9b73702dfe7f1", + max_pr_lookback=50, + ) + self.assertEqual( + reverts, + [ + revert_checker.Revert( + # This SHA is a `Reverts ${PR}` for #111004. + sha="50866e84d1da8462aeb96607bf6d9e5bbd5869c5", + # ...And this was the commit for #111004. + reverted_sha="67160c5ab5f5b7fd5fa7851abcfde367c8a9f91b", + ), + ], + ) + if __name__ == "__main__": unittest.main()