Skip to content

Commit 1508f12

Browse files
committed
[autorevert] implement autorevert and fix detection logic
1 parent 4eb9d0c commit 1508f12

File tree

4 files changed

+375
-53
lines changed

4 files changed

+375
-53
lines changed

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def get_opts() -> argparse.Namespace:
9090
action="store_true",
9191
help="Actually restart workflows for detected autorevert patterns",
9292
)
93+
workflow_parser.add_argument(
94+
"--do-revert",
95+
action="store_true",
96+
help="When restarts complete and secondary pattern matches, log REVERT",
97+
)
9398
workflow_parser.add_argument(
9499
"--dry-run",
95100
action="store_true",
@@ -185,6 +190,7 @@ def main(*args, **kwargs) -> None:
185190
hours=opts.hours,
186191
verbose=opts.verbose,
187192
do_restart=opts.do_restart,
193+
do_revert=opts.do_revert,
188194
dry_run=opts.dry_run,
189195
ignore_common_errors=opts.ignore_common_errors,
190196
)

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/autorevert_checker.py

Lines changed: 219 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,19 @@ def job_base_names(self) -> Set[str]:
5151
return self._job_base_names
5252

5353
def normalize_job_name(self, name: str) -> str:
54-
"""Strip shard suffix from job name for matching."""
54+
"""Normalize job name to a stable base for matching across commits.
55+
56+
- Drop any trailing parenthetical qualifiers (e.g., "(rocm)", shard notes)
57+
- Strip common shard suffixes like ", 1, 1, " used in CI naming
58+
- Collapse redundant whitespace
59+
"""
60+
# Drop any trailing parenthetical qualifier
61+
base = re.sub(r"\s*\(.*\)$", "", name)
5562
# Remove patterns like ", 1, 1, " or ", 2, 3, " from job names
56-
return re.sub(r", \d+, \d+, ", ", ", name)
63+
base = re.sub(r", \d+, \d+, ", ", ", base)
64+
# Collapse multiple spaces
65+
base = re.sub(r"\s+", " ", base).strip()
66+
return base
5767

5868
def get_job_base_names(self) -> Set[str]:
5969
"""Get normalized job names (without shard info)."""
@@ -74,13 +84,24 @@ def __init__(
7484
self._workflow_commits_cache: Dict[str, List[CommitJobs]] = {}
7585
self._commit_history = None
7686
self._ignore_classification_rules = ignore_classification_rules or set()
87+
# Controls whether queries target restarted runs only (workflow_dispatch/tagged trunk/<sha>)
88+
self._use_restarted_runs_only = False
7789

7890
def get_workflow_commits(self, workflow_name: str) -> List[CommitJobs]:
7991
"""Get workflow commits for a specific workflow, fetching if needed. From newer to older"""
8092
if workflow_name not in self._workflow_commits_cache:
8193
self._fetch_workflow_data()
8294
return self._workflow_commits_cache.get(workflow_name, [])
8395

96+
def get_workflow_commit_by_sha(
97+
self, workflow_name: str, sha: str
98+
) -> Optional[CommitJobs]:
99+
"""Return CommitJobs for a workflow and head_sha if present in cache."""
100+
for cj in self.get_workflow_commits(workflow_name):
101+
if cj.head_sha.startswith(sha):
102+
return cj
103+
return None
104+
84105
@property
85106
def workflow_commits(self) -> List[CommitJobs]:
86107
"""Get workflow commits for the first workflow (backward compatibility)."""
@@ -106,7 +127,13 @@ def _fetch_workflow_data(self):
106127
f"Fetching workflow data for {len(self.workflow_names)} workflows since {lookback_time.isoformat()}..."
107128
)
108129

