Skip to content

Commit 2891ee8

Browse files
committed
[autorevert] implement autorevert and fix detection logic
1 parent 4eb9d0c commit 2891ee8

File tree

4 files changed

+403
-53
lines changed

4 files changed

+403
-53
lines changed

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def get_opts() -> argparse.Namespace:
9090
action="store_true",
9191
help="Actually restart workflows for detected autorevert patterns",
9292
)
93+
workflow_parser.add_argument(
94+
"--do-revert",
95+
action="store_true",
96+
help="When restarts complete and secondary pattern matches, log REVERT",
97+
)
9398
workflow_parser.add_argument(
9499
"--dry-run",
95100
action="store_true",
@@ -185,6 +190,7 @@ def main(*args, **kwargs) -> None:
185190
hours=opts.hours,
186191
verbose=opts.verbose,
187192
do_restart=opts.do_restart,
193+
do_revert=opts.do_revert,
188194
dry_run=opts.dry_run,
189195
ignore_common_errors=opts.ignore_common_errors,
190196
)

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/autorevert_checker.py

Lines changed: 247 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,19 @@ def job_base_names(self) -> Set[str]:
5151
return self._job_base_names
5252

5353
def normalize_job_name(self, name: str) -> str:
54-
"""Strip shard suffix from job name for matching."""
54+
"""Normalize job name to a stable base for matching across commits.
55+
56+
- Drop any trailing parenthetical qualifiers (e.g., "(rocm)", shard notes)
57+
- Strip common shard suffixes like ", 1, 1, " used in CI naming
58+
- Collapse redundant whitespace
59+
"""
60+
# Drop any trailing parenthetical qualifier
61+
base = re.sub(r"\s*\(.*\)$", "", name)
5562
# Remove patterns like ", 1, 1, " or ", 2, 3, " from job names
56-
return re.sub(r", \d+, \d+, ", ", ", name)
63+
base = re.sub(r", \d+, \d+, ", ", ", base)
64+
# Collapse multiple spaces
65+
base = re.sub(r"\s+", " ", base).strip()
66+
return base
5767

5868
def get_job_base_names(self) -> Set[str]:
5969
"""Get normalized job names (without shard info)."""
@@ -74,13 +84,24 @@ def __init__(
7484
self._workflow_commits_cache: Dict[str, List[CommitJobs]] = {}
7585
self._commit_history = None
7686
self._ignore_classification_rules = ignore_classification_rules or set()
87+
# Controls whether queries target restarted runs only (workflow_dispatch/tagged trunk/<sha>)
88+
self._use_restarted_runs_only = False
7789

7890
def get_workflow_commits(self, workflow_name: str) -> List[CommitJobs]:
7991
"""Get workflow commits for a specific workflow, fetching if needed. From newer to older"""
8092
if workflow_name not in self._workflow_commits_cache:
8193
self._fetch_workflow_data()
8294
return self._workflow_commits_cache.get(workflow_name, [])
8395

96+
def get_workflow_commit_by_sha(
97+
self, workflow_name: str, sha: str
98+
) -> Optional[CommitJobs]:
99+
"""Return CommitJobs for a workflow and head_sha if present in cache."""
100+
for cj in self.get_workflow_commits(workflow_name):
101+
if cj.head_sha.startswith(sha):
102+
return cj
103+
return None
104+
84105
@property
85106
def workflow_commits(self) -> List[CommitJobs]:
86107
"""Get workflow commits for the first workflow (backward compatibility)."""
@@ -106,7 +127,13 @@ def _fetch_workflow_data(self):
106127
f"Fetching workflow data for {len(self.workflow_names)} workflows since {lookback_time.isoformat()}..."
107128
)
108129

109-
query = """
130+
base_where = (
131+
"workflow_event = 'workflow_dispatch' AND head_branch LIKE 'trunk/%'"
132+
if self._use_restarted_runs_only
133+
else "workflow_event != 'workflow_dispatch' AND head_branch = 'main'"
134+
)
135+
136+
query = f"""
110137
SELECT
111138
workflow_name,
112139
head_sha,
@@ -118,11 +145,10 @@ def _fetch_workflow_data(self):
118145
FROM
119146
workflow_job FINAL
120147
WHERE
121-
workflow_name IN {workflow_names:Array(String)}
122-
AND head_branch = 'main'
123-
AND created_at >= {lookback_time:DateTime}
148+
workflow_name IN {{workflow_names:Array(String)}}
149+
AND {base_where}
150+
AND created_at >= {{lookback_time:DateTime}}
124151
AND dynamoKey LIKE 'pytorch/pytorch/%'
125-
AND workflow_event != 'workflow_dispatch' -- Exclude restart jobs
126152
ORDER BY
127153
workflow_name, workflow_created_at DESC, head_sha, name
128154
"""
@@ -221,7 +247,7 @@ def _find_last_commit_with_job(
221247
job_results = []
222248
for commit in commits:
223249
for job in commit.jobs:
224-
if job.name.split("(")[0] == job_name: # Normalize job name
250+
if commit.normalize_job_name(job.name) == job_name:
225251
job_results.append(job)
226252
if job_results:
227253
return (
@@ -230,6 +256,25 @@ def _find_last_commit_with_job(
230256
)
231257
return None, None
232258

259+
def _find_last_commit_with_rule(
260+
self, commits: Iterable[CommitJobs], rule: str, failures_only: bool = True
261+
) -> Optional[Tuple[CommitJobs, List[JobResult]]]:
262+
"""Find the first commit (per iteration order) that has any job with the given rule.
263+
264+
If failures_only is True, only consider jobs with conclusion == 'failure'.
265+
"""
266+
job_results = []
267+
for commit in commits:
268+
job_results = [
269+
j
270+
for j in commit.jobs
271+
if j.classification_rule == rule
272+
and (j.conclusion == "failure" if failures_only else True)
273+
]
274+
if job_results:
275+
return commit, job_results
276+
return None, None
277+
233278
def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
234279
"""
235280
Detect all autorevert patterns in commit job data for a specific workflow.
@@ -257,10 +302,11 @@ def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
257302
if suspected_commit1.has_pending_jobs:
258303
continue
259304

