@@ -60,9 +60,12 @@ def setup(self, repo: git.Repo) -> None:
6060 def _path (self , name : str ) -> Path :
6161 return Path (self ._repo .working_dir , name )
6262
63- def _read (self , name : str ) -> str :
64- with open (self ._path (name )) as f :
65- return f .read ()
63+ def _read (self , name : str ) -> str | None :
64+ try :
65+ with open (self ._path (name )) as f :
66+ return f .read ()
67+ except FileNotFoundError :
68+ return None
6669
6770 def _write (self , name : str , contents = "" ) -> None :
6871 with open (self ._path (name ), "w" ) as f :
@@ -95,9 +98,9 @@ def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
9598 self ._drafter .generate_draft ("hello" , CustomBot ())
9699 assert self ._commit_files ("HEAD" ) == set (["p2" , "p3" ])
97100
98- def test_generate_then_discard_draft (self ) -> None :
101+ def test_generate_then_revert_draft (self ) -> None :
99102 self ._drafter .generate_draft ("hello" , FakeBot ())
100- self ._drafter .discard_draft ()
103+ self ._drafter .revert_draft ()
101104 assert len (self ._commits ()) == 1
102105
103106 def test_generate_outside_branch (self ) -> None :
@@ -129,7 +132,7 @@ def test_generate_clean_index_sync(self) -> None:
129132 prompt = TemplatedPrompt ("add-test" , {"symbol" : "abc" })
130133 self ._drafter .generate_draft (prompt , FakeBot (), sync = True )
131134 self ._repo .git .checkout ("." )
132- assert "abc" in self ._read ("PROMPT" )
135+ assert "abc" in ( self ._read ("PROMPT" ) or " " )
133136 assert len (self ._commits ()) == 2 # init, prompt
134137
135138 def test_generate_reuse_branch (self ) -> None :
@@ -157,29 +160,53 @@ def act(self, _goal: Goal, _toolbox: Toolbox) -> Action:
157160 assert len (self ._commits ()) == 2 # init, prompt
158161 assert not self ._commit_files ("HEAD" )
159162
160- def test_discard_outside_draft (self ) -> None :
163+ def test_revert_outside_draft (self ) -> None :
161164 with pytest .raises (RuntimeError ):
162- self ._drafter .discard_draft ()
165+ self ._drafter .revert_draft ()
163166
164- def test_discard_after_branch_move (self ) -> None :
167+ def test_revert_after_branch_move (self ) -> None :
165168 self ._write ("log" , "11" )
166169 self ._drafter .generate_draft ("hi" , FakeBot (), sync = True )
167170 branch = self ._repo .active_branch
168171 self ._repo .git .checkout ("main" )
169172 self ._repo .index .commit ("advance" )
170173 self ._repo .git .checkout (branch )
171174 with pytest .raises (RuntimeError ):
172- self ._drafter .discard_draft ()
175+ self ._drafter .revert_draft ()
173176
174- def test_discard_restores_worktree (self ) -> None :
177+ def test_revert_restores_worktree (self ) -> None :
175178 self ._write ("p1.txt" , "a1" )
176179 self ._write ("p2.txt" , "b1" )
177180 self ._drafter .generate_draft ("hello" , FakeBot (), sync = True )
178181 self ._write ("p1.txt" , "a2" )
179- self ._drafter .discard_draft (delete = True )
182+ self ._drafter .revert_draft (delete = True )
180183 assert self ._read ("p1.txt" ) == "a1"
181184 assert self ._read ("p2.txt" ) == "b1"
182185
186+ def test_revert_keeps_untouched_files (self ) -> None :
187+ class CustomBot (Bot ):
188+ def act (self , _goal : Goal , toolbox : Toolbox ) -> Action :
189+ toolbox .write_file (PurePosixPath ("p2.txt" ), "t2" )
190+ toolbox .write_file (PurePosixPath ("p4.txt" ), "t2" )
191+ return Action ()
192+
193+ self ._write ("p1.txt" , "t0" )
194+ self ._write ("p2.txt" , "t0" )
195+ self ._repo .git .add (all = True )
196+ self ._repo .index .commit ("update" )
197+ self ._write ("p1.txt" , "t1" )
198+ self ._write ("p2.txt" , "t1" )
199+ self ._write ("p3.txt" , "t1" )
200+ self ._drafter .generate_draft ("hello" , CustomBot ())
201+ self ._write ("p1.txt" , "t3" )
202+ self ._write ("p2.txt" , "t3" )
203+ self ._drafter .revert_draft ()
204+
205+ assert self ._read ("p1.txt" ) == "t3"
206+ assert self ._read ("p2.txt" ) == "t0"
207+ assert self ._read ("p3.txt" ) == "t1"
208+ assert self ._read ("p4.txt" ) is None
209+
183210 def test_finalize_keeps_changes (self ) -> None :
184211 self ._write ("p1.txt" , "a1" )
185212 self ._drafter .generate_draft ("hello" , FakeBot (), checkout = True )
0 commit comments