@@ -106,27 +106,61 @@ def validate_commits(self, data):
106106 return False , f"More than 1 commit! { len (data ['commits' ])} "
107107 return True
108108
109- def validate_pr (self ):
109+ def _normalize_pr (self , parg : str ):
110+ if parg .isdigit ():
111+ return parg
112+ elif parg .startswith ("https://github.com/llvm/llvm-project/pull" ):
113+ # try to parse the following url https://github.com/llvm/llvm-project/pull/114089
114+ i = parg [parg .rfind ("/" ) + 1 :]
115+ if not i .isdigit ():
116+ raise RuntimeError (f"{ i } is not a number, malformatted input." )
117+ return i
118+ else :
119+ raise RuntimeError (
120+ f"PR argument must be PR ID or pull request URL - { parg } is wrong."
121+ )
122+
123+ def load_pr_data (self ):
124+ self .args .pr = self ._normalize_pr (self .args .pr )
110125 fields_to_fetch = [
111126 "baseRefName" ,
127+ "commits" ,
128+ "headRefName" ,
129+ "headRepository" ,
130+ "headRepositoryOwner" ,
112131 "reviewDecision" ,
113- "title " ,
132+ "state " ,
114133 "statusCheckRollup" ,
134+ "title" ,
115135 "url" ,
116- "state" ,
117- "commits" ,
118136 ]
137+ print (f"> Loading PR { self .args .pr } ..." )
119138 o = self .run_gh (
120139 "pr" ,
121140 ["view" , self .args .pr , "--json" , "," .join (fields_to_fetch )],
122141 )
123- prdata = json .loads (o )
142+ self . prdata = json .loads (o )
124143
125144 # save the baseRefName (target branch) so that we know where to push
126- self .target_branch = prdata ["baseRefName" ]
145+ self .target_branch = self .prdata ["baseRefName" ]
146+ srepo = self .prdata ["headRepository" ]["name" ]
147+ sowner = self .prdata ["headRepositoryOwner" ]["login" ]
148+ self .source_url = f"https://github.com/{ sowner } /{ srepo } "
149+ self .source_branch = self .prdata ["headRefName" ]
150+
151+ if srepo != "llvm-project" :
152+ print ("The target repo is NOT llvm-project, check the PR!" )
153+ sys .exit (1 )
154+
155+ if sowner == "llvm" :
156+ print (
157+ "The source owner should never be github.com/llvm, double check the PR!"
158+ )
159+ sys .exit (1 )
127160
128- print (f"> Handling PR { self .args .pr } - { prdata ['title' ]} " )
129- print (f"> { prdata ['url' ]} " )
161+ def validate_pr (self ):
162+ print (f"> Handling PR { self .args .pr } - { self .prdata ['title' ]} " )
163+ print (f"> { self .prdata ['url' ]} " )
130164
131165 VALIDATIONS = {
132166 "state" : self .validate_state ,
@@ -141,7 +175,7 @@ def validate_pr(self):
141175 total_ok = True
142176 for val_name , val_func in VALIDATIONS .items ():
143177 try :
144- validation_data = val_func (prdata )
178+ validation_data = val_func (self . prdata )
145179 except :
146180 validation_data = False
147181 ok = None
@@ -166,24 +200,42 @@ def validate_pr(self):
166200 return total_ok
167201
168202 def rebase_pr (self ):
169- print ("> Rebasing" )
170- self .run_gh ("pr" , ["update-branch" , "--rebase" , self .args .pr ])
171- print ("> Waiting for GitHub to update PR" )
172- time .sleep (4 )
203+ print ("> Fetching upstream" )
204+ subprocess .run (["git" , "fetch" , "--all" ], check = True )
205+ print ("> Rebasing..." )
206+ subprocess .run (
207+ ["git" , "rebase" , self .args .upstream + "/" + self .target_branch ], check = True
208+ )
209+ print ("> Publish rebase..." )
210+ subprocess .run (
211+ ["git" , "push" , "--force" , self .source_url , f"HEAD:{ self .source_branch } " ]
212+ )
173213
174214 def checkout_pr (self ):
175215 print ("> Fetching PR changes..." )
216+ self .merge_branch = "llvm_merger_" + self .args .pr
176217 self .run_gh (
177218 "pr" ,
178219 [
179220 "checkout" ,
180221 self .args .pr ,
181222 "--force" ,
182223 "--branch" ,
183- "llvm_merger_" + self .args . pr ,
224+ self .merge_branch ,
184225 ],
185226 )
186227
228+ # get the branch information so that we can use it for
229+ # pushing later.
230+ p = subprocess .run (
231+ ["git" , "config" , f"branch.{ self .merge_branch } .merge" ],
232+ check = True ,
233+ capture_output = True ,
234+ text = True ,
235+ )
236+ upstream_branch = p .stdout .strip ().replace ("refs/heads/" , "" )
237+ print (upstream_branch )
238+
187239 def push_upstream (self ):
188240 print ("> Pushing changes..." )
189241 subprocess .run (
@@ -201,7 +253,7 @@ def delete_local_branch(self):
201253 parser = argparse .ArgumentParser ()
202254 parser .add_argument (
203255 "pr" ,
204- help = "The Pull Request ID that should be merged into a release." ,
256+ help = "The Pull Request ID that should be merged into a release. Can be number or URL " ,
205257 )
206258 parser .add_argument (
207259 "--skip-validation" ,
@@ -224,9 +276,20 @@ def delete_local_branch(self):
224276 parser .add_argument (
225277 "--validate-only" , action = "store_true" , help = "Only run the validations."
226278 )
279+ parser .add_argument (
280+ "--rebase-only" , action = "store_true" , help = "Only rebase and exit"
281+ )
227282 args = parser .parse_args ()
228283
229284 merger = PRMerger (args )
285+ merger .load_pr_data ()
286+
287+ if args .rebase_only :
288+ merger .checkout_pr ()
289+ merger .rebase_pr ()
290+ merger .delete_local_branch ()
291+ sys .exit (0 )
292+
230293 if not merger .validate_pr ():
231294 print ()
232295 print (
@@ -239,8 +302,8 @@ def delete_local_branch(self):
239302 print ("! --validate-only passed, will exit here" )
240303 sys .exit (0 )
241304
242- merger .rebase_pr ()
243305 merger .checkout_pr ()
306+ merger .rebase_pr ()
244307
245308 if args .no_push :
246309 print ()
0 commit comments