Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,17 @@ def generate_draft( # noqa: PLR0913
if timeout is not None:
raise NotImplementedError() # TODO: Implement

# Handle prompt templating and editing. We do this first in case this
# fails, to avoid creating unnecessary branches.
toolbox, dirty = RepoToolbox.for_working_dir(self._repo)
prompt_contents = self._prepare_prompt(
prompt, prompt_transform, toolbox
)

# Ensure that we are in a folio.
folio = _active_folio(self._repo)
if not folio:
folio = self._create_folio()
toolbox, parent_commit_sha = self._synced_toolbox()

# Handle prompt templating and editing.
prompt_contents = self._prepare_prompt(
prompt, prompt_transform, toolbox
)
with self._store.cursor() as cursor:
[(prompt_id, seqno)] = cursor.execute(
sql("add-prompt"),
Expand All @@ -139,9 +140,15 @@ def generate_draft( # noqa: PLR0913
),
)

# Update metadata.
# Create git commits, references, and update branches.
if dirty:
parent_commit_rev = self._commit_tree(
toolbox.tree_sha(), "HEAD", "sync(prompt)"
)
else:
parent_commit_rev = "HEAD"
commit_sha = self._record_change(
change, parent_commit_sha, folio, seqno
change, parent_commit_rev, folio, seqno
)
# TODO: Trim commits (sync and prompt of files which have not been
# operated on). This will improve the UX by allowing fast-forward when
Expand Down Expand Up @@ -262,17 +269,6 @@ def _create_folio(self) -> Folio:

return folio

def _synced_toolbox(self) -> tuple[RepoToolbox, SHA]:
toolbox, dirty = RepoToolbox.for_working_dir(self._repo)
if dirty:
upstream_rev = self._commit_tree(
toolbox.tree_sha(), "HEAD", "sync(prompt)"
)
_logger.debug("Created sync commit. [sha=%r]", upstream_rev)
else:
upstream_rev = "HEAD"
return toolbox, upstream_rev

def _sync_head(self, scope: str) -> None:
self._repo.git("add", "-A")
index_call = self._repo.git(
Expand Down Expand Up @@ -330,10 +326,10 @@ def _generate_change(
)

def _record_change(
self, change: _Change, parent_commit_sha: SHA, folio: Folio, seqno: int
self, change: _Change, parent_commit_rev: str, folio: Folio, seqno: int
) -> SHA:
commit_sha = self._commit_tree(
change.tree_sha, parent_commit_sha, change.commit_message
change.tree_sha, parent_commit_rev, change.commit_message
)
_logger.debug("Created prompt commit. [sha=%r]", commit_sha)
self._repo.git(
Expand Down