Skip to content

Commit b7da517

Browse files
committed
add combine PRs logic
1 parent 1bea40f commit b7da517

File tree

2 files changed

+238
-18
lines changed

2 files changed

+238
-18
lines changed

internal/cmd/combine_prs.go

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
package cmd
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
10+
"github.com/cli/go-gh/v2/pkg/api"
11+
)
12+
13+
func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClient *api.RESTClient, owner, repo string, matchedPRs []struct {
14+
Number int
15+
Title string
16+
Branch string
17+
Base string
18+
BaseSHA string
19+
}) error {
20+
// Define the combined branch name
21+
combineBranchName := "combined-prs"
22+
workingBranchName := combineBranchName + "-working"
23+
24+
baseBranchSHA, err := getBranchSHA(ctx, restClient, owner, repo, baseBranch)
25+
if err != nil {
26+
return fmt.Errorf("failed to get SHA of main branch: %w", err)
27+
}
28+
29+
// Delete any pre-existing working branch
30+
err = deleteBranch(ctx, restClient, owner, repo, workingBranchName)
31+
if err != nil {
32+
Logger.Debug("Working branch not found, continuing", "branch", workingBranchName)
33+
}
34+
35+
// Delete any pre-existing combined branch
36+
err = deleteBranch(ctx, restClient, owner, repo, combineBranchName)
37+
if err != nil {
38+
Logger.Debug("Combined branch not found, continuing", "branch", combineBranchName)
39+
}
40+
41+
// Create the combined branch
42+
err = createBranch(ctx, restClient, owner, repo, combineBranchName, baseBranchSHA)
43+
if err != nil {
44+
return fmt.Errorf("failed to create combined branch: %w", err)
45+
}
46+
47+
// Create the working branch
48+
err = createBranch(ctx, restClient, owner, repo, workingBranchName, baseBranchSHA)
49+
if err != nil {
50+
return fmt.Errorf("failed to create working branch: %w", err)
51+
}
52+
53+
// Merge all PR branches into the working branch
54+
var combinedPRs []string
55+
var mergeFailedPRs []string
56+
for _, pr := range matchedPRs {
57+
err := mergeBranch(ctx, restClient, owner, repo, workingBranchName, pr.Branch)
58+
if err != nil {
59+
Logger.Warn("Failed to merge branch", "branch", pr.Branch, "error", err)
60+
mergeFailedPRs = append(mergeFailedPRs, fmt.Sprintf("#%d", pr.Number))
61+
} else {
62+
Logger.Info("Merged branch", "branch", pr.Branch)
63+
combinedPRs = append(combinedPRs, fmt.Sprintf("#%d - %s", pr.Number, pr.Title))
64+
}
65+
}
66+
67+
// Update the combined branch to the latest commit of the working branch
68+
err = updateRef(ctx, restClient, owner, repo, combineBranchName, workingBranchName)
69+
if err != nil {
70+
return fmt.Errorf("failed to update combined branch: %w", err)
71+
}
72+
73+
// Delete the temporary working branch
74+
err = deleteBranch(ctx, restClient, owner, repo, workingBranchName)
75+
if err != nil {
76+
Logger.Warn("Failed to delete working branch", "branch", workingBranchName, "error", err)
77+
}
78+
79+
// Create the combined PR
80+
prBody := generatePRBody(combinedPRs, mergeFailedPRs)
81+
prTitle := "Combined PRs"
82+
baseBranch := matchedPRs[0].Base // Use the base branch of the first PR
83+
err = createPullRequest(ctx, restClient, owner, repo, prTitle, combineBranchName, baseBranch, prBody)
84+
if err != nil {
85+
return fmt.Errorf("failed to create combined PR: %w", err)
86+
}
87+
88+
return nil
89+
}
90+
91+
// Get the SHA of a given branch
92+
func getBranchSHA(ctx context.Context, client *api.RESTClient, owner, repo, branch string) (string, error) {
93+
var ref struct {
94+
Object struct {
95+
SHA string `json:"sha"`
96+
} `json:"object"`
97+
}
98+
endpoint := fmt.Sprintf("repos/%s/%s/git/ref/heads/%s", owner, repo, branch)
99+
err := client.Get(endpoint, &ref)
100+
if err != nil {
101+
return "", fmt.Errorf("failed to get SHA of branch %s: %w", branch, err)
102+
}
103+
return ref.Object.SHA, nil
104+
}
105+
106+
// generatePRBody generates the body for the combined PR
107+
func generatePRBody(combinedPRs, mergeFailedPRs []string) string {
108+
body := "✅ The following pull requests have been successfully combined:\n"
109+
for _, pr := range combinedPRs {
110+
body += "- " + pr + "\n"
111+
}
112+
if len(mergeFailedPRs) > 0 {
113+
body += "\n⚠️ The following pull requests could not be merged due to conflicts:\n"
114+
for _, pr := range mergeFailedPRs {
115+
body += "- " + pr + "\n"
116+
}
117+
}
118+
return body
119+
}
120+
121+
// deleteBranch deletes a branch in the repository
122+
func deleteBranch(ctx context.Context, client *api.RESTClient, owner, repo, branch string) error {
123+
endpoint := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", owner, repo, branch)
124+
return client.Delete(endpoint, nil)
125+
}
126+
127+
// createBranch creates a new branch in the repository
128+
func createBranch(ctx context.Context, client *api.RESTClient, owner, repo, branch, sha string) error {
129+
endpoint := fmt.Sprintf("repos/%s/%s/git/refs", owner, repo)
130+
payload := map[string]string{
131+
"ref": "refs/heads/" + branch,
132+
"sha": sha,
133+
}
134+
body, err := encodePayload(payload)
135+
if err != nil {
136+
return fmt.Errorf("failed to encode payload: %w", err)
137+
}
138+
return client.Post(endpoint, body, nil)
139+
}
140+
141+
// mergeBranch merges a branch into the base branch
142+
func mergeBranch(ctx context.Context, client *api.RESTClient, owner, repo, base, head string) error {
143+
endpoint := fmt.Sprintf("repos/%s/%s/merges", owner, repo)
144+
payload := map[string]string{
145+
"base": base,
146+
"head": head,
147+
}
148+
body, err := encodePayload(payload)
149+
if err != nil {
150+
return fmt.Errorf("failed to encode payload: %w", err)
151+
}
152+
return client.Post(endpoint, body, nil)
153+
}
154+
155+
// updateRef updates a branch to point to the latest commit of another branch
156+
func updateRef(ctx context.Context, client *api.RESTClient, owner, repo, branch, sourceBranch string) error {
157+
// Get the SHA of the source branch
158+
var ref struct {
159+
Object struct {
160+
SHA string `json:"sha"`
161+
} `json:"object"`
162+
}
163+
endpoint := fmt.Sprintf("repos/%s/%s/git/ref/heads/%s", owner, repo, sourceBranch)
164+
err := client.Get(endpoint, &ref)
165+
if err != nil {
166+
return fmt.Errorf("failed to get SHA of source branch: %w", err)
167+
}
168+
169+
// Update the branch to point to the new SHA
170+
endpoint = fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", owner, repo, branch)
171+
payload := map[string]interface{}{
172+
"sha": ref.Object.SHA,
173+
"force": true,
174+
}
175+
body, err := encodePayload(payload)
176+
if err != nil {
177+
return fmt.Errorf("failed to encode payload: %w", err)
178+
}
179+
return client.Patch(endpoint, body, nil)
180+
}
181+
182+
// createPullRequest creates a new pull request
183+
func createPullRequest(ctx context.Context, client *api.RESTClient, owner, repo, title, head, base, body string) error {
184+
endpoint := fmt.Sprintf("repos/%s/%s/pulls", owner, repo)
185+
payload := map[string]string{
186+
"title": title,
187+
"head": head,
188+
"base": base,
189+
"body": body,
190+
}
191+
requestBody, err := encodePayload(payload)
192+
if err != nil {
193+
return fmt.Errorf("failed to encode payload: %w", err)
194+
}
195+
return client.Post(endpoint, requestBody, nil)
196+
}
197+
198+
// encodePayload encodes a payload as JSON and returns an io.Reader
199+
func encodePayload(payload interface{}) (io.Reader, error) {
200+
data, err := json.Marshal(payload)
201+
if err != nil {
202+
return nil, err
203+
}
204+
return bytes.NewReader(data), nil
205+
}