305+
# Primary path: use classified failures (highest precision)
260306
suspected_failures = {
261307
(
262308
j.classification_rule,
263-
j.name.split("(")[0],
309+
suspected_commit1.normalize_job_name(j.name),
264310
)
265311
for j in suspected_commit1.failed_jobs
266312
}
@@ -282,15 +328,35 @@ def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
282328
suspected_failure_job_name,
283329
)
284330
)
285-
if not newer_commit_same_job or not newer_same_jobs:
286-
# No newer commit with the same job found
287-
continue
288331

289-
if any(
290-
j.classification_rule == suspected_failure_class_rule
291-
for j in newer_same_jobs
332+
# Fallback: allow rule match across any job (failure can jump jobs)
333+
if (
334+
not newer_commit_same_job
335+
or not newer_same_jobs
336+
or not any(
337+
j.classification_rule == suspected_failure_class_rule
338+
and j.conclusion == "failure"
339+
for j in newer_same_jobs
340+
)
292341
):
293-
# The newer commit has the same job failing
342+
newer_commit_same_job, newer_same_jobs = (
343+
self._find_last_commit_with_rule(
344+
(commits[j] for j in range(i - 1, -1, -1)),
345+
suspected_failure_class_rule,
346+
failures_only=True,
347+
)
348+
)
349+
350+
if (
351+
newer_commit_same_job
352+
and newer_same_jobs
353+
and any(
354+
j.classification_rule == suspected_failure_class_rule
355+
and j.conclusion == "failure"
356+
for j in newer_same_jobs
357+
)
358+
):
359+
# The newer commit has the same failure (job may differ)
294360
failure_key = (
295361
suspected_failure_class_rule,
296362
suspected_failure_job_name,
@@ -310,26 +376,72 @@ def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
310376
)
311377
)
312378

379+
# Fallback: allow rule match across any job when searching older commit's coverage
380+
if not last_commit_with_same_job or not last_same_jobs:
381+
last_commit_with_same_job, last_same_jobs = (
382+
self._find_last_commit_with_rule(
383+
(commits[j] for j in range(i + 1, len(commits))),
384+
failure_rule,
385+
failures_only=True,
386+
)
387+
)
388+
313389
if not last_commit_with_same_job or not last_same_jobs:
314-
# No older commit with the same job found
390+
# No older commit with any jobs found
391+
continue
392+
393+
if any(
394+
j.classification_rule == failure_rule and j.conclusion == "failure"
395+
for j in last_same_jobs
396+
):
397+
# The older commit already has the same failure (regardless of job)
315398
continue
316399

317-
if any(j.classification_rule == failure_rule for j in last_same_jobs):
318-
# The older commit has the same job failing with same rule
400+
# Ensure there is some overlap in job coverage between suspected and older commit
401+
older_coverage = list(
402+
set(suspected_commit1.job_base_names)
403+
& set(last_commit_with_same_job.job_base_names)
404+
)
405+
if not older_coverage:
406+
# No overlapping jobs -> insufficient comparable signal
407+
continue
408+
409+
# Cross-workflow baseline check: if multiple workflows were provided,
410+
# ensure the older_commit does NOT have the same failure in any sibling workflow.
411+
older_sha = last_commit_with_same_job.head_sha
412+
conflict_in_other_wf = False
413+
if len(self.workflow_names) > 1:
414+
for other_wf in self.workflow_names:
415+
if other_wf == workflow_name:
416+
continue
417+
other_cj = self.get_workflow_commit_by_sha(other_wf, older_sha)
418+
if other_cj and any(
419+
j.classification_rule == failure_rule for j in other_cj.jobs
420+
):
421+
conflict_in_other_wf = True
422+
break
423+
424+
if conflict_in_other_wf:
425+
# Skip this pattern; baseline is not clean across workflows
319426
continue
320427

