diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index fadfcb6..e62e2db 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -44,8 +44,15 @@ def callback(_option, _opt, _value, parser) -> None: parser.add_option( "-a", "--assistant", - dest="ASSISTANT", + dest="assistant", help="assistant key", + default="openai", +) +parser.add_option( + "-c", + "--checkout", + help="check out generated changes", + action="store_true", ) parser.add_option( "-d", @@ -87,7 +94,9 @@ def main() -> None: prompt = open_editor(textwrap.dedent(EDITOR_PLACEHOLDER)) else: prompt = sys.stdin.read() - manager.generate_draft(prompt, assistant, reset=opts.reset) + manager.generate_draft( + prompt, assistant, checkout=opts.checkout, reset=opts.reset + ) elif command == "finalize": manager.finalize_draft(delete=opts.delete) elif command == "discard": diff --git a/src/git_draft/manager.py b/src/git_draft/manager.py index cbcd232..1f42db7 100644 --- a/src/git_draft/manager.py +++ b/src/git_draft/manager.py @@ -78,7 +78,7 @@ def needs_rebase(self, repo: git.Repo) -> bool: init_commit = repo.commit(self.init_shortsha) (origin_commit,) = init_commit.parents head_commit = repo.commit(self.init_note.origin_branch) - return origin_commit == head_commit + return origin_commit != head_commit @classmethod def create(cls, repo: git.Repo, sync: Callable[[], str | None]) -> _Branch: @@ -148,7 +148,7 @@ def __init__(self, repo: git.Repo) -> None: self._repo = repo def generate_draft( - self, prompt: str, assistant: Assistant, reset=False + self, prompt: str, assistant: Assistant, checkout=False, reset=False ) -> None: if not prompt.strip(): raise ValueError("Empty prompt") @@ -165,6 +165,8 @@ def generate_draft( assistant.run(prompt, _Toolbox(self._repo)) self._repo.index.commit(f"draft! prompt\n\n{prompt}") + if checkout: + self._repo.git.checkout("--", ".") def finalize_draft(self, delete=False) -> None: self._exit_draft(True, delete=delete)