|
15 | 15 | package gitrepo |
16 | 16 |
|
17 | 17 | import ( |
| 18 | + "fmt" |
18 | 19 | "os" |
19 | 20 | "path/filepath" |
20 | 21 | "strings" |
@@ -1117,6 +1118,80 @@ func TestCreateBranchAndCheckout(t *testing.T) { |
1117 | 1118 | } |
1118 | 1119 | } |
1119 | 1120 |
|
| 1121 | +func TestRestore(t *testing.T) { |
| 1122 | + for _, test := range []struct { |
| 1123 | + name string |
| 1124 | + paths []string |
| 1125 | + wantErr bool |
| 1126 | + wantErrPhrase string |
| 1127 | + }{ |
| 1128 | + { |
| 1129 | + name: "restore files in paths", |
| 1130 | + paths: []string{ |
| 1131 | + "first/path", |
| 1132 | + "second/path", |
| 1133 | + }, |
| 1134 | + }, |
| 1135 | + } { |
| 1136 | + t.Run(test.name, func(t *testing.T) { |
| 1137 | + repo, dir := initTestRepo(t) |
| 1138 | + localRepo := &LocalRepository{ |
| 1139 | + Dir: dir, |
| 1140 | + repo: repo, |
| 1141 | + } |
| 1142 | + // Create files in test.paths and commit the change. |
| 1143 | + for _, path := range test.paths { |
| 1144 | + file := filepath.Join(path, "example.txt") |
| 1145 | + createAndCommit(t, repo, file, []byte("old content"), fmt.Sprintf("commit path, %s", path)) |
| 1146 | + } |
| 1147 | + |
| 1148 | + // Change file contents. |
| 1149 | + for _, path := range test.paths { |
| 1150 | + file := filepath.Join(dir, path, "example.txt") |
| 1151 | + if err := os.WriteFile(file, []byte("new content"), 0755); err != nil { |
| 1152 | + t.Fatal(err) |
| 1153 | + } |
| 1154 | + // Create untracked files. |
| 1155 | + untrackedFile := filepath.Join(dir, path, "untracked.txt") |
| 1156 | + if err := os.WriteFile(untrackedFile, []byte("new content"), 0755); err != nil { |
| 1157 | + t.Fatal(err) |
| 1158 | + } |
| 1159 | + } |
| 1160 | + |
| 1161 | + err := localRepo.Restore(test.paths) |
| 1162 | + if test.wantErr { |
| 1163 | + if err == nil { |
| 1164 | + t.Fatalf("%s should return error", test.name) |
| 1165 | + } |
| 1166 | + if !strings.Contains(err.Error(), test.wantErrPhrase) { |
| 1167 | + t.Errorf("Restore() returned error %q, want to contain %q", err.Error(), test.wantErrPhrase) |
| 1168 | + } |
| 1169 | + return |
| 1170 | + } |
| 1171 | + |
| 1172 | + if err != nil { |
| 1173 | + t.Fatal(err) |
| 1174 | + } |
| 1175 | + |
| 1176 | + for _, path := range test.paths { |
| 1177 | + got, err := os.ReadFile(filepath.Join(dir, path, "example.txt")) |
| 1178 | + if err != nil { |
| 1179 | + t.Fatal(err) |
| 1180 | + } |
| 1181 | + // Verify file contents are restored. |
| 1182 | + if diff := cmp.Diff("old content", string(got)); diff != "" { |
| 1183 | + t.Errorf("Restore() mismatch (-want +got):\n%s", diff) |
| 1184 | + } |
| 1185 | + // Verify the untracked files are untouched. |
| 1186 | + untrackedFile := filepath.Join(dir, path, "untracked.txt") |
| 1187 | + if _, err := os.Stat(untrackedFile); err != nil { |
| 1188 | + t.Errorf("untracked file, %s should not be removed", untrackedFile) |
| 1189 | + } |
| 1190 | + } |
| 1191 | + }) |
| 1192 | + } |
| 1193 | +} |
| 1194 | + |
1120 | 1195 | // initTestRepo creates a new git repository in a temporary directory. |
1121 | 1196 | func initTestRepo(t *testing.T) (*git.Repository, string) { |
1122 | 1197 | t.Helper() |
|
0 commit comments