321428
patterns.append(
322429
{
323430
"pattern_detected": True,
324431
"workflow_name": workflow_name,
325432
"failure_rule": failure_rule,
433+
"job_name_base": job_name,
326434
"newer_commits": [
327435
newer_commit.head_sha,
328436
suspected_commit1.head_sha,
329437
],
330-
"older_commit": last_commit_with_same_job.head_sha,
331-
"failed_job_names": [j.name for j in last_same_jobs],
332-
"older_job_coverage": [],
438+
"older_commit": older_sha,
439+
"failed_job_names": [
440+
j.name
441+
for j in suspected_commit1.failed_jobs
442+
if j.classification_rule == failure_rule
443+
][:10],
444+
"older_job_coverage": older_coverage[:10],
333445
}
334446
)
335447
break
@@ -378,6 +490,119 @@ def detect_autorevert_pattern(self) -> List[Dict]:
378490

379491
return all_patterns
380492

493+
def _fetch_single_commit_jobs(
494+
self,
495+
workflow_name: str,
496+
head_sha: str,
497+
restarted_only: bool = False,
498+
) -> Optional[CommitJobs]:
499+
"""Fetch jobs for a single workflow+commit, optionally only restarted runs.
500+
501+
Groups all jobs by head_sha (assumes at most one restart dispatch of interest).
502+
Returns CommitJobs or None if no jobs found in lookback window.
503+
"""
504+
lookback_time = datetime.now() - timedelta(hours=self.lookback_hours)
505+
506+
where_event = (
507+
"workflow_event = {we:String} AND head_branch LIKE {hb:String}"
508+
if restarted_only
509+
else "workflow_event != {we:String} AND head_branch = {hb:String}"
510+
)
511+
512+
query = f"""
513+
SELECT
514+
head_sha,
515+
name,
516+
conclusion,
517+
status,
518+
torchci_classification.rule AS classification_rule,
519+
created_at AS workflow_created_at
520+
FROM workflow_job FINAL
521+
WHERE workflow_name = {{workflow_name:String}}
522+
AND head_sha = {{head_sha:String}}
523+
AND {where_event}
524+
AND created_at >= {{lookback_time:DateTime}}
525+
AND dynamoKey LIKE 'pytorch/pytorch/%'
526+
ORDER BY workflow_created_at DESC, name
527+
"""
528+
529+
hb = "trunk/%" if restarted_only else "main"
530+
we = "workflow_dispatch" if restarted_only else "workflow_dispatch"
531+
# Note: for non-restarted we exclude workflow_dispatch via != in WHERE above
532+
533+
result = CHCliFactory().client.query(
534+
query,
535+
parameters={
536+
"workflow_name": workflow_name,
537+
"head_sha": head_sha,
538+
"we": we,
539+
"hb": hb,
540+
"lookback_time": lookback_time,
541+
},
542+
)
543+
544+
rows = list(result.result_rows)
545+
if not rows:
546+
return None
547+
548+
# Use the newest created_at among returned rows as the commit's created_at marker
549+
latest_created = max(r[5] for r in rows)
550+
cj = CommitJobs(head_sha=head_sha, created_at=latest_created, jobs=[])
551+
for row in rows:
552+
_, name, conclusion, status, classification_rule, created_at = row
553+
cj.jobs.append(
554+
JobResult(
555+
head_sha=head_sha,
556+
name=name,
557+
conclusion=conclusion,
558+
status=status,
559+
classification_rule=classification_rule or "",
560+
workflow_created_at=created_at,
561+
)
562+
)
563+
return cj
564+
565+
def confirm_commit_caused_failure_on_restarted(self, pattern: Dict) -> bool:
566+
"""Confirm commit-caused failure using restarted runs.
567+
568+
Requires that:
569+
- first failing commit's restarted run has the same failure classification for the job
570+
- previous commit's restarted run does NOT have that failure classification for the job
571+
- both restarted runs have no pending jobs
572+
"""
573+
workflow_name = pattern["workflow_name"]
574+
job_base = pattern.get("job_name_base")
575+
failure_rule = pattern["failure_rule"]
576+
first_failing = pattern["newer_commits"][1]
577+
previous_commit = pattern["older_commit"]
578+
579+
# Fetch restarted jobs for first failing and previous commits
580+
failing_jobs = self._fetch_single_commit_jobs(
581+
workflow_name, first_failing, restarted_only=True
582+
)
583+
prev_jobs = self._fetch_single_commit_jobs(
584+
workflow_name, previous_commit, restarted_only=True
585+
)
586+
if not failing_jobs or not prev_jobs:
587+
return False
588+
589+
# Pending check
590+
if failing_jobs.has_pending_jobs or prev_jobs.has_pending_jobs:
591+
return False
592+
593+
def has_rule(cj: CommitJobs, rule: str) -> bool:
594+
return any(
595+
cj.normalize_job_name(j.name) == job_base
596+
and j.classification_rule == rule
597+
and j.conclusion == "failure"
598+
for j in cj.jobs
599+
)
600+
601+
# Commit-caused if failing commit reproduces, previous does not
602+
return has_rule(failing_jobs, failure_rule) and not has_rule(
603+
prev_jobs, failure_rule
604+
)
605+
381606
def get_commits_reverted(self) -> Set[str]:
382607
"""
383608
Get all commits that were reverted within the lookback window.

0 commit comments

Comments
 (0)