Skip to content

Commit 7a3e404

Browse files
authored
feat: switch to original head after running update-image (#2696)
Fixes #2648
1 parent 50b95f2 commit 7a3e404

File tree

2 files changed

+35
-27
lines changed

2 files changed

+35
-27
lines changed

internal/librarian/update_image.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ func (r *updateImageRunner) run(ctx context.Context) error {
113113
// For each library, run generation at the previous commit
114114
var failedGenerations []*config.LibraryState
115115
var successfulGenerations []*config.LibraryState
116+
sourceHead, err := r.sourceRepo.HeadHash()
117+
if err != nil {
118+
return err
119+
}
116120
outputDir := filepath.Join(r.workRoot, "output")
117121
for _, libraryState := range r.state.Libraries {
118122
err := r.regenerateSingleLibrary(ctx, libraryState, outputDir)
@@ -132,6 +136,10 @@ func (r *updateImageRunner) run(ctx context.Context) error {
132136
prBodyBuilder := func() (string, error) {
133137
return formatUpdateImagePRBody(r.image, failedGenerations)
134138
}
139+
// Restore api source repo
140+
if err := r.sourceRepo.Checkout(sourceHead); err != nil {
141+
slog.Error(err.Error(), "repository", r.sourceRepo, "HEAD", sourceHead)
142+
}
135143
commitMessage := fmt.Sprintf("feat: update image to %s", r.image)
136144
return commitAndPush(ctx, &commitInfo{
137145
branch: r.branch,

internal/librarian/update_image_test.go

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
231231
wantFindLatestCalls: 0,
232232
wantGenerateCalls: 1,
233233
wantBuildCalls: 0, // no -build flag
234-
wantCheckoutCalls: 1,
234+
wantCheckoutCalls: 2,
235235
},
236236
{
237237
name: "no change image",
@@ -280,7 +280,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
280280
wantFindLatestCalls: 1,
281281
wantGenerateCalls: 1,
282282
wantBuildCalls: 0, // no -build flag
283-
wantCheckoutCalls: 1,
283+
wantCheckoutCalls: 2,
284284
},
285285
{
286286
name: "finds image error",
@@ -333,7 +333,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
333333
wantFindLatestCalls: 1,
334334
wantGenerateCalls: 1,
335335
wantBuildCalls: 1,
336-
wantCheckoutCalls: 1,
336+
wantCheckoutCalls: 2,
337337
},
338338
{
339339
name: "updates multiple",
@@ -367,7 +367,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
367367
wantFindLatestCalls: 1,
368368
wantGenerateCalls: 2,
369369
wantBuildCalls: 2,
370-
wantCheckoutCalls: 2,
370+
wantCheckoutCalls: 3,
371371
},
372372
{
373373
name: "skips libraries without APIs",
@@ -401,7 +401,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
401401
wantFindLatestCalls: 1,
402402
wantGenerateCalls: 1,
403403
wantBuildCalls: 1,
404-
wantCheckoutCalls: 1,
404+
wantCheckoutCalls: 2,
405405
},
406406
{
407407
name: "partial generate success",
@@ -438,7 +438,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
438438
wantFindLatestCalls: 1,
439439
wantGenerateCalls: 2,
440440
wantBuildCalls: 1, // build for failed generate should not run
441-
wantCheckoutCalls: 2,
441+
wantCheckoutCalls: 3,
442442
},
443443
{
444444
name: "partial build success",
@@ -475,7 +475,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
475475
wantFindLatestCalls: 1,
476476
wantGenerateCalls: 2,
477477
wantBuildCalls: 2,
478-
wantCheckoutCalls: 2,
478+
wantCheckoutCalls: 3,
479479
},
480480
{
481481
name: "checkout error",
@@ -509,7 +509,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
509509
wantFindLatestCalls: 1,
510510
wantGenerateCalls: 0,
511511
wantBuildCalls: 0,
512-
wantCheckoutCalls: 2,
512+
wantCheckoutCalls: 3,
513513
checkoutError: fmt.Errorf("some checkout error"),
514514
},
515515
{
@@ -545,7 +545,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
545545
wantFindLatestCalls: 1,
546546
wantGenerateCalls: 2,
547547
wantBuildCalls: 2,
548-
wantCheckoutCalls: 2,
548+
wantCheckoutCalls: 3,
549549
wantCommitMsg: "feat: update image to gcr.io/test/image@sha256:abc123",
550550
},
551551
{
@@ -584,7 +584,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
584584
wantFindLatestCalls: 1,
585585
wantGenerateCalls: 2,
586586
wantBuildCalls: 2,
587-
wantCheckoutCalls: 2,
587+
wantCheckoutCalls: 3,
588588
wantCreatePullRequestCalls: 1,
589589
wantCommitMsg: "feat: update image to gcr.io/test/image@sha256:abc123",
590590
wantErr: true,
@@ -632,7 +632,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
632632
wantFindLatestCalls: 1,
633633
wantGenerateCalls: 2,
634634
wantBuildCalls: 2,
635-
wantCheckoutCalls: 2,
635+
wantCheckoutCalls: 3,
636636
wantCreatePullRequestCalls: 1,
637637
wantCommitMsg: "feat: update image to gcr.io/test/image@sha256:abc123",
638638
},
@@ -681,7 +681,7 @@ func TestUpdateImageRunnerRun(t *testing.T) {
681681
wantFindLatestCalls: 1,
682682
wantGenerateCalls: 2,
683683
wantBuildCalls: 2,
684-
wantCheckoutCalls: 2,
684+
wantCheckoutCalls: 3,
685685
wantCreatePullRequestCalls: 1,
686686
wantCreateIssueCalls: 1,
687687
wantCommitMsg: "feat: update image to gcr.io/test/image@sha256:abc123",
@@ -718,21 +718,6 @@ func TestUpdateImageRunnerRun(t *testing.T) {
718718

719719
err := r.run(t.Context())
720720

721-
if test.wantErr {
722-
if err == nil {
723-
t.Fatalf("%s should return error", test.name)
724-
}
725-
726-
if !strings.Contains(err.Error(), test.wantErrMsg) {
727-
t.Errorf("want error message %s, got %s", test.wantErrMsg, err.Error())
728-
}
729-
return
730-
} else {
731-
if err != nil {
732-
t.Fatal(err)
733-
}
734-
}
735-
736721
if diff := cmp.Diff(test.wantGenerateCalls, test.containerClient.generateCalls); diff != "" {
737722
t.Errorf("%s: run() generateCalls mismatch (-want +got):%s", test.name, diff)
738723
}
@@ -752,6 +737,21 @@ func TestUpdateImageRunnerRun(t *testing.T) {
752737
t.Errorf("%s: run() createIssueCalls mismatch (-want +got):%s", test.name, diff)
753738
}
754739

740+
if test.wantErr {
741+
if err == nil {
742+
t.Fatalf("%s should return error", test.name)
743+
}
744+
745+
if !strings.Contains(err.Error(), test.wantErrMsg) {
746+
t.Errorf("want error message %s, got %s", test.wantErrMsg, err.Error())
747+
}
748+
return
749+
} else {
750+
if err != nil {
751+
t.Fatal(err)
752+
}
753+
}
754+
755755
if test.wantCommitMsg != "" {
756756
if diff := cmp.Diff(test.wantCommitMsg, repo.LastCommitMessage); diff != "" {
757757
t.Errorf("%s: run() commit message mismatch (-want +got):%s", test.name, diff)

0 commit comments

Comments
 (0)