@@ -12,7 +12,15 @@ import (
12
12
"github.com/github/gh-combine/internal/github"
13
13
)
14
14
15
- func CombinePRs (ctx context.Context , graphQlClient * api.GraphQLClient , restClient * api.RESTClient , repo github.Repo , pulls github.Pulls ) error {
15
+ // Updated RESTClientInterface to match the method signatures of api.RESTClient
16
+ type RESTClientInterface interface {
17
+ Post (endpoint string , body io.Reader , response interface {}) error
18
+ Get (endpoint string , response interface {}) error
19
+ Delete (endpoint string , response interface {}) error
20
+ Patch (endpoint string , body io.Reader , response interface {}) error
21
+ }
22
+
23
+ func CombinePRs (ctx context.Context , graphQlClient * api.GraphQLClient , restClient RESTClientInterface , repo github.Repo , pulls github.Pulls ) error {
16
24
// Define the combined branch name
17
25
workingBranchName := combineBranchName + workingBranchSuffix
18
26
@@ -87,7 +95,7 @@ func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClien
87
95
// Create the combined PR
88
96
prBody := generatePRBody (combinedPRs , mergeFailedPRs )
89
97
prTitle := "Combined PRs"
90
- err = createPullRequest (ctx , restClient , repo , prTitle , combineBranchName , repoDefaultBranch , prBody )
98
+ err = createPullRequest (ctx , restClient , repo , prTitle , combineBranchName , repoDefaultBranch , prBody , addLabels , addAssignees )
91
99
if err != nil {
92
100
return fmt .Errorf ("failed to create combined PR: %w" , err )
93
101
}
@@ -102,7 +110,7 @@ func isMergeConflictError(err error) bool {
102
110
}
103
111
104
112
// Find the default branch of a repository
105
- func getDefaultBranch (ctx context.Context , client * api. RESTClient , repo github.Repo ) (string , error ) {
113
+ func getDefaultBranch (ctx context.Context , client RESTClientInterface , repo github.Repo ) (string , error ) {
106
114
var repoInfo struct {
107
115
DefaultBranch string `json:"default_branch"`
108
116
}
@@ -115,7 +123,7 @@ func getDefaultBranch(ctx context.Context, client *api.RESTClient, repo github.R
115
123
}
116
124
117
125
// Get the SHA of a given branch
118
- func getBranchSHA (ctx context.Context , client * api. RESTClient , repo github.Repo , branch string ) (string , error ) {
126
+ func getBranchSHA (ctx context.Context , client RESTClientInterface , repo github.Repo , branch string ) (string , error ) {
119
127
var ref struct {
120
128
Object struct {
121
129
SHA string `json:"sha"`
@@ -148,13 +156,13 @@ func generatePRBody(combinedPRs, mergeFailedPRs []string) string {
148
156
}
149
157
150
158
// deleteBranch deletes a branch in the repository
151
- func deleteBranch (ctx context.Context , client * api. RESTClient , repo github.Repo , branch string ) error {
159
+ func deleteBranch (ctx context.Context , client RESTClientInterface , repo github.Repo , branch string ) error {
152
160
endpoint := fmt .Sprintf ("repos/%s/%s/git/refs/heads/%s" , repo .Owner , repo .Repo , branch )
153
161
return client .Delete (endpoint , nil )
154
162
}
155
163
156
164
// createBranch creates a new branch in the repository
157
- func createBranch (ctx context.Context , client * api. RESTClient , repo github.Repo , branch , sha string ) error {
165
+ func createBranch (ctx context.Context , client RESTClientInterface , repo github.Repo , branch , sha string ) error {
158
166
endpoint := fmt .Sprintf ("repos/%s/%s/git/refs" , repo .Owner , repo .Repo )
159
167
payload := map [string ]string {
160
168
"ref" : "refs/heads/" + branch ,
@@ -168,7 +176,7 @@ func createBranch(ctx context.Context, client *api.RESTClient, repo github.Repo,
168
176
}
169
177
170
178
// mergeBranch merges a branch into the base branch
171
- func mergeBranch (ctx context.Context , client * api. RESTClient , repo github.Repo , base , head string ) error {
179
+ func mergeBranch (ctx context.Context , client RESTClientInterface , repo github.Repo , base , head string ) error {
172
180
endpoint := fmt .Sprintf ("repos/%s/%s/merges" , repo .Owner , repo .Repo )
173
181
payload := map [string ]string {
174
182
"base" : base ,
@@ -182,7 +190,7 @@ func mergeBranch(ctx context.Context, client *api.RESTClient, repo github.Repo,
182
190
}
183
191
184
192
// updateRef updates a branch to point to the latest commit of another branch
185
- func updateRef (ctx context.Context , client * api. RESTClient , repo github.Repo , branch , sourceBranch string ) error {
193
+ func updateRef (ctx context.Context , client RESTClientInterface , repo github.Repo , branch , sourceBranch string ) error {
186
194
// Get the SHA of the source branch
187
195
var ref struct {
188
196
Object struct {
@@ -208,20 +216,56 @@ func updateRef(ctx context.Context, client *api.RESTClient, repo github.Repo, br
208
216
return client .Patch (endpoint , body , nil )
209
217
}
210
218
211
- // createPullRequest creates a new pull request
212
- func createPullRequest (ctx context.Context , client * api.RESTClient , repo github.Repo , title , head , base , body string ) error {
219
+ func createPullRequest (ctx context.Context , client RESTClientInterface , repo github.Repo , title , head , base , body string , labels , assignees []string ) error {
213
220
endpoint := fmt .Sprintf ("repos/%s/%s/pulls" , repo .Owner , repo .Repo )
214
- payload := map [string ]string {
221
+ payload := map [string ]interface {} {
215
222
"title" : title ,
216
223
"head" : head ,
217
224
"base" : base ,
218
225
"body" : body ,
219
226
}
227
+
220
228
requestBody , err := encodePayload (payload )
221
229
if err != nil {
222
230
return fmt .Errorf ("failed to encode payload: %w" , err )
223
231
}
224
- return client .Post (endpoint , requestBody , nil )
232
+
233
+ // Create the pull request
234
+ var prResponse struct {
235
+ Number int `json:"number"`
236
+ }
237
+ err = client .Post (endpoint , requestBody , & prResponse )
238
+ if err != nil {
239
+ return fmt .Errorf ("failed to create pull request: %w" , err )
240
+ }
241
+
242
+ // Add labels if provided
243
+ if len (labels ) > 0 {
244
+ labelsEndpoint := fmt .Sprintf ("repos/%s/%s/issues/%d/labels" , repo .Owner , repo .Repo , prResponse .Number )
245
+ labelsPayload , err := encodePayload (map [string ][]string {"labels" : labels })
246
+ if err != nil {
247
+ return fmt .Errorf ("failed to encode labels payload: %w" , err )
248
+ }
249
+ err = client .Post (labelsEndpoint , labelsPayload , nil )
250
+ if err != nil {
251
+ return fmt .Errorf ("failed to add labels: %w" , err )
252
+ }
253
+ }
254
+
255
+ // Add assignees if provided
256
+ if len (assignees ) > 0 {
257
+ assigneesEndpoint := fmt .Sprintf ("repos/%s/%s/issues/%d/assignees" , repo .Owner , repo .Repo , prResponse .Number )
258
+ assigneesPayload , err := encodePayload (map [string ][]string {"assignees" : assignees })
259
+ if err != nil {
260
+ return fmt .Errorf ("failed to encode assignees payload: %w" , err )
261
+ }
262
+ err = client .Post (assigneesEndpoint , assigneesPayload , nil )
263
+ if err != nil {
264
+ return fmt .Errorf ("failed to add assignees: %w" , err )
265
+ }
266
+ }
267
+
268
+ return nil
225
269
}
226
270
227
271
// encodePayload encodes a payload as JSON and returns an io.Reader
0 commit comments