internal/cmd/root.go

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,24 @@ import (
1414
)
1515

1616
var (
17-
branchPrefix string
18-
branchSuffix string
19-
branchRegex string
20-
selectLabel string
21-
selectLabels []string
22-
addLabels []string
23-
addAssignees []string
24-
requireCI bool
25-
mustBeApproved bool
26-
autoclose bool
27-
updateBranch bool
28-
ignoreLabel string
29-
ignoreLabels []string
30-
reposFile string
31-
minimum int
32-
defaultOwner string
17+
branchPrefix string
18+
branchSuffix string
19+
branchRegex string
20+
selectLabel string
21+
selectLabels []string
22+
addLabels []string
23+
addAssignees []string
24+
requireCI bool
25+
mustBeApproved bool
26+
autoclose bool
27+
updateBranch bool
28+
ignoreLabel string
29+
ignoreLabels []string
30+
reposFile string
31+
minimum int
32+
defaultOwner string
33+
doNotCombineFromScratch bool
34+
baseBranch string
3335
)
3436

3537
// NewRootCmd creates the root command for the gh-combine CLI
@@ -80,8 +82,10 @@ func NewRootCmd() *cobra.Command {
8082
gh combine octocat/hello-world --add-assignees octocat,hubot # Assign users to the new PR
8183
8284
# Additional options
83-
gh combine octocat/hello-world --autoclose # Close source PRs when combined PR is merged
84-
gh combine octocat/hello-world --update-branch # Update the branch of the combined PR`,
85+
gh combine octocat/hello-world --autoclose # Close source PRs when combined PR is merged
86+
gh combine octocat/hello-world --base-branch main # Use a different base branch for the combined PR
87+
gh combine octocat/hello-world --do-not-combine-from-scratch # Do not combine the PRs from scratch
88+
gh combine octocat/hello-world --update-branch # Update the branch of the combined PR`,
8589
RunE: runCombine,
8690
}
8791

