@@ -36,6 +36,13 @@ class Accept(enum.Enum):
3636 NO_REGRETS = enum .auto ()
3737
3838
39+ @dataclasses .dataclass (frozen = True )
40+ class Draft :
41+ """Collection of generated changes"""
42+
43+ branch_name : str
44+
45+
3946@dataclasses .dataclass (frozen = True )
4047class _Branch :
4148 """Draft branch"""
@@ -94,7 +101,7 @@ def generate_draft( # noqa: PLR0913
94101 sync : bool = False ,
95102 timeout : float | None = None ,
96103 tool_visitors : Sequence [ToolVisitor ] | None = None ,
97- ) -> str :
104+ ) -> Draft :
98105 if timeout is not None :
99106 raise NotImplementedError () # TODO: Implement
100107
@@ -165,7 +172,7 @@ def generate_draft( # noqa: PLR0913
165172 delta .apply ()
166173 if accept .value >= Accept .FINALIZE .value :
167174 self .finalize_draft (delete = accept == Accept .NO_REGRETS , sync = sync )
168- return str (branch )
175+ return Draft ( str (branch ) )
169176
170177 def _prepare_prompt (
171178 self ,
@@ -214,7 +221,7 @@ def _generate_change(
214221
215222 def finalize_draft (
216223 self , * , delete : bool = False , sync : bool = False
217- ) -> str :
224+ ) -> Draft :
218225 branch = _Branch .active (self ._repo )
219226 if not branch :
220227 raise RuntimeError ("Not currently on a draft branch" )
@@ -240,7 +247,7 @@ def finalize_draft(
240247 _logger .debug ("Deleted branch %s." , branch )
241248
242249 _logger .info ("Exited %s." , branch )
243- return branch .name
250+ return Draft ( branch .name )
244251
245252 def _create_branch (self , sync : bool ) -> _Branch :
246253 if self ._repo .head .is_detached :
0 commit comments