Skip to content

Commit 09abb53

Browse files
authored
[autorevert] Fix state logging on revert (#7295)
fixes: ``` [ERROR] AttributeError: 'AutorevertPattern' object has no attribute 'job_base_name' Traceback (most recent call last): File "/var/task/pytorch_auto_revert/__main__.py", line 311, in main autorevert_v2( File "/var/task/pytorch_auto_revert/testers/autorevert_v2.py", line 84, in autorevert_v2 state_json = RunStateLogger().insert_state(ctx=run_ctx, pairs=pairs) File "/var/task/pytorch_auto_revert/run_state_logger.py", line 159, in insert_state doc = self._build_state_json(repo=ctx.repo_full_name, ctx=ctx, pairs=pairs) File "/var/task/pytorch_auto_revert/run_state_logger.py", line 62, in _build_state_json if outcome.job_base_name: ``` Optional field was accessed unconditionally on the outcome. Removed accessor, as field is not used. Added more explicit logging & tests.
1 parent b33428d commit 09abb53

File tree

3 files changed

+223
-4
lines changed

3 files changed

+223
-4
lines changed

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/run_state_logger.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ def _build_state_json(
5959
"older_successful_commit": outcome.older_successful_commit,
6060
"newer_failing_commits": list(outcome.newer_failing_commits),
6161
}
62-
if outcome.job_base_name:
63-
data["job_base_name"] = outcome.job_base_name
6462
if outcome.wf_run_id is not None:
6563
data["wf_run_id"] = outcome.wf_run_id
6664
if outcome.job_id is not None:

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/testers/autorevert_v2.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ def autorevert_v2(
8585
logging.info("[v2] Executed action groups: %d", executed_count)
8686

8787
# Persist full run state via separate logger
88-
state_json = RunStateLogger().insert_state(ctx=run_ctx, pairs=pairs)
89-
logging.info("[v2] State logged")
88+
try:
89+
state_json = RunStateLogger().insert_state(ctx=run_ctx, pairs=pairs)
90+
logging.info("[v2] State logged")
91+
except Exception:
92+
logging.exception("[v2] State logging failed") # capture full stack
93+
# Keep returning a JSON payload for downstream consumers
94+
state_json = "{}"
9095

9196
return signals, pairs, state_json
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import json
2+
import unittest
3+
from datetime import datetime, timedelta
4+
from unittest.mock import patch
5+
6+
from pytorch_auto_revert.run_state_logger import RunStateLogger
7+
from pytorch_auto_revert.signal import (
8+
AutorevertPattern,
9+
Ineligible,
10+
IneligibleReason,
11+
RestartCommits,
12+
Signal,
13+
SignalCommit,
14+
SignalEvent,
15+
SignalStatus,
16+
)
17+
from pytorch_auto_revert.signal_extraction_types import RunContext
18+
from pytorch_auto_revert.utils import RestartAction, RevertAction
19+
20+
21+
def ts(base: datetime, minutes: int) -> datetime:
22+
return base + timedelta(minutes=minutes)
23+
24+
25+
class TestRunStateLogger(unittest.TestCase):
26+
def setUp(self) -> None:
27+
self.t0 = datetime(2025, 9, 22, 18, 59, 14)
28+
29+
def _ev(
30+
self,
31+
name: str,
32+
status: SignalStatus,
33+
minute: int,
34+
*,
35+
wf_run_id: int = 1,
36+
job_id: int | None = None,
37+
run_attempt: int | None = None,
38+
) -> SignalEvent:
39+
return SignalEvent(
40+
name=name,
41+
status=status,
42+
started_at=ts(self.t0, minute),
43+
wf_run_id=wf_run_id,
44+
job_id=job_id,
45+
run_attempt=run_attempt,
46+
)
47+
48+
def _ctx(self, *, restart: RestartAction, revert: RevertAction) -> RunContext:
49+
return RunContext(
50+
lookback_hours=8,
51+
notify_issue_number=123,
52+
repo_full_name="owner/repo",
53+
restart_action=restart,
54+
revert_action=revert,
55+
ts=self.t0,
56+
workflows=["wf-a", "wf-b"],
57+
)
58+
59+
@patch("pytorch_auto_revert.run_state_logger.CHCliFactory")
60+
def test_build_and_insert_state_mixed_outcomes_calls_clickhouse_correctly(
61+
self, mock_factory
62+
) -> None:
63+
# Build three signals with events across commits
64+
# Revert-signal (job level)
65+
c1 = SignalCommit(
66+
head_sha="sha_new",
67+
timestamp=ts(self.t0, 12),
68+
events=[
69+
self._ev("job-a", SignalStatus.FAILURE, 12, wf_run_id=111, job_id=999)
70+
],
71+
)
72+
c2 = SignalCommit(
73+
head_sha="sha_old",
74+
timestamp=ts(self.t0, 5),
75+
events=[self._ev("job-a", SignalStatus.SUCCESS, 5)],
76+
)
77+
sig_revert = Signal(
78+
key="job-a",
79+
workflow_name="wf-a",
80+
commits=[c1, c2],
81+
job_base_name="job-a-base",
82+
)
83+
outcome_revert = AutorevertPattern(
84+
workflow_name="wf-a",
85+
newer_failing_commits=["sha_new"],
86+
suspected_commit="sha_old",
87+
older_successful_commit="sha_old",
88+
wf_run_id=111,
89+
job_id=999,
90+
)
91+
92+
# Restart-signal (e.g., uncertainty)
93+
r1 = SignalCommit(
94+
head_sha="R_sha1",
95+
timestamp=ts(self.t0, 20),
96+
events=[self._ev("job-b", SignalStatus.PENDING, 20)],
97+
)
98+
r2 = SignalCommit(
99+
head_sha="R_sha0",
100+
timestamp=ts(self.t0, 10),
101+
events=[self._ev("job-b", SignalStatus.SUCCESS, 10)],
102+
)
103+
sig_restart = Signal(
104+
key="job-b",
105+
workflow_name="wf-b",
106+
commits=[r1, r2],
107+
)
108+
outcome_restart = RestartCommits(commit_shas={"R_sha0"})
109+
110+
# Ineligible-signal (e.g., flaky)
111+
i1 = SignalCommit(
112+
head_sha="I_sha1",
113+
timestamp=ts(self.t0, 30),
114+
events=[self._ev("job-c", SignalStatus.SUCCESS, 30)],
115+
)
116+
i2 = SignalCommit(
117+
head_sha="I_sha0",
118+
timestamp=ts(self.t0, 15),
119+
events=[self._ev("job-c", SignalStatus.FAILURE, 15)],
120+
)
121+
sig_ineligible = Signal(
122+
key="job-c",
123+
workflow_name="wf-c",
124+
commits=[i1, i2],
125+
)
126+
outcome_ineligible = Ineligible(IneligibleReason.FLAKY, "flaky")
127+
128+
pairs = [
129+
(sig_revert, outcome_revert),
130+
(sig_restart, outcome_restart),
131+
(sig_ineligible, outcome_ineligible),
132+
]
133+
134+
ctx = self._ctx(restart=RestartAction.LOG, revert=RevertAction.RUN_LOG)
135+
136+
rsl = RunStateLogger()
137+
state_json = rsl.insert_state(ctx=ctx, pairs=pairs, params="k=v")
138+
139+
# Verify ClickHouse insert was called once with expected arguments
140+
self.assertTrue(mock_factory.return_value.client.insert.called)
141+
args, kwargs = mock_factory.return_value.client.insert.call_args
142+
self.assertEqual(kwargs.get("table"), "autorevert_state")
143+
self.assertEqual(kwargs.get("database"), "misc")
144+
self.assertEqual(
145+
kwargs.get("column_names"),
146+
[
147+
"ts",
148+
"repo",
149+
"state",
150+
"dry_run",
151+
"workflows",
152+
"lookback_hours",
153+
"params",
154+
],
155+
)
156+
# One row inserted
157+
data = kwargs.get("data")
158+
self.assertIsInstance(data, list)
159+
self.assertEqual(len(data), 1)
160+
row = data[0]
161+
# dry_run should be 0 because revert_action has side effects (RUN_LOG)
162+
self.assertEqual(row[3], 0)
163+
self.assertEqual(row[4], ctx.workflows)
164+
self.assertEqual(row[5], ctx.lookback_hours)
165+
self.assertEqual(row[6], "k=v")
166+
167+
# Validate state JSON contents
168+
state = json.loads(state_json)
169+
# Outcomes include all three signal types
170+
outcomes = state.get("outcomes", {})
171+
self.assertIn("wf-a:job-a", outcomes)
172+
self.assertIn("wf-b:job-b", outcomes)
173+
self.assertIn("wf-c:job-c", outcomes)
174+
self.assertEqual(outcomes["wf-a:job-a"]["type"], "AutorevertPattern")
175+
self.assertEqual(outcomes["wf-b:job-b"]["type"], "RestartCommits")
176+
self.assertEqual(outcomes["wf-c:job-c"]["type"], "Ineligible")
177+
178+
# Columns reflect per-signal outcomes
179+
cols = state.get("columns", [])
180+
self.assertEqual(len(cols), 3)
181+
# Find revert column
182+
col_revert = next(c for c in cols if c["outcome"] == "revert")
183+
self.assertEqual(col_revert["workflow"], "wf-a")
184+
self.assertEqual(col_revert["key"], "job-a")
185+
# Cells include entries for commits we provided
186+
self.assertIn("sha_new", col_revert["cells"]) # failure event on newest commit
187+
self.assertIn("sha_old", col_revert["cells"]) # success event on older commit
188+
189+
@patch("pytorch_auto_revert.run_state_logger.CHCliFactory")
190+
def test_insert_state_sets_dry_run_based_on_actions(self, mock_factory) -> None:
191+
# No side effects → dry_run=1
192+
c = SignalCommit(
193+
head_sha="S",
194+
timestamp=ts(self.t0, 0),
195+
events=[self._ev("job", SignalStatus.PENDING, 0)],
196+
)
197+
sig = Signal(key="job", workflow_name="wf", commits=[c])
198+
pairs = [(sig, Ineligible(IneligibleReason.NO_SUCCESSES, ""))]
199+
200+
ctx_dry = self._ctx(restart=RestartAction.LOG, revert=RevertAction.LOG)
201+
rsl = RunStateLogger()
202+
rsl.insert_state(ctx=ctx_dry, pairs=pairs)
203+
_, kwargs1 = mock_factory.return_value.client.insert.call_args
204+
row1 = kwargs1["data"][0]
205+
self.assertEqual(row1[3], 1) # dry_run=1
206+
207+
# With side effects → dry_run=0
208+
ctx_wet = self._ctx(restart=RestartAction.RUN, revert=RevertAction.SKIP)
209+
rsl.insert_state(ctx=ctx_wet, pairs=pairs)
210+
_, kwargs2 = mock_factory.return_value.client.insert.call_args
211+
row2 = kwargs2["data"][0]
212+
self.assertEqual(row2[3], 0) # dry_run=0
213+
214+
215+
if __name__ == "__main__":
216+
unittest.main()

0 commit comments

Comments
 (0)