Skip to content

Commit 210855a

Browse files
authored
[AUTOREVERT] Improvements on CLI interface (#7297)
In order to avoid any confusion, and adding a few improvements after discussion, I made significant changes to the CLI parsing for autorevert: 1 - Adding the proper sequence parsing for CLI parameters: env variables < explicit cli options; 2 - To clear out any possible misunderstanding on how to use the CLI, --dry-run option is now *explicitly incompatible* with providing actions specifications via either environment variables or cli options; 3 - As much as possible (in between reason) the code for parsing environment variables and making decisions have been centralised and reused; --------- Signed-off-by: Jean Schmidt <[email protected]>
1 parent 38bb33b commit 210855a

File tree

4 files changed

+135
-75
lines changed

4 files changed

+135
-75
lines changed

aws/lambda/pytorch-auto-revert/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ run-local: venv/bin/python
2323

2424
.PHONY: run-local-workflows
2525
run-local-workflows: venv/bin/python
26-
NOTIFY_ISSUE_NUMBER=265 REPO_FULL_NAME=pytorch/pytorch-canary venv/bin/python -m pytorch_auto_revert --dry-run autorevert-checker Lint trunk pull inductor linux-binary-manywheel --hours 8 --revert-action log
26+
NOTIFY_ISSUE_NUMBER=265 REPO_FULL_NAME=pytorch/pytorch-canary venv/bin/python -m pytorch_auto_revert --dry-run autorevert-checker Lint trunk pull inductor linux-binary-manywheel --hours 8
2727

2828
.PHONY: run-local-hud
2929
run-local-hud: venv/bin/python

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py

Lines changed: 129 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,96 @@
2121
from .utils import RestartAction, RetryWithBackoff, RevertAction
2222

2323

24-
DEFAULT_WORKFLOWS = ["Lint", "trunk", "pull", "inductor", "linux-aarch64"]
25-
DEFAULT_REPO_FULL_NAME = "pytorch/pytorch"
26-
DEFAULT_HOURS = 16
27-
DEFAULT_COMMENT_ISSUE_NUMBER = (
28-
163650 # https://github.com/pytorch/pytorch/issues/163650
29-
)
3024
# Special constant to indicate --hud-html was passed as a flag (without a value)
3125
HUD_HTML_NO_VALUE_FLAG = object()
3226

3327

28+
class DefaultConfig:
29+
def __init__(self):
30+
self.bisection_limit = (
31+
int(os.environ["BISECTION_LIMIT"])
32+
if "BISECTION_LIMIT" in os.environ
33+
else None
34+
)
35+
self.clickhouse_database = os.environ.get("CLICKHOUSE_DATABASE", "default")
36+
self.clickhouse_host = os.environ.get("CLICKHOUSE_HOST", "localhost")
37+
self.clickhouse_password = os.environ.get("CLICKHOUSE_PASSWORD", "")
38+
self.clickhouse_port = int(os.environ.get("CLICKHOUSE_PORT", 8443))
39+
self.clickhouse_username = os.environ.get("CLICKHOUSE_USERNAME", "")
40+
self.github_access_token = os.environ.get("GITHUB_TOKEN", "")
41+
self.github_app_id = os.environ.get("GITHUB_APP_ID", "")
42+
self.github_app_secret = os.environ.get("GITHUB_APP_SECRET", "")
43+
self.github_installation_id = os.environ.get("GITHUB_INSTALLATION_ID", "")
44+
self.hours = int(os.environ.get("HOURS", 16))
45+
self.log_level = os.environ.get("LOG_LEVEL", "INFO")
46+
self.notify_issue_number = int(
47+
os.environ.get("NOTIFY_ISSUE_NUMBER", 163650)
48+
) # https://github.com/pytorch/pytorch/issues/163650
49+
self.repo_full_name = os.environ.get("REPO_FULL_NAME", "pytorch/pytorch")
50+
self.restart_action = (
51+
RestartAction.from_str(os.environ["RESTART_ACTION"])
52+
if "RESTART_ACTION" in os.environ
53+
else None
54+
)
55+
self.revert_action = (
56+
RevertAction.from_str(os.environ["REVERT_ACTION"])
57+
if "REVERT_ACTION" in os.environ
58+
else None
59+
)
60+
self.secret_store_name = os.environ.get("SECRET_STORE_NAME", "")
61+
self.workflows = os.environ.get(
62+
"WORKFLOWS",
63+
",".join(["Lint", "trunk", "pull", "inductor", "linux-aarch64"]),
64+
).split(",")
65+
66+
def to_autorevert_v2_params(
67+
self,
68+
*,
69+
default_restart_action: RestartAction,
70+
default_revert_action: RevertAction,
71+
dry_run: bool,
72+
) -> dict:
73+
"""Convert the configuration to a dictionary."""
74+
return {
75+
"workflows": self.workflows,
76+
"repo_full_name": self.repo_full_name,
77+
"hours": self.hours,
78+
"notify_issue_number": self.notify_issue_number,
79+
"restart_action": RestartAction.LOG
80+
if dry_run
81+
else (self.restart_action or default_restart_action),
82+
"revert_action": RevertAction.LOG
83+
if dry_run
84+
else (self.revert_action or default_revert_action),
85+
"bisection_limit": self.bisection_limit,
86+
}
87+
88+
89+
def validate_actions_dry_run(
90+
opts: argparse.Namespace, default_config: DefaultConfig
91+
) -> None:
92+
"""Validate the actions to be taken in dry run mode."""
93+
if (
94+
default_config.restart_action is not None
95+
or default_config.revert_action is not None
96+
) and opts.dry_run:
97+
logging.error(
98+
"Dry run mode: using dry-run flag with environment variables is not allowed."
99+
)
100+
raise ValueError(
101+
"Conflicting options: --dry-run with explicit actions via environment variables"
102+
)
103+
if (
104+
opts.subcommand == "autorevert-checker"
105+
and (opts.restart_action is not None or opts.revert_action is not None)
106+
and opts.dry_run
107+
):
108+
logging.error(
109+
"Dry run mode: using dry-run flag with explicit actions is not allowed."
110+
)
111+
raise ValueError("Conflicting options: --dry-run with explicit actions")
112+
113+
34114
def setup_logging(log_level: str) -> None:
35115
"""Set up logging configuration."""
36116
numeric_level = getattr(logging, log_level.upper(), None)
@@ -54,45 +134,41 @@ def setup_logging(log_level: str) -> None:
54134
handler.setLevel(numeric_level)
55135

56136

57-
def get_opts() -> argparse.Namespace:
137+
def get_opts(default_config: DefaultConfig) -> argparse.Namespace:
58138
parser = argparse.ArgumentParser()
59139

60140
# General options and configurations
61141
parser.add_argument(
62142
"--log-level",
63-
default=os.environ.get("LOG_LEVEL", "INFO"),
143+
default=default_config.log_level,
64144
choices=["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
65145
help="Set the logging level for the application.",
66146
)
67-
parser.add_argument(
68-
"--clickhouse-host", default=os.environ.get("CLICKHOUSE_HOST", "")
69-
)
147+
parser.add_argument("--clickhouse-host", default=default_config.clickhouse_host)
70148
parser.add_argument(
71149
"--clickhouse-port",
72150
type=int,
73-
default=int(os.environ.get("CLICKHOUSE_PORT", "8443")),
151+
default=default_config.clickhouse_port,
74152
)
75153
parser.add_argument(
76-
"--clickhouse-username", default=os.environ.get("CLICKHOUSE_USERNAME", "")
154+
"--clickhouse-username", default=default_config.clickhouse_username
77155
)
78156
parser.add_argument(
79-
"--clickhouse-password", default=os.environ.get("CLICKHOUSE_PASSWORD", "")
157+
"--clickhouse-password", default=default_config.clickhouse_password
80158
)
81159
parser.add_argument(
82160
"--clickhouse-database",
83-
default=os.environ.get("CLICKHOUSE_DATABASE", "default"),
84-
)
85-
parser.add_argument(
86-
"--github-access-token", default=os.environ.get("GITHUB_TOKEN", "")
161+
default=default_config.clickhouse_database,
87162
)
88-
parser.add_argument("--github-app-id", default=os.environ.get("GITHUB_APP_ID", ""))
89163
parser.add_argument(
90-
"--github-app-secret", default=os.environ.get("GITHUB_APP_SECRET", "")
164+
"--github-access-token", default=default_config.github_access_token
91165
)
166+
parser.add_argument("--github-app-id", default=default_config.github_app_id)
167+
parser.add_argument("--github-app-secret", default=default_config.github_app_secret)
92168
parser.add_argument(
93169
"--github-installation-id",
94170
type=int,
95-
default=int(os.environ.get("GITHUB_INSTALLATION_ID", "0")),
171+
default=default_config.github_installation_id,
96172
)
97173
parser.add_argument(
98174
"--dry-run",
@@ -102,7 +178,7 @@ def get_opts() -> argparse.Namespace:
102178
parser.add_argument(
103179
"--secret-store-name",
104180
action="store",
105-
default=os.environ.get("SECRET_STORE_NAME", ""),
181+
default=default_config.secret_store_name,
106182
help="Name of the secret in AWS Secrets Manager to fetch GitHub App secret from",
107183
)
108184

@@ -117,41 +193,38 @@ def get_opts() -> argparse.Namespace:
117193
workflow_parser.add_argument(
118194
"workflows",
119195
nargs="+",
120-
default=DEFAULT_WORKFLOWS,
196+
default=default_config.workflows,
121197
help="Workflow name(s) to analyze - single name or comma/space separated"
122198
+ ' list (e.g., "pull" or "pull,trunk,inductor")',
123199
)
124200
workflow_parser.add_argument(
125201
"--hours",
126202
type=int,
127-
default=DEFAULT_HOURS,
128-
help=f"Lookback window in hours (default: {DEFAULT_HOURS})",
203+
default=default_config.hours,
204+
help=f"Lookback window in hours (default: {default_config.hours})",
129205
)
130206
workflow_parser.add_argument(
131207
"--repo-full-name",
132-
default=os.environ.get("REPO_FULL_NAME", DEFAULT_REPO_FULL_NAME),
208+
default=default_config.repo_full_name,
133209
help="Full repo name to filter by (owner/repo).",
134210
)
135211
workflow_parser.add_argument(
136212
"--restart-action",
137213
type=RestartAction.from_str,
138-
default=RestartAction.from_str(
139-
os.environ.get("RESTART_ACTION", RestartAction.RUN)
140-
),
214+
default=default_config.restart_action,
141215
choices=list(RestartAction),
142216
help=(
143-
"Restart mode: skip (no logging), log (no side effects), or run (dispatch)."
217+
"Restart mode: skip (no logging), log (no side effects), or run (dispatch). Default is run."
144218
),
145219
)
146220
workflow_parser.add_argument(
147221
"--revert-action",
148222
type=RevertAction.from_str,
149-
default=RevertAction.from_str(
150-
os.environ.get("REVERT_ACTION", RevertAction.LOG)
151-
),
223+
default=default_config.revert_action,
152224
choices=list(RevertAction),
153225
help=(
154-
"Revert mode: skip, log (no side effects), run-log (prod-style logging), run-notify, or run-revert."
226+
"Revert mode: skip, log (no side effects), run-log (prod-style logging), run-notify, or "
227+
"run-revert. Default is log."
155228
),
156229
)
157230
workflow_parser.add_argument(
@@ -166,22 +239,16 @@ def get_opts() -> argparse.Namespace:
166239
workflow_parser.add_argument(
167240
"--bisection-limit",
168241
type=int,
169-
default=(
170-
int(os.environ["BISECTION_LIMIT"])
171-
if os.environ.get("BISECTION_LIMIT", "").strip()
172-
else None
173-
),
242+
default=default_config.bisection_limit,
174243
help=(
175244
"Max new pending jobs to schedule per signal to cover gaps (None = unlimited)."
176245
),
177246
)
178247
workflow_parser.add_argument(
179248
"--notify-issue-number",
180249
type=int,
181-
default=int(
182-
os.environ.get("NOTIFY_ISSUE_NUMBER", DEFAULT_COMMENT_ISSUE_NUMBER)
183-
),
184-
help=f"Issue number to notify (default: {DEFAULT_COMMENT_ISSUE_NUMBER})",
250+
default=default_config.notify_issue_number,
251+
help="Issue number to notify",
185252
)
186253

187254
# workflow-restart-checker subcommand
@@ -265,7 +332,8 @@ def get_secret_from_aws(secret_store_name: str) -> AWSSecretsFromStore:
265332

266333
def main(*args, **kwargs) -> None:
267334
load_dotenv()
268-
opts = get_opts()
335+
default_config = DefaultConfig()
336+
opts = get_opts(default_config)
269337

270338
gh_app_secret = ""
271339
if opts.github_app_secret:
@@ -299,51 +367,39 @@ def main(*args, **kwargs) -> None:
299367
)
300368

301369
if opts.subcommand is None:
302-
repo_name = os.environ.get("REPO_FULL_NAME", DEFAULT_REPO_FULL_NAME)
303-
304-
if check_autorevert_disabled(repo_name):
370+
if check_autorevert_disabled(default_config.repo_full_name):
305371
logging.error(
306372
"Autorevert is disabled via circuit breaker (ci: disable-autorevert issue found). "
307373
"Exiting successfully."
308374
)
309375
return
310376

311-
# Read env-driven defaults for the lambda path
312-
_bis_env = os.environ.get("BISECTION_LIMIT", "").strip()
313-
_bis_limit = int(_bis_env) if _bis_env else None
377+
validate_actions_dry_run(opts, default_config)
314378

315379
autorevert_v2(
316-
os.environ.get("WORKFLOWS", ",".join(DEFAULT_WORKFLOWS)).split(","),
317-
hours=int(os.environ.get("HOURS", DEFAULT_HOURS)),
318-
notify_issue_number=int(
319-
os.environ.get("NOTIFY_ISSUE_NUMBER", DEFAULT_COMMENT_ISSUE_NUMBER)
320-
),
321-
repo_full_name=repo_name,
380+
**default_config.to_autorevert_v2_params(
381+
default_restart_action=RestartAction.RUN,
382+
default_revert_action=RevertAction.RUN_NOTIFY,
383+
dry_run=opts.dry_run,
384+
)
385+
)
386+
elif opts.subcommand == "autorevert-checker":
387+
validate_actions_dry_run(opts, default_config)
388+
_, _, state_json = autorevert_v2(
389+
opts.workflows,
390+
hours=opts.hours,
391+
notify_issue_number=opts.notify_issue_number,
392+
repo_full_name=opts.repo_full_name,
322393
restart_action=(
323394
RestartAction.LOG
324395
if opts.dry_run
325-
else RestartAction.from_str(
326-
os.environ.get("RESTART_ACTION", RestartAction.RUN)
327-
)
396+
else (opts.restart_action or RestartAction.RUN)
328397
),
329398
revert_action=(
330399
RevertAction.LOG
331400
if opts.dry_run
332-
else RevertAction.from_str(
333-
os.environ.get("REVERT_ACTION", RevertAction.RUN_NOTIFY)
334-
)
401+
else (opts.revert_action or RevertAction.LOG)
335402
),
336-
bisection_limit=_bis_limit,
337-
)
338-
elif opts.subcommand == "autorevert-checker":
339-
# New default behavior under the same subcommand
340-
_, _, state_json = autorevert_v2(
341-
opts.workflows,
342-
hours=opts.hours,
343-
notify_issue_number=opts.notify_issue_number,
344-
repo_full_name=opts.repo_full_name,
345-
restart_action=(RestartAction.LOG if opts.dry_run else opts.restart_action),
346-
revert_action=(RevertAction.LOG if opts.dry_run else opts.revert_action),
347403
bisection_limit=opts.bisection_limit,
348404
)
349405
write_hud_html_from_cli(opts.hud_html, HUD_HTML_NO_VALUE_FLAG, state_json)

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/testers/autorevert_v2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def autorevert_v2(
4646
)
4747
logging.info("[v2] Run timestamp (CH log ts) = %s", ts.isoformat())
4848

49+
import sys
50+
51+
sys.exit(1)
52+
4953
extractor = SignalExtractor(
5054
workflows=workflows, lookback_hours=hours, repo_full_name=repo_full_name
5155
)

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def side_effects(self) -> bool:
1717
raise NotImplementedError("Subclasses must implement this method")
1818

1919
@classmethod
20-
def from_str(cls, label: any):
20+
def from_str(cls, label: any) -> "AbstractExecAction":
2121
lower_label = str(label).lower()
2222
for member in cls:
2323
if member.value.lower() == lower_label:

0 commit comments

Comments
 (0)