Skip to content

Commit d64fd55

Browse files
authored
feat: support setting default bot per repo (#68)
1 parent 75b1703 commit d64fd55

File tree

4 files changed

+23
-13
lines changed

4 files changed

+23
-13
lines changed

src/git_draft/__main__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .common import PROGRAM, Config, UnreachableError, ensure_state_home
1414
from .drafter import Accept, Drafter
1515
from .editor import open_editor
16+
from .git import Repo
1617
from .prompt import Template, TemplatedPrompt, find_template, templates_table
1718
from .store import Store
1819
from .toolbox import ToolVisitor
@@ -165,11 +166,13 @@ def main() -> None: # noqa: PLR0912 PLR0915
165166
return
166167
logging.basicConfig(level=config.log_level, filename=str(log_path))
167168

168-
drafter = Drafter.create(store=Store.persistent(), path=opts.root)
169+
repo = Repo.enclosing(Path(opts.root) if opts.root else Path.cwd())
170+
drafter = Drafter.create(repo, Store.persistent())
169171
match getattr(opts, "command", "generate"):
170172
case "generate":
171173
bot_config = None
172-
if opts.bot:
174+
bot_name = opts.bot or repo.default_bot()
175+
if bot_name:
173176
bot_configs = [c for c in config.bots if c.name == opts.bot]
174177
if len(bot_configs) != 1:
175178
raise ValueError(f"Found {len(bot_configs)} matching bots")

src/git_draft/drafter.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import enum
99
import json
1010
import logging
11-
from pathlib import Path, PurePosixPath
11+
from pathlib import PurePosixPath
1212
import re
1313
import textwrap
1414
import time
@@ -85,14 +85,13 @@ class Drafter:
8585
"""Draft state orchestrator"""
8686

8787
def __init__(self, store: Store, repo: Repo) -> None:
88-
with store.cursor() as cursor:
89-
cursor.executescript(sql("create-tables"))
9088
self._store = store
9189
self._repo = repo
9290

9391
@classmethod
94-
def create(cls, store: Store, path: str | None = None) -> Drafter:
95-
repo = Repo.enclosing(Path(path) if path else Path.cwd())
92+
def create(cls, repo: Repo, store: Store) -> Drafter:
93+
with store.cursor() as cursor:
94+
cursor.executescript(sql("create-tables"))
9695
return cls(store, repo)
9796

9897
def generate_draft( # noqa: PLR0913

src/git_draft/git.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class GitError(Exception):
6262

6363
class _ConfigKey(enum.StrEnum):
6464
REPO_UUID = "repouuid"
65-
DEFAULT_BOT = "bot" # TODO: Use
65+
DEFAULT_BOT = "bot"
6666

6767
@property
6868
def fullname(self) -> str:
@@ -103,17 +103,25 @@ def git(
103103
def active_branch(self) -> str | None:
104104
return self.git("branch", "--show-current").stdout or None
105105

106+
def default_bot(self) -> str | None:
107+
return _get_config_value(_ConfigKey.DEFAULT_BOT, self.working_dir)
106108

107-
def _ensure_repo_uuid(working_dir: Path) -> uuid.UUID:
109+
110+
def _get_config_value(key: _ConfigKey, working_dir: Path) -> str | None:
108111
call = GitCall.sync(
109112
"config",
110113
"get",
111-
_ConfigKey.REPO_UUID.fullname,
114+
key.fullname,
112115
working_dir=working_dir,
113116
expect_codes=(),
114117
)
115-
if call.code == 0:
116-
return uuid.UUID(call.stdout)
118+
return None if call.code else call.stdout
119+
120+
121+
def _ensure_repo_uuid(working_dir: Path) -> uuid.UUID:
122+
value = _get_config_value(_ConfigKey.REPO_UUID, working_dir)
123+
if value:
124+
return uuid.UUID(value)
117125
repo_uuid = uuid.uuid4()
118126
GitCall.sync(
119127
"config",

tests/git_draft/drafter_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class TestDrafter:
4444
def setup(self, repo: Repo, repo_fs: RepoFS) -> None:
4545
self._repo = repo
4646
self._fs = repo_fs
47-
self._drafter = sut.Drafter(Store.in_memory(), repo)
47+
self._drafter = sut.Drafter.create(repo, Store.in_memory())
4848

4949
def _commits(self, ref: str | None = None) -> Sequence[SHA]:
5050
git = self._repo.git("log", "--pretty=format:%H", ref or "HEAD")

0 commit comments

Comments
 (0)