Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .common import PROGRAM, Config, UnreachableError, ensure_state_home
from .drafter import Accept, Drafter
from .editor import open_editor
from .git import Repo
from .prompt import Template, TemplatedPrompt, find_template, templates_table
from .store import Store
from .toolbox import ToolVisitor
Expand Down Expand Up @@ -165,11 +166,13 @@ def main() -> None: # noqa: PLR0912 PLR0915
return
logging.basicConfig(level=config.log_level, filename=str(log_path))

drafter = Drafter.create(store=Store.persistent(), path=opts.root)
repo = Repo.enclosing(Path(opts.root) if opts.root else Path.cwd())
drafter = Drafter.create(repo, Store.persistent())
match getattr(opts, "command", "generate"):
case "generate":
bot_config = None
if opts.bot:
bot_name = opts.bot or repo.default_bot()
if bot_name:
bot_configs = [c for c in config.bots if c.name == opts.bot]
if len(bot_configs) != 1:
raise ValueError(f"Found {len(bot_configs)} matching bots")
Expand Down
9 changes: 4 additions & 5 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import enum
import json
import logging
from pathlib import Path, PurePosixPath
from pathlib import PurePosixPath
import re
import textwrap
import time
Expand Down Expand Up @@ -85,14 +85,13 @@ class Drafter:
"""Draft state orchestrator"""

def __init__(self, store: Store, repo: Repo) -> None:
with store.cursor() as cursor:
cursor.executescript(sql("create-tables"))
self._store = store
self._repo = repo

@classmethod
def create(cls, store: Store, path: str | None = None) -> Drafter:
repo = Repo.enclosing(Path(path) if path else Path.cwd())
def create(cls, repo: Repo, store: Store) -> Drafter:
with store.cursor() as cursor:
cursor.executescript(sql("create-tables"))
return cls(store, repo)

def generate_draft( # noqa: PLR0913
Expand Down
18 changes: 13 additions & 5 deletions src/git_draft/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class GitError(Exception):

class _ConfigKey(enum.StrEnum):
REPO_UUID = "repouuid"
DEFAULT_BOT = "bot" # TODO: Use
DEFAULT_BOT = "bot"

@property
def fullname(self) -> str:
Expand Down Expand Up @@ -103,17 +103,25 @@ def git(
def active_branch(self) -> str | None:
return self.git("branch", "--show-current").stdout or None

def default_bot(self) -> str | None:
return _get_config_value(_ConfigKey.DEFAULT_BOT, self.working_dir)

def _ensure_repo_uuid(working_dir: Path) -> uuid.UUID:

def _get_config_value(key: _ConfigKey, working_dir: Path) -> str | None:
call = GitCall.sync(
"config",
"get",
_ConfigKey.REPO_UUID.fullname,
key.fullname,
working_dir=working_dir,
expect_codes=(),
)
if call.code == 0:
return uuid.UUID(call.stdout)
return None if call.code else call.stdout


def _ensure_repo_uuid(working_dir: Path) -> uuid.UUID:
value = _get_config_value(_ConfigKey.REPO_UUID, working_dir)
if value:
return uuid.UUID(value)
repo_uuid = uuid.uuid4()
GitCall.sync(
"config",
Expand Down
2 changes: 1 addition & 1 deletion tests/git_draft/drafter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TestDrafter:
def setup(self, repo: Repo, repo_fs: RepoFS) -> None:
self._repo = repo
self._fs = repo_fs
self._drafter = sut.Drafter(Store.in_memory(), repo)
self._drafter = sut.Drafter.create(repo, Store.in_memory())

def _commits(self, ref: str | None = None) -> Sequence[SHA]:
git = self._repo.git("log", "--pretty=format:%H", ref or "HEAD")
Expand Down