Skip to content

Commit 2b3626c

Browse files
authored
Merge pull request #2 from github/ci-and-approval-checks
feat: implement PR combine
2 parents 0ef96ce + 094ca7a commit 2b3626c

File tree

4 files changed

+421
-22
lines changed

4 files changed

+421
-22
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ require github.com/briandowns/spinner v1.23.2
88

99
require (
1010
github.com/cli/go-gh/v2 v2.12.0
11+
github.com/cli/shurcooL-graphql v0.0.4
1112
github.com/spf13/cobra v1.9.1
1213
)
1314

1415
require (
1516
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
1617
github.com/cli/safeexec v1.0.0 // indirect
17-
github.com/cli/shurcooL-graphql v0.0.4 // indirect
1818
github.com/fatih/color v1.7.0 // indirect
1919
github.com/henvic/httpretty v0.0.6 // indirect
2020
github.com/inconshreveable/mousetrap v1.1.0 // indirect

internal/cmd/combine_prs.go

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
package cmd
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
10+
"github.com/cli/go-gh/v2/pkg/api"
11+
)
12+
13+
func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClient *api.RESTClient, owner, repo string, matchedPRs []struct {
14+
Number int
15+
Title string
16+
Branch string
17+
Base string
18+
BaseSHA string
19+
}) error {
20+
// Define the combined branch name
21+
workingBranchName := combineBranchName + workingBranchSuffix
22+
23+
baseBranchSHA, err := getBranchSHA(ctx, restClient, owner, repo, baseBranch)
24+
if err != nil {
25+
return fmt.Errorf("failed to get SHA of main branch: %w", err)
26+
}
27+
28+
// Delete any pre-existing working branch
29+
err = deleteBranch(ctx, restClient, owner, repo, workingBranchName)
30+
if err != nil {
31+
Logger.Debug("Working branch not found, continuing", "branch", workingBranchName)
32+
}
33+
34+
// Delete any pre-existing combined branch
35+
err = deleteBranch(ctx, restClient, owner, repo, combineBranchName)
36+
if err != nil {
37+
Logger.Debug("Combined branch not found, continuing", "branch", combineBranchName)
38+
}
39+
40+
// Create the combined branch
41+
err = createBranch(ctx, restClient, owner, repo, combineBranchName, baseBranchSHA)
42+
if err != nil {
43+
return fmt.Errorf("failed to create combined branch: %w", err)
44+
}
45+
46+
// Create the working branch
47+
err = createBranch(ctx, restClient, owner, repo, workingBranchName, baseBranchSHA)
48+
if err != nil {
49+
return fmt.Errorf("failed to create working branch: %w", err)
50+
}
51+
52+
// Merge all PR branches into the working branch
53+
var combinedPRs []string
54+
var mergeFailedPRs []string
55+
for _, pr := range matchedPRs {
56+
err := mergeBranch(ctx, restClient, owner, repo, workingBranchName, pr.Branch)
57+
if err != nil {
58+
Logger.Warn("Failed to merge branch", "branch", pr.Branch, "error", err)
59+
mergeFailedPRs = append(mergeFailedPRs, fmt.Sprintf("#%d", pr.Number))
60+
} else {
61+
Logger.Info("Merged branch", "branch", pr.Branch)
62+
combinedPRs = append(combinedPRs, fmt.Sprintf("#%d - %s", pr.Number, pr.Title))
63+
}
64+
}
65+
66+
// Update the combined branch to the latest commit of the working branch
67+
err = updateRef(ctx, restClient, owner, repo, combineBranchName, workingBranchName)
68+
if err != nil {
69+
return fmt.Errorf("failed to update combined branch: %w", err)
70+
}
71+
72+
// Delete the temporary working branch
73+
err = deleteBranch(ctx, restClient, owner, repo, workingBranchName)
74+
if err != nil {
75+
Logger.Warn("Failed to delete working branch", "branch", workingBranchName, "error", err)
76+
}
77+
78+
// Create the combined PR
79+
prBody := generatePRBody(combinedPRs, mergeFailedPRs)
80+
prTitle := "Combined PRs"
81+
err = createPullRequest(ctx, restClient, owner, repo, prTitle, combineBranchName, baseBranch, prBody)
82+
if err != nil {
83+
return fmt.Errorf("failed to create combined PR: %w", err)
84+
}
85+
86+
return nil
87+
}
88+
89+
// Get the SHA of a given branch
90+
func getBranchSHA(ctx context.Context, client *api.RESTClient, owner, repo, branch string) (string, error) {
91+
var ref struct {
92+
Object struct {
93+
SHA string `json:"sha"`
94+
} `json:"object"`
95+
}
96+
endpoint := fmt.Sprintf("repos/%s/%s/git/ref/heads/%s", owner, repo, branch)
97+
err := client.Get(endpoint, &ref)
98+
if err != nil {
99+
return "", fmt.Errorf("failed to get SHA of branch %s: %w", branch, err)
100+
}
101+
return ref.Object.SHA, nil
102+
}
103+
104+
// generatePRBody generates the body for the combined PR
105+
func generatePRBody(combinedPRs, mergeFailedPRs []string) string {
106+
body := "✅ The following pull requests have been successfully combined:\n"
107+
for _, pr := range combinedPRs {
108+
body += "- " + pr + "\n"
109+
}
110+
if len(mergeFailedPRs) > 0 {
111+
body += "\n⚠️ The following pull requests could not be merged due to conflicts:\n"
112+
for _, pr := range mergeFailedPRs {
113+
body += "- " + pr + "\n"
114+
}
115+
}
116+
return body
117+
}
118+
119+
// deleteBranch deletes a branch in the repository
120+
func deleteBranch(ctx context.Context, client *api.RESTClient, owner, repo, branch string) error {
121+
endpoint := fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", owner, repo, branch)
122+
return client.Delete(endpoint, nil)
123+
}
124+
125+
// createBranch creates a new branch in the repository
126+
func createBranch(ctx context.Context, client *api.RESTClient, owner, repo, branch, sha string) error {
127+
endpoint := fmt.Sprintf("repos/%s/%s/git/refs", owner, repo)
128+
payload := map[string]string{
129+
"ref": "refs/heads/" + branch,
130+
"sha": sha,
131+
}
132+
body, err := encodePayload(payload)
133+
if err != nil {
134+
return fmt.Errorf("failed to encode payload: %w", err)
135+
}
136+
return client.Post(endpoint, body, nil)
137+
}
138+
139+
// mergeBranch merges a branch into the base branch
140+
func mergeBranch(ctx context.Context, client *api.RESTClient, owner, repo, base, head string) error {
141+
endpoint := fmt.Sprintf("repos/%s/%s/merges", owner, repo)
142+
payload := map[string]string{
143+
"base": base,
144+
"head": head,
145+
}
146+
body, err := encodePayload(payload)
147+
if err != nil {
148+
return fmt.Errorf("failed to encode payload: %w", err)
149+
}
150+
return client.Post(endpoint, body, nil)
151+
}
152+
153+
// updateRef updates a branch to point to the latest commit of another branch
154+
func updateRef(ctx context.Context, client *api.RESTClient, owner, repo, branch, sourceBranch string) error {
155+
// Get the SHA of the source branch
156+
var ref struct {
157+
Object struct {
158+
SHA string `json:"sha"`
159+
} `json:"object"`
160+
}
161+
endpoint := fmt.Sprintf("repos/%s/%s/git/ref/heads/%s", owner, repo, sourceBranch)
162+
err := client.Get(endpoint, &ref)
163+
if err != nil {
164+
return fmt.Errorf("failed to get SHA of source branch: %w", err)
165+
}
166+
167+
// Update the branch to point to the new SHA
168+
endpoint = fmt.Sprintf("repos/%s/%s/git/refs/heads/%s", owner, repo, branch)
169+
payload := map[string]interface{}{
170+
"sha": ref.Object.SHA,
171+
"force": true,
172+
}
173+
body, err := encodePayload(payload)
174+
if err != nil {
175+
return fmt.Errorf("failed to encode payload: %w", err)
176+
}
177+
return client.Patch(endpoint, body, nil)
178+
}
179+
180+
// createPullRequest creates a new pull request
181+
func createPullRequest(ctx context.Context, client *api.RESTClient, owner, repo, title, head, base, body string) error {
182+
endpoint := fmt.Sprintf("repos/%s/%s/pulls", owner, repo)
183+
payload := map[string]string{
184+
"title": title,
185+
"head": head,
186+
"base": base,
187+
"body": body,
188+
}
189+
requestBody, err := encodePayload(payload)
190+
if err != nil {
191+
return fmt.Errorf("failed to encode payload: %w", err)
192+
}
193+
return client.Post(endpoint, requestBody, nil)
194+
}
195+
196+
// encodePayload encodes a payload as JSON and returns an io.Reader
197+
func encodePayload(payload interface{}) (io.Reader, error) {
198+
data, err := json.Marshal(payload)
199+
if err != nil {
200+
return nil, err
201+
}
202+
return bytes.NewReader(data), nil
203+
}

