21
21
from .utils import RestartAction , RetryWithBackoff , RevertAction
22
22
23
23
24
- DEFAULT_WORKFLOWS = ["Lint" , "trunk" , "pull" , "inductor" , "linux-aarch64" ]
25
- DEFAULT_REPO_FULL_NAME = "pytorch/pytorch"
26
- DEFAULT_HOURS = 16
27
- DEFAULT_COMMENT_ISSUE_NUMBER = (
28
- 163650 # https://github.com/pytorch/pytorch/issues/163650
29
- )
30
24
# Special constant to indicate --hud-html was passed as a flag (without a value)
31
25
HUD_HTML_NO_VALUE_FLAG = object ()
32
26
33
27
28
+ class DefaultConfig :
29
+ def __init__ (self ):
30
+ self .bisection_limit = (
31
+ int (os .environ ["BISECTION_LIMIT" ])
32
+ if "BISECTION_LIMIT" in os .environ
33
+ else None
34
+ )
35
+ self .clickhouse_database = os .environ .get ("CLICKHOUSE_DATABASE" , "default" )
36
+ self .clickhouse_host = os .environ .get ("CLICKHOUSE_HOST" , "localhost" )
37
+ self .clickhouse_password = os .environ .get ("CLICKHOUSE_PASSWORD" , "" )
38
+ self .clickhouse_port = int (os .environ .get ("CLICKHOUSE_PORT" , 8443 ))
39
+ self .clickhouse_username = os .environ .get ("CLICKHOUSE_USERNAME" , "" )
40
+ self .github_access_token = os .environ .get ("GITHUB_TOKEN" , "" )
41
+ self .github_app_id = os .environ .get ("GITHUB_APP_ID" , "" )
42
+ self .github_app_secret = os .environ .get ("GITHUB_APP_SECRET" , "" )
43
+ self .github_installation_id = os .environ .get ("GITHUB_INSTALLATION_ID" , "" )
44
+ self .hours = int (os .environ .get ("HOURS" , 16 ))
45
+ self .log_level = os .environ .get ("LOG_LEVEL" , "INFO" )
46
+ self .notify_issue_number = int (
47
+ os .environ .get ("NOTIFY_ISSUE_NUMBER" , 163650 )
48
+ ) # https://github.com/pytorch/pytorch/issues/163650
49
+ self .repo_full_name = os .environ .get ("REPO_FULL_NAME" , "pytorch/pytorch" )
50
+ self .restart_action = (
51
+ RestartAction .from_str (os .environ ["RESTART_ACTION" ])
52
+ if "RESTART_ACTION" in os .environ
53
+ else None
54
+ )
55
+ self .revert_action = (
56
+ RevertAction .from_str (os .environ ["REVERT_ACTION" ])
57
+ if "REVERT_ACTION" in os .environ
58
+ else None
59
+ )
60
+ self .secret_store_name = os .environ .get ("SECRET_STORE_NAME" , "" )
61
+ self .workflows = os .environ .get (
62
+ "WORKFLOWS" ,
63
+ "," .join (["Lint" , "trunk" , "pull" , "inductor" , "linux-aarch64" ]),
64
+ ).split ("," )
65
+
66
+ def to_autorevert_v2_params (
67
+ self ,
68
+ * ,
69
+ default_restart_action : RestartAction ,
70
+ default_revert_action : RevertAction ,
71
+ dry_run : bool ,
72
+ ) -> dict :
73
+ """Convert the configuration to a dictionary."""
74
+ return {
75
+ "workflows" : self .workflows ,
76
+ "repo_full_name" : self .repo_full_name ,
77
+ "hours" : self .hours ,
78
+ "notify_issue_number" : self .notify_issue_number ,
79
+ "restart_action" : RestartAction .LOG
80
+ if dry_run
81
+ else (self .restart_action or default_restart_action ),
82
+ "revert_action" : RevertAction .LOG
83
+ if dry_run
84
+ else (self .revert_action or default_revert_action ),
85
+ "bisection_limit" : self .bisection_limit ,
86
+ }
87
+
88
+
89
+ def validate_actions_dry_run (
90
+ opts : argparse .Namespace , default_config : DefaultConfig
91
+ ) -> None :
92
+ """Validate the actions to be taken in dry run mode."""
93
+ if (
94
+ default_config .restart_action is not None
95
+ or default_config .revert_action is not None
96
+ ) and opts .dry_run :
97
+ logging .error (
98
+ "Dry run mode: using dry-run flag with environment variables is not allowed."
99
+ )
100
+ raise ValueError (
101
+ "Conflicting options: --dry-run with explicit actions via environment variables"
102
+ )
103
+ if (
104
+ opts .subcommand == "autorevert-checker"
105
+ and (opts .restart_action is not None or opts .revert_action is not None )
106
+ and opts .dry_run
107
+ ):
108
+ logging .error (
109
+ "Dry run mode: using dry-run flag with explicit actions is not allowed."
110
+ )
111
+ raise ValueError ("Conflicting options: --dry-run with explicit actions" )
112
+
113
+
34
114
def setup_logging (log_level : str ) -> None :
35
115
"""Set up logging configuration."""
36
116
numeric_level = getattr (logging , log_level .upper (), None )
@@ -54,45 +134,41 @@ def setup_logging(log_level: str) -> None:
54
134
handler .setLevel (numeric_level )
55
135
56
136
57
- def get_opts () -> argparse .Namespace :
137
+ def get_opts (default_config : DefaultConfig ) -> argparse .Namespace :
58
138
parser = argparse .ArgumentParser ()
59
139
60
140
# General options and configurations
61
141
parser .add_argument (
62
142
"--log-level" ,
63
- default = os . environ . get ( "LOG_LEVEL" , "INFO" ) ,
143
+ default = default_config . log_level ,
64
144
choices = ["NOTSET" , "DEBUG" , "INFO" , "WARNING" , "ERROR" , "CRITICAL" ],
65
145
help = "Set the logging level for the application." ,
66
146
)
67
- parser .add_argument (
68
- "--clickhouse-host" , default = os .environ .get ("CLICKHOUSE_HOST" , "" )
69
- )
147
+ parser .add_argument ("--clickhouse-host" , default = default_config .clickhouse_host )
70
148
parser .add_argument (
71
149
"--clickhouse-port" ,
72
150
type = int ,
73
- default = int ( os . environ . get ( "CLICKHOUSE_PORT" , "8443" )) ,
151
+ default = default_config . clickhouse_port ,
74
152
)
75
153
parser .add_argument (
76
- "--clickhouse-username" , default = os . environ . get ( "CLICKHOUSE_USERNAME" , "" )
154
+ "--clickhouse-username" , default = default_config . clickhouse_username
77
155
)
78
156
parser .add_argument (
79
- "--clickhouse-password" , default = os . environ . get ( "CLICKHOUSE_PASSWORD" , "" )
157
+ "--clickhouse-password" , default = default_config . clickhouse_password
80
158
)
81
159
parser .add_argument (
82
160
"--clickhouse-database" ,
83
- default = os .environ .get ("CLICKHOUSE_DATABASE" , "default" ),
84
- )
85
- parser .add_argument (
86
- "--github-access-token" , default = os .environ .get ("GITHUB_TOKEN" , "" )
161
+ default = default_config .clickhouse_database ,
87
162
)
88
- parser .add_argument ("--github-app-id" , default = os .environ .get ("GITHUB_APP_ID" , "" ))
89
163
parser .add_argument (
90
- "--github-app-secret " , default = os . environ . get ( "GITHUB_APP_SECRET" , "" )
164
+ "--github-access-token " , default = default_config . github_access_token
91
165
)
166
+ parser .add_argument ("--github-app-id" , default = default_config .github_app_id )
167
+ parser .add_argument ("--github-app-secret" , default = default_config .github_app_secret )
92
168
parser .add_argument (
93
169
"--github-installation-id" ,
94
170
type = int ,
95
- default = int ( os . environ . get ( "GITHUB_INSTALLATION_ID" , "0" )) ,
171
+ default = default_config . github_installation_id ,
96
172
)
97
173
parser .add_argument (
98
174
"--dry-run" ,
@@ -102,7 +178,7 @@ def get_opts() -> argparse.Namespace:
102
178
parser .add_argument (
103
179
"--secret-store-name" ,
104
180
action = "store" ,
105
- default = os . environ . get ( "SECRET_STORE_NAME" , "" ) ,
181
+ default = default_config . secret_store_name ,
106
182
help = "Name of the secret in AWS Secrets Manager to fetch GitHub App secret from" ,
107
183
)
108
184
@@ -117,41 +193,38 @@ def get_opts() -> argparse.Namespace:
117
193
workflow_parser .add_argument (
118
194
"workflows" ,
119
195
nargs = "+" ,
120
- default = DEFAULT_WORKFLOWS ,
196
+ default = default_config . workflows ,
121
197
help = "Workflow name(s) to analyze - single name or comma/space separated"
122
198
+ ' list (e.g., "pull" or "pull,trunk,inductor")' ,
123
199
)
124
200
workflow_parser .add_argument (
125
201
"--hours" ,
126
202
type = int ,
127
- default = DEFAULT_HOURS ,
128
- help = f"Lookback window in hours (default: { DEFAULT_HOURS } )" ,
203
+ default = default_config . hours ,
204
+ help = f"Lookback window in hours (default: { default_config . hours } )" ,
129
205
)
130
206
workflow_parser .add_argument (
131
207
"--repo-full-name" ,
132
- default = os . environ . get ( "REPO_FULL_NAME" , DEFAULT_REPO_FULL_NAME ) ,
208
+ default = default_config . repo_full_name ,
133
209
help = "Full repo name to filter by (owner/repo)." ,
134
210
)
135
211
workflow_parser .add_argument (
136
212
"--restart-action" ,
137
213
type = RestartAction .from_str ,
138
- default = RestartAction .from_str (
139
- os .environ .get ("RESTART_ACTION" , RestartAction .RUN )
140
- ),
214
+ default = default_config .restart_action ,
141
215
choices = list (RestartAction ),
142
216
help = (
143
- "Restart mode: skip (no logging), log (no side effects), or run (dispatch)."
217
+ "Restart mode: skip (no logging), log (no side effects), or run (dispatch). Default is run. "
144
218
),
145
219
)
146
220
workflow_parser .add_argument (
147
221
"--revert-action" ,
148
222
type = RevertAction .from_str ,
149
- default = RevertAction .from_str (
150
- os .environ .get ("REVERT_ACTION" , RevertAction .LOG )
151
- ),
223
+ default = default_config .revert_action ,
152
224
choices = list (RevertAction ),
153
225
help = (
154
- "Revert mode: skip, log (no side effects), run-log (prod-style logging), run-notify, or run-revert."
226
+ "Revert mode: skip, log (no side effects), run-log (prod-style logging), run-notify, or "
227
+ "run-revert. Default is log."
155
228
),
156
229
)
157
230
workflow_parser .add_argument (
@@ -166,22 +239,16 @@ def get_opts() -> argparse.Namespace:
166
239
workflow_parser .add_argument (
167
240
"--bisection-limit" ,
168
241
type = int ,
169
- default = (
170
- int (os .environ ["BISECTION_LIMIT" ])
171
- if os .environ .get ("BISECTION_LIMIT" , "" ).strip ()
172
- else None
173
- ),
242
+ default = default_config .bisection_limit ,
174
243
help = (
175
244
"Max new pending jobs to schedule per signal to cover gaps (None = unlimited)."
176
245
),
177
246
)
178
247
workflow_parser .add_argument (
179
248
"--notify-issue-number" ,
180
249
type = int ,
181
- default = int (
182
- os .environ .get ("NOTIFY_ISSUE_NUMBER" , DEFAULT_COMMENT_ISSUE_NUMBER )
183
- ),
184
- help = f"Issue number to notify (default: { DEFAULT_COMMENT_ISSUE_NUMBER } )" ,
250
+ default = default_config .notify_issue_number ,
251
+ help = "Issue number to notify" ,
185
252
)
186
253
187
254
# workflow-restart-checker subcommand
@@ -265,7 +332,8 @@ def get_secret_from_aws(secret_store_name: str) -> AWSSecretsFromStore:
265
332
266
333
def main (* args , ** kwargs ) -> None :
267
334
load_dotenv ()
268
- opts = get_opts ()
335
+ default_config = DefaultConfig ()
336
+ opts = get_opts (default_config )
269
337
270
338
gh_app_secret = ""
271
339
if opts .github_app_secret :
@@ -299,51 +367,39 @@ def main(*args, **kwargs) -> None:
299
367
)
300
368
301
369
if opts .subcommand is None :
302
- repo_name = os .environ .get ("REPO_FULL_NAME" , DEFAULT_REPO_FULL_NAME )
303
-
304
- if check_autorevert_disabled (repo_name ):
370
+ if check_autorevert_disabled (default_config .repo_full_name ):
305
371
logging .error (
306
372
"Autorevert is disabled via circuit breaker (ci: disable-autorevert issue found). "
307
373
"Exiting successfully."
308
374
)
309
375
return
310
376
311
- # Read env-driven defaults for the lambda path
312
- _bis_env = os .environ .get ("BISECTION_LIMIT" , "" ).strip ()
313
- _bis_limit = int (_bis_env ) if _bis_env else None
377
+ validate_actions_dry_run (opts , default_config )
314
378
315
379
autorevert_v2 (
316
- os .environ .get ("WORKFLOWS" , "," .join (DEFAULT_WORKFLOWS )).split ("," ),
317
- hours = int (os .environ .get ("HOURS" , DEFAULT_HOURS )),
318
- notify_issue_number = int (
319
- os .environ .get ("NOTIFY_ISSUE_NUMBER" , DEFAULT_COMMENT_ISSUE_NUMBER )
320
- ),
321
- repo_full_name = repo_name ,
380
+ ** default_config .to_autorevert_v2_params (
381
+ default_restart_action = RestartAction .RUN ,
382
+ default_revert_action = RevertAction .RUN_NOTIFY ,
383
+ dry_run = opts .dry_run ,
384
+ )
385
+ )
386
+ elif opts .subcommand == "autorevert-checker" :
387
+ validate_actions_dry_run (opts , default_config )
388
+ _ , _ , state_json = autorevert_v2 (
389
+ opts .workflows ,
390
+ hours = opts .hours ,
391
+ notify_issue_number = opts .notify_issue_number ,
392
+ repo_full_name = opts .repo_full_name ,
322
393
restart_action = (
323
394
RestartAction .LOG
324
395
if opts .dry_run
325
- else RestartAction .from_str (
326
- os .environ .get ("RESTART_ACTION" , RestartAction .RUN )
327
- )
396
+ else (opts .restart_action or RestartAction .RUN )
328
397
),
329
398
revert_action = (
330
399
RevertAction .LOG
331
400
if opts .dry_run
332
- else RevertAction .from_str (
333
- os .environ .get ("REVERT_ACTION" , RevertAction .RUN_NOTIFY )
334
- )
401
+ else (opts .revert_action or RevertAction .LOG )
335
402
),
336
- bisection_limit = _bis_limit ,
337
- )
338
- elif opts .subcommand == "autorevert-checker" :
339
- # New default behavior under the same subcommand
340
- _ , _ , state_json = autorevert_v2 (
341
- opts .workflows ,
342
- hours = opts .hours ,
343
- notify_issue_number = opts .notify_issue_number ,
344
- repo_full_name = opts .repo_full_name ,
345
- restart_action = (RestartAction .LOG if opts .dry_run else opts .restart_action ),
346
- revert_action = (RevertAction .LOG if opts .dry_run else opts .revert_action ),
347
403
bisection_limit = opts .bisection_limit ,
348
404
)
349
405
write_hud_html_from_cli (opts .hud_html , HUD_HTML_NO_VALUE_FLAG , state_json )
0 commit comments