109-
query = """
130+
base_where = (
131+
"workflow_event = 'workflow_dispatch' AND head_branch LIKE 'trunk/%'"
132+
if self._use_restarted_runs_only
133+
else "workflow_event != 'workflow_dispatch' AND head_branch = 'main'"
134+
)
135+
136+
query = f"""
110137
SELECT
111138
workflow_name,
112139
head_sha,
@@ -118,11 +145,10 @@ def _fetch_workflow_data(self):
118145
FROM
119146
workflow_job FINAL
120147
WHERE
121-
workflow_name IN {workflow_names:Array(String)}
122-
AND head_branch = 'main'
123-
AND created_at >= {lookback_time:DateTime}
148+
workflow_name IN {{workflow_names:Array(String)}}
149+
AND {base_where}
150+
AND created_at >= {{lookback_time:DateTime}}
124151
AND dynamoKey LIKE 'pytorch/pytorch/%'
125-
AND workflow_event != 'workflow_dispatch' -- Exclude restart jobs
126152
ORDER BY
127153
workflow_name, workflow_created_at DESC, head_sha, name
128154
"""
@@ -221,7 +247,7 @@ def _find_last_commit_with_job(
221247
job_results = []
222248
for commit in commits:
223249
for job in commit.jobs:
224-
if job.name.split("(")[0] == job_name: # Normalize job name
250+
if commit.normalize_job_name(job.name) == job_name:
225251
job_results.append(job)
226252
if job_results:
227253
return (
@@ -230,6 +256,25 @@ def _find_last_commit_with_job(
230256
)
231257
return None, None
232258

259+
def _find_last_commit_with_rule(
260+
self, commits: Iterable[CommitJobs], rule: str, failures_only: bool = True
261+
) -> Optional[Tuple[CommitJobs, List[JobResult]]]:
262+
"""Find the first commit (per iteration order) that has any job with the given rule.
263+
264+
If failures_only is True, only consider jobs with conclusion == 'failure'.
265+
"""
266+
job_results = []
267+
for commit in commits:
268+
job_results = [
269+
j
270+
for j in commit.jobs
271+
if j.classification_rule == rule
272+
and (j.conclusion == "failure" if failures_only else True)
273+
]
274+
if job_results:
275+
return commit, job_results
276+
return None, None
277+
233278
def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
234279
"""
235280
Detect all autorevert patterns in commit job data for a specific workflow.
@@ -257,10 +302,11 @@ def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
257302
if suspected_commit1.has_pending_jobs:
258303
continue
259304

305+
# Primary path: use classified failures (highest precision)
260306
suspected_failures = {
261307
(
262308
j.classification_rule,
263-
j.name.split("(")[0],
309+
suspected_commit1.normalize_job_name(j.name),
264310
)
265311
for j in suspected_commit1.failed_jobs
266312
}
@@ -282,15 +328,17 @@ def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
282328
suspected_failure_job_name,
283329
)
284330
)
285-
if not newer_commit_same_job or not newer_same_jobs:
286-
# No newer commit with the same job found
287-
continue
288331