internal/cmd/match_criteria.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package cmd
22

33
import (
4+
"context"
5+
"fmt"
46
"regexp"
57
"strings"
8+
9+
"github.com/cli/go-gh/v2/pkg/api"
10+
graphql "github.com/cli/shurcooL-graphql"
611
)
712

813
// checks if a PR matches all filtering criteria
@@ -113,3 +118,165 @@ func labelsMatchCriteria(prLabels []struct{ Name string }) bool {
113118

114119
return true
115120
}
121+
122+
// GraphQL response structure for PR status info
123+
type prStatusResponse struct {
124+
Data struct {
125+
Repository struct {
126+
PullRequest struct {
127+
ReviewDecision string `json:"reviewDecision"`
128+
Commits struct {
129+
Nodes []struct {
130+
Commit struct {
131+
StatusCheckRollup *struct {
132+
State string `json:"state"`
133+
} `json:"statusCheckRollup"`
134+
} `json:"commit"`
135+
} `json:"nodes"`
136+
} `json:"commits"`
137+
} `json:"pullRequest"`
138+
} `json:"repository"`
139+
} `json:"data"`
140+
Errors []struct {
141+
Message string `json:"message"`
142+
} `json:"errors,omitempty"`
143+
}
144+
145+
// GetPRStatusInfo fetches both CI status and approval status using GitHub's GraphQL API
146+
func GetPRStatusInfo(ctx context.Context, graphQlClient *api.GraphQLClient, owner, repo string, prNumber int) (*prStatusResponse, error) {
147+
// Check for context cancellation
148+
select {
149+
case <-ctx.Done():
150+
return nil, ctx.Err()
151+
default:
152+
// Continue processing
153+
}
154+
155+
// Define a struct with embedded graphql query
156+
var query struct {
157+
Repository struct {
158+
PullRequest struct {
159+
ReviewDecision string
160+
Commits struct {
161+
Nodes []struct {
162+
Commit struct {
163+
StatusCheckRollup *struct {
164+
State string
165+
}
166+
}
167+
}
168+
} `graphql:"commits(last: 1)"`
169+
} `graphql:"pullRequest(number: $prNumber)"`
170+
} `graphql:"repository(owner: $owner, name: $repo)"`
171+
}
172+
173+
// Prepare GraphQL query variables
174+
variables := map[string]interface{}{
175+
"owner": graphql.String(owner),
176+
"repo": graphql.String(repo),
177+
"prNumber": graphql.Int(prNumber),
178+
}
179+
180+
// Execute GraphQL query
181+
err := graphQlClient.Query("PullRequestStatus", &query, variables)
182+
if err != nil {
183+
return nil, fmt.Errorf("GraphQL query failed: %w", err)
184+
}
185+
186+
// Convert to our response format
187+
response := &prStatusResponse{}
188+
response.Data.Repository.PullRequest.ReviewDecision = query.Repository.PullRequest.ReviewDecision
189+
190+
if len(query.Repository.PullRequest.Commits.Nodes) > 0 {
191+
response.Data.Repository.PullRequest.Commits.Nodes = make([]struct {
192+
Commit struct {
193+
StatusCheckRollup *struct {
194+
State string `json:"state"`
195+
} `json:"statusCheckRollup"`
196+
} `json:"commit"`
197+
}, len(query.Repository.PullRequest.Commits.Nodes))
198+
199+
for i, node := range query.Repository.PullRequest.Commits.Nodes {
200+
if node.Commit.StatusCheckRollup != nil {
201+
response.Data.Repository.PullRequest.Commits.Nodes[i].Commit.StatusCheckRollup = &struct {
202+
State string `json:"state"`
203+
}{
204+
State: node.Commit.StatusCheckRollup.State,
205+
}
206+
}
207+
}
208+
}
209+
210+
return response, nil
211+
}
212+
213+
// PrMeetsRequirements checks if a PR meets additional requirements beyond basic criteria
214+
func PrMeetsRequirements(ctx context.Context, graphQlClient *api.GraphQLClient, owner, repo string, prNumber int) (bool, error) {
215+
// If no additional requirements are specified, the PR meets requirements
216+
if !requireCI && !mustBeApproved {
217+
return true, nil
218+
}
219+
220+
// Fetch PR status info once
221+
response, err := GetPRStatusInfo(ctx, graphQlClient, owner, repo, prNumber)
222+
if err != nil {
223+
return false, err
224+
}
225+
226+
// Check CI status if required
227+
if requireCI {
228+
passing := isCIPassing(response)
229+
if !passing {
230+
return false, nil
231+
}
232+
}
233+
234+
// Check approval status if required
235+
if mustBeApproved {
236+
approved := isPRApproved(response)
237+
if !approved {
238+
return false, nil
239+
}
240+
}
241+
242+
return true, nil
243+
}
244+
245+
// isCIPassing checks if the CI status is passing based on the response
246+
func isCIPassing(response *prStatusResponse) bool {
247+
commits := response.Data.Repository.PullRequest.Commits.Nodes
248+
if len(commits) == 0 {
249+
Logger.Debug("No commits found for PR")
250+
return false
251+
}
252+
253+
statusCheckRollup := commits[0].Commit.StatusCheckRollup
254+
if statusCheckRollup == nil {
255+
Logger.Debug("No status checks found for PR")
256+
return true // If no checks defined, consider it passing
257+
}
258+
259+
if statusCheckRollup.State != "SUCCESS" {
260+
Logger.Debug("PR failed CI check", "status", statusCheckRollup.State)
261+
return false
262+
}
263+
264+
return true
265+
}
266+
267+
// isPRApproved checks if the PR is approved based on the response
268+
func isPRApproved(response *prStatusResponse) bool {
269+
reviewDecision := response.Data.Repository.PullRequest.ReviewDecision
270+
Logger.Debug("PR review decision", "decision", reviewDecision)
271+
272+
switch reviewDecision {
273+
case "APPROVED":
274+
return true
275+
case "": // When no reviews are required
276+
Logger.Debug("PR has no required reviewers")
277+
return true // If no reviews required, consider it approved
278+
default:
279+
Logger.Debug("PR not approved", "decision", reviewDecision)
280+
return false
281+
}
282+
}

0 commit comments

Comments
 (0)