diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 6fc194a..4c06575 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -13,6 +13,7 @@ from .common import PROGRAM, Config, UnreachableError, ensure_state_home from .drafter import Accept, Drafter from .editor import open_editor +from .git import Repo from .prompt import Template, TemplatedPrompt, find_template, templates_table from .store import Store from .toolbox import ToolVisitor @@ -165,11 +166,13 @@ def main() -> None: # noqa: PLR0912 PLR0915 return logging.basicConfig(level=config.log_level, filename=str(log_path)) - drafter = Drafter.create(store=Store.persistent(), path=opts.root) + repo = Repo.enclosing(Path(opts.root) if opts.root else Path.cwd()) + drafter = Drafter.create(repo, Store.persistent()) match getattr(opts, "command", "generate"): case "generate": bot_config = None - if opts.bot: + bot_name = opts.bot or repo.default_bot() + if bot_name: bot_configs = [c for c in config.bots if c.name == opts.bot] if len(bot_configs) != 1: raise ValueError(f"Found {len(bot_configs)} matching bots") diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index 6bb7b84..b2fd3be 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -8,7 +8,7 @@ import enum import json import logging -from pathlib import Path, PurePosixPath +from pathlib import PurePosixPath import re import textwrap import time @@ -85,14 +85,13 @@ class Drafter: """Draft state orchestrator""" def __init__(self, store: Store, repo: Repo) -> None: - with store.cursor() as cursor: - cursor.executescript(sql("create-tables")) self._store = store self._repo = repo @classmethod - def create(cls, store: Store, path: str | None = None) -> Drafter: - repo = Repo.enclosing(Path(path) if path else Path.cwd()) + def create(cls, repo: Repo, store: Store) -> Drafter: + with store.cursor() as cursor: + cursor.executescript(sql("create-tables")) return cls(store, repo) def generate_draft( # noqa: PLR0913 diff --git a/src/git_draft/git.py b/src/git_draft/git.py index dc3476b..98fb87c 100644 --- a/src/git_draft/git.py +++ b/src/git_draft/git.py @@ -62,7 +62,7 @@ class GitError(Exception): class _ConfigKey(enum.StrEnum): REPO_UUID = "repouuid" - DEFAULT_BOT = "bot" # TODO: Use + DEFAULT_BOT = "bot" @property def fullname(self) -> str: @@ -103,17 +103,25 @@ def git( def active_branch(self) -> str | None: return self.git("branch", "--show-current").stdout or None + def default_bot(self) -> str | None: + return _get_config_value(_ConfigKey.DEFAULT_BOT, self.working_dir) -def _ensure_repo_uuid(working_dir: Path) -> uuid.UUID: + +def _get_config_value(key: _ConfigKey, working_dir: Path) -> str | None: call = GitCall.sync( "config", "get", - _ConfigKey.REPO_UUID.fullname, + key.fullname, working_dir=working_dir, expect_codes=(), ) - if call.code == 0: - return uuid.UUID(call.stdout) + return None if call.code else call.stdout + + +def _ensure_repo_uuid(working_dir: Path) -> uuid.UUID: + value = _get_config_value(_ConfigKey.REPO_UUID, working_dir) + if value: + return uuid.UUID(value) repo_uuid = uuid.uuid4() GitCall.sync( "config", diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index 8e62489..6c9c841 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -44,7 +44,7 @@ class TestDrafter: def setup(self, repo: Repo, repo_fs: RepoFS) -> None: self._repo = repo self._fs = repo_fs - self._drafter = sut.Drafter(Store.in_memory(), repo) + self._drafter = sut.Drafter.create(repo, Store.in_memory()) def _commits(self, ref: str | None = None) -> Sequence[SHA]: git = self._repo.git("log", "--pretty=format:%H", ref or "HEAD")