@@ -107,6 +111,8 @@ func NewRootCmd() *cobra.Command {
107111
rootCmd.Flags().BoolVar(&mustBeApproved, "require-approved", false, "Only include PRs that have been approved")
108112
rootCmd.Flags().BoolVar(&autoclose, "autoclose", false, "Close source PRs when combined PR is merged")
109113
rootCmd.Flags().BoolVar(&updateBranch, "update-branch", false, "Update the branch of the combined PR if possible")
114+
rootCmd.Flags().BoolVar(&doNotCombineFromScratch, "do-not-combine-from-scratch", false, "Do not combine the PRs from scratch (clean)")
115+
rootCmd.Flags().StringVar(&baseBranch, "base-branch", "main", "Base branch for the combined PR (default: main)")
110116
rootCmd.Flags().StringVar(&reposFile, "file", "", "File containing repository names, one per line")
111117
rootCmd.Flags().IntVar(&minimum, "minimum", 2, "Minimum number of PRs to combine")
112118
rootCmd.Flags().StringVar(&defaultOwner, "owner", "", "Default owner for repositories (if not specified in repo name or missing from file inputs)")
@@ -294,5 +300,14 @@ func processRepository(ctx context.Context, client *api.RESTClient, graphQlClien
294300
}
295301

296302
Logger.Debug("Matched PRs", "repo", repo, "count", len(matchedPRs))
303+
304+
// If we get here, we have enough PRs to combine
305+
306+
// Combine the PRs
307+
err := CombinePRs(ctx, graphQlClient, client, owner, repoName, matchedPRs)
308+
if err != nil {
309+
return fmt.Errorf("failed to combine PRs: %w", err)
310+
}
311+
297312
return nil
298313
}

0 commit comments

Comments
 (0)