Skip to content

Commit 2d847ec

Browse files
authored
feat: add checkout option (#13)
1 parent 3bb32db commit 2d847ec

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/git_draft/__main__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,15 @@ def callback(_option, _opt, _value, parser) -> None:
4444
parser.add_option(
4545
"-a",
4646
"--assistant",
47-
dest="ASSISTANT",
47+
dest="assistant",
4848
help="assistant key",
49+
default="openai",
50+
)
51+
parser.add_option(
52+
"-c",
53+
"--checkout",
54+
help="check out generated changes",
55+
action="store_true",
4956
)
5057
parser.add_option(
5158
"-d",
@@ -87,7 +94,9 @@ def main() -> None:
8794
prompt = open_editor(textwrap.dedent(EDITOR_PLACEHOLDER))
8895
else:
8996
prompt = sys.stdin.read()
90-
manager.generate_draft(prompt, assistant, reset=opts.reset)
97+
manager.generate_draft(
98+
prompt, assistant, checkout=opts.checkout, reset=opts.reset
99+
)
91100
elif command == "finalize":
92101
manager.finalize_draft(delete=opts.delete)
93102
elif command == "discard":

src/git_draft/manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def needs_rebase(self, repo: git.Repo) -> bool:
7878
init_commit = repo.commit(self.init_shortsha)
7979
(origin_commit,) = init_commit.parents
8080
head_commit = repo.commit(self.init_note.origin_branch)
81-
return origin_commit == head_commit
81+
return origin_commit != head_commit
8282

8383
@classmethod
8484
def create(cls, repo: git.Repo, sync: Callable[[], str | None]) -> _Branch:
@@ -148,7 +148,7 @@ def __init__(self, repo: git.Repo) -> None:
148148
self._repo = repo
149149

150150
def generate_draft(
151-
self, prompt: str, assistant: Assistant, reset=False
151+
self, prompt: str, assistant: Assistant, checkout=False, reset=False
152152
) -> None:
153153
if not prompt.strip():
154154
raise ValueError("Empty prompt")
@@ -165,6 +165,8 @@ def generate_draft(
165165

166166
assistant.run(prompt, _Toolbox(self._repo))
167167
self._repo.index.commit(f"draft! prompt\n\n{prompt}")
168+
if checkout:
169+
self._repo.git.checkout("--", ".")
168170

169171
def finalize_draft(self, delete=False) -> None:
170172
self._exit_draft(True, delete=delete)

0 commit comments

Comments
 (0)