Skip to content

Commit e5a86a1

Browse files
authored
feat: add Restore function (#2028)
Updates #1682
1 parent 2807adf commit e5a86a1

File tree

5 files changed

+101
-3
lines changed

5 files changed

+101
-3
lines changed

internal/gitrepo/gitrepo.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"fmt"
2121
"log/slog"
2222
"os"
23+
"os/exec"
2324
"strings"
2425
"time"
2526

@@ -44,6 +45,7 @@ type Repository interface {
4445
GetCommitsForPathsSinceCommit(paths []string, sinceCommit string) ([]*Commit, error)
4546
CreateBranchAndCheckout(name string) error
4647
Push(branchName string) error
48+
Restore(paths []string) error
4749
}
4850

4951
// LocalRepository represents a git repository.
@@ -449,3 +451,19 @@ func (r *LocalRepository) Push(branchName string) error {
449451
slog.Info("Successfully pushed branch to remote 'origin", "branch", branchName)
450452
return nil
451453
}
454+
455+
// Restore restores changes in the working tree, leaving staged area untouched.
456+
// Note that untracked files, if any, are not touched.
457+
//
458+
// Wrap git operations in exec, because [git.Worktree.Restore] does not support
459+
// this operation.
460+
func (r *LocalRepository) Restore(paths []string) error {
461+
args := []string{"restore"}
462+
args = append(args, paths...)
463+
slog.Info("Restoring uncommitted changes", "paths", strings.Join(paths, ","))
464+
cmd := exec.Command("git", args...)
465+
cmd.Stderr = os.Stderr
466+
cmd.Stdout = os.Stdout
467+
cmd.Dir = r.Dir
468+
return cmd.Run()
469+
}

internal/gitrepo/gitrepo_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package gitrepo
1616

1717
import (
18+
"fmt"
1819
"os"
1920
"path/filepath"
2021
"strings"
@@ -1117,6 +1118,80 @@ func TestCreateBranchAndCheckout(t *testing.T) {
11171118
}
11181119
}
11191120

1121+
func TestRestore(t *testing.T) {
1122+
for _, test := range []struct {
1123+
name string
1124+
paths []string
1125+
wantErr bool
1126+
wantErrPhrase string
1127+
}{
1128+
{
1129+
name: "restore files in paths",
1130+
paths: []string{
1131+
"first/path",
1132+
"second/path",
1133+
},
1134+
},
1135+
} {
1136+
t.Run(test.name, func(t *testing.T) {
1137+
repo, dir := initTestRepo(t)
1138+
localRepo := &LocalRepository{
1139+
Dir: dir,
1140+
repo: repo,
1141+
}
1142+
// Create files in test.paths and commit the change.
1143+
for _, path := range test.paths {
1144+
file := filepath.Join(path, "example.txt")
1145+
createAndCommit(t, repo, file, []byte("old content"), fmt.Sprintf("commit path, %s", path))
1146+
}
1147+
1148+
// Change file contents.
1149+
for _, path := range test.paths {
1150+
file := filepath.Join(dir, path, "example.txt")
1151+
if err := os.WriteFile(file, []byte("new content"), 0755); err != nil {
1152+
t.Fatal(err)
1153+
}
1154+
// Create untracked files.
1155+
untrackedFile := filepath.Join(dir, path, "untracked.txt")
1156+
if err := os.WriteFile(untrackedFile, []byte("new content"), 0755); err != nil {
1157+
t.Fatal(err)
1158+
}
1159+
}
1160+
1161+
err := localRepo.Restore(test.paths)
1162+
if test.wantErr {
1163+
if err == nil {
1164+
t.Fatalf("%s should return error", test.name)
1165+
}
1166+
if !strings.Contains(err.Error(), test.wantErrPhrase) {
1167+
t.Errorf("Restore() returned error %q, want to contain %q", err.Error(), test.wantErrPhrase)
1168+
}
1169+
return
1170+
}
1171+
1172+
if err != nil {
1173+
t.Fatal(err)
1174+
}
1175+
1176+
for _, path := range test.paths {
1177+
got, err := os.ReadFile(filepath.Join(dir, path, "example.txt"))
1178+
if err != nil {
1179+
t.Fatal(err)
1180+
}
1181+
// Verify file contents are restored.
1182+
if diff := cmp.Diff("old content", string(got)); diff != "" {
1183+
t.Errorf("Restore() mismatch (-want +got):\n%s", diff)
1184+
}
1185+
// Verify the untracked files are untouched.
1186+
untrackedFile := filepath.Join(dir, path, "untracked.txt")
1187+
if _, err := os.Stat(untrackedFile); err != nil {
1188+
t.Errorf("untracked file, %s should not be removed", untrackedFile)
1189+
}
1190+
}
1191+
})
1192+
}
1193+
}
1194+
11201195
// initTestRepo creates a new git repository in a temporary directory.
11211196
func initTestRepo(t *testing.T) (*git.Repository, string) {
11221197
t.Helper()

internal/librarian/command.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ func findSubDirRelPaths(dir, subDir string) ([]string, error) {
555555
return nil, fmt.Errorf("subDir is not nested within the dir: %s, %s", subDir, dir)
556556
}
557557

558-
paths := []string{}
558+
var paths []string
559559
err = filepath.WalkDir(subDir, func(path string, d fs.DirEntry, err error) error {
560560
if err != nil {
561561
return err

internal/librarian/librarian_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ func TestRun(t *testing.T) {
4545
func TestGenerate_DefaultBehavior(t *testing.T) {
4646
ctx := context.Background()
4747

48-
// 1. Setup a mock repository with a state file
48+
// 1. Set up a mock repository with a state file
4949
repo := newTestGitRepo(t)
5050
repoDir := repo.GetDir()
5151

52-
// Setup a dummy API Source repo to prevent cloning googleapis/googleapis
52+
// Set up a dummy API Source repo to prevent cloning googleapis/googleapis
5353
apiSourceDir := t.TempDir()
5454
runGit(t, apiSourceDir, "init")
5555
runGit(t, apiSourceDir, "config", "user.email", "[email protected]")

internal/librarian/mocks_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ type MockRepository struct {
317317
ChangedFilesInCommitError error
318318
CreateBranchAndCheckoutError error
319319
PushError error
320+
RestoreError error
320321
}
321322

322323
func (m *MockRepository) IsClean() (bool, error) {
@@ -424,3 +425,7 @@ func (m *MockRepository) Push(name string) error {
424425
}
425426
return nil
426427
}
428+
429+
func (m *MockRepository) Restore(paths []string) error {
430+
return m.RestoreError
431+
}

0 commit comments

Comments
 (0)