289-
if any(
290-
j.classification_rule == suspected_failure_class_rule
291-
for j in newer_same_jobs
332+
if (
333+
newer_commit_same_job
334+
and newer_same_jobs
335+
and any(
336+
j.classification_rule == suspected_failure_class_rule
337+
and j.conclusion == "failure"
338+
for j in newer_same_jobs
339+
)
292340
):
293-
# The newer commit has the same job failing
341+
# The newer commit has the same failure (job may differ)
294342
failure_key = (
295343
suspected_failure_class_rule,
296344
suspected_failure_job_name,
@@ -311,25 +359,61 @@ def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
311359
)
312360

313361
if not last_commit_with_same_job or not last_same_jobs:
314-
# No older commit with the same job found
362+
# No older commit with any jobs found
315363
continue
316364

317-
if any(j.classification_rule == failure_rule for j in last_same_jobs):
318-
# The older commit has the same job failing with same rule
365+
if any(
366+
j.classification_rule == failure_rule and j.conclusion == "failure"
367+
for j in last_same_jobs
368+
):
369+
# The older commit already has the same failure (regardless of job)
370+
continue
371+
372+
# Ensure there is some overlap in job coverage between suspected and older commit
373+
older_coverage = list(
374+
set(suspected_commit1.job_base_names)
375+
& set(last_commit_with_same_job.job_base_names)
376+
)
377+
if not older_coverage:
378+
# No overlapping jobs -> insufficient comparable signal
379+
continue
380+
381+
# Cross-workflow baseline check: if multiple workflows were provided,
382+
# ensure the older_commit does NOT have the same failure in any sibling workflow.
383+
older_sha = last_commit_with_same_job.head_sha
384+
conflict_in_other_wf = False
385+
if len(self.workflow_names) > 1:
386+
for other_wf in self.workflow_names:
387+
if other_wf == workflow_name:
388+
continue
389+
other_cj = self.get_workflow_commit_by_sha(other_wf, older_sha)
390+
if other_cj and any(
391+
j.classification_rule == failure_rule for j in other_cj.jobs
392+
):
393+
conflict_in_other_wf = True
394+
break
395+
396+
if conflict_in_other_wf:
397+
# Skip this pattern; baseline is not clean across workflows
319398
continue
320399

321400
patterns.append(
322401
{
323402
"pattern_detected": True,
324403
"workflow_name": workflow_name,
325404
"failure_rule": failure_rule,
405+
"job_name_base": job_name,
326406
"newer_commits": [
327407
newer_commit.head_sha,
328408
suspected_commit1.head_sha,
329409
],
330-
"older_commit": last_commit_with_same_job.head_sha,
331-
"failed_job_names": [j.name for j in last_same_jobs],
332-
"older_job_coverage": [],
410+
"older_commit": older_sha,
411+
"failed_job_names": [
412+
j.name
413+
for j in suspected_commit1.failed_jobs
414+
if j.classification_rule == failure_rule
415+
][:10],
416+
"older_job_coverage": older_coverage[:10],
333417
}
334418
)
335419
break
@@ -378,6 +462,119 @@ def detect_autorevert_pattern(self) -> List[Dict]:
378462

379463
return all_patterns
380464

465+
def _fetch_single_commit_jobs(
466+
self,
467+
workflow_name: str,
468+
head_sha: str,
469+
restarted_only: bool = False,
470+
) -> Optional[CommitJobs]:
471+
"""Fetch jobs for a single workflow+commit, optionally only restarted runs.
472+
473+
Groups all jobs by head_sha (assumes at most one restart dispatch of interest).
474+
Returns CommitJobs or None if no jobs found in lookback window.
475+
"""
476+
lookback_time = datetime.now() - timedelta(hours=self.lookback_hours)
477+
478+
where_event = (
479+
"workflow_event = {we:String} AND head_branch LIKE {hb:String}"
480+
if restarted_only
481+
else "workflow_event != {we:String} AND head_branch = {hb:String}"
482+
)
483+
484+
query = f"""
485+
SELECT
486+
head_sha,
487+
name,
488+
conclusion,
489+
status,
490+
torchci_classification.rule AS classification_rule,
491+
created_at AS workflow_created_at
492+
FROM workflow_job FINAL
493+
WHERE workflow_name = {{workflow_name:String}}
494+
AND head_sha = {{head_sha:String}}
495+
AND {where_event}
496+
AND created_at >= {{lookback_time:DateTime}}
497+
AND dynamoKey LIKE 'pytorch/pytorch/%'
498+
ORDER BY workflow_created_at DESC, name
499+
"""
500+
501+
hb = "trunk/%" if restarted_only else "main"
502+
we = "workflow_dispatch" if restarted_only else "workflow_dispatch"
503+
# Note: for non-restarted we exclude workflow_dispatch via != in WHERE above
504+
505+
result = CHCliFactory().client.query(
506+
query,
507+
parameters={
508+
"workflow_name": workflow_name,
509+
"head_sha": head_sha,
510+
"we": we,
511+
"hb": hb,
512+
"lookback_time": lookback_time,
513+
},
514+
)
515+
516+
rows = list(result.result_rows)
517+
if not rows:
518+
return None
519+
520+
# Use the newest created_at among returned rows as the commit's created_at marker
521+
latest_created = max(r[5] for r in rows)
522+
cj = CommitJobs(head_sha=head_sha, created_at=latest_created, jobs=[])
523+
for row in rows:
524+
_, name, conclusion, status, classification_rule, created_at = row
525+
cj.jobs.append(
526+
JobResult(
527+
head_sha=head_sha,
528+
name=name,
529+
conclusion=conclusion,
530+
status=status,
531+
classification_rule=classification_rule or "",
532+
workflow_created_at=created_at,
533+
)
534+
)
535+
return cj
536+
537+
def confirm_commit_caused_failure_on_restarted(self, pattern: Dict) -> bool:
538+
"""Confirm commit-caused failure using restarted runs.
539+
540+
Requires that:
541+
- first failing commit's restarted run has the same failure classification for the job
542+
- previous commit's restarted run does NOT have that failure classification for the job
543+
- both restarted runs have no pending jobs
544+
"""
545+
workflow_name = pattern["workflow_name"]
546+
job_base = pattern.get("job_name_base")
547+
failure_rule = pattern["failure_rule"]
548+
first_failing = pattern["newer_commits"][1]
549+
previous_commit = pattern["older_commit"]
550+
551+
# Fetch restarted jobs for first failing and previous commits
552+
failing_jobs = self._fetch_single_commit_jobs(
553+
workflow_name, first_failing, restarted_only=True
554+
)
555+
prev_jobs = self._fetch_single_commit_jobs(
556+
workflow_name, previous_commit, restarted_only=True
557+
)
558+
if not failing_jobs or not prev_jobs:
559+
return False
560+
561+
# Pending check
562+
if failing_jobs.has_pending_jobs or prev_jobs.has_pending_jobs:
563+
return False
564+
565+
def has_rule(cj: CommitJobs, rule: str) -> bool:
566+
return any(
567+
cj.normalize_job_name(j.name) == job_base
568+
and j.classification_rule == rule
569+
and j.conclusion == "failure"
570+
for j in cj.jobs
571+
)
572+
573+
# Commit-caused if failing commit reproduces, previous does not
574+
return has_rule(failing_jobs, failure_rule) and not has_rule(
575+
prev_jobs, failure_rule
576+
)
577+
381578
def get_commits_reverted(self) -> Set[str]:
382579
"""
383580
Get all commits that were reverted within the lookback window.

0 commit comments

Comments
 (0)