Skip to content

Commit b41a73b

Browse files
authored
feat: read bot config (#23)
1 parent 88ab2fb commit b41a73b

File tree

5 files changed

+93
-29
lines changed

5 files changed

+93
-29
lines changed

src/git_draft/__main__.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
import sys
77

88
from .bots import Operation, load_bot
9-
from .common import Config, PROGRAM, Store, ensure_state_home, open_editor
9+
from .common import (
10+
Config,
11+
PROGRAM,
12+
Store,
13+
UnreachableError,
14+
ensure_state_home,
15+
open_editor,
16+
)
1017
from .manager import Manager
1118

1219

@@ -50,7 +57,6 @@ def callback(_option, _opt, _value, parser) -> None:
5057
"--bot",
5158
dest="bot",
5259
help="bot key",
53-
default="openai",
5460
)
5561
parser.add_option(
5662
"-c",
@@ -107,7 +113,16 @@ def main() -> None:
107113
)
108114
command = getattr(opts, "command", "generate")
109115
if command == "generate":
110-
bot = load_bot(opts.bot, {})
116+
if not config.bots:
117+
raise ValueError("No bots configured")
118+
if opts.bot:
119+
bot_configs = [c for c in config.bots if c.name == opts.bot]
120+
if len(bot_configs) != 1:
121+
raise ValueError(f"Found {len(bot_configs)} matching bots")
122+
bot_config = bot_configs[0]
123+
else:
124+
bot_config = config.bots[0]
125+
bot = load_bot(bot_config)
111126
prompt = opts.prompt
112127
if not prompt:
113128
if sys.stdin.isatty():
@@ -122,7 +137,7 @@ def main() -> None:
122137
elif command == "discard":
123138
manager.discard_draft(delete=opts.delete)
124139
else:
125-
assert False, "unreachable"
140+
raise UnreachableError()
126141

127142

128143
if __name__ == "__main__":

src/git_draft/bots/__init__.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
* https://aider.chat/docs/leaderboards/
44
"""
55

6-
from typing import Any, Mapping
6+
import importlib
7+
import sys
78

9+
from ..common import BotConfig
810
from .common import Action, Bot, Operation, OperationHook, Toolbox
911

1012
__all__ = [
@@ -16,13 +18,27 @@
1618
]
1719

1820

19-
def load_bot(entry: str, kwargs: Mapping[str, Any]) -> Bot:
20-
if entry == "openai":
21-
return _load_openai_bot(**kwargs)
22-
raise NotImplementedError() # TODO
21+
def load_bot(config: BotConfig) -> Bot:
22+
if config.pythonpath and config.pythonpath not in sys.path:
23+
sys.path.insert(0, config.pythonpath)
2324

25+
parts = config.factory.split(":", 1)
26+
if len(parts) == 1:
27+
module = sys.modules[__name__] # Default to this module
28+
symbol = parts[0]
29+
else:
30+
module_name, symbol = parts
31+
module = importlib.import_module(module_name)
2432

25-
def _load_openai_bot(**kwargs) -> Bot:
33+
factory = getattr(module, symbol, None)
34+
if not factory:
35+
raise NotImplementedError(f"Unknown factory: {factory}")
36+
37+
kwargs = config.config or {}
38+
return factory(**kwargs)
39+
40+
41+
def openai_bot(**kwargs) -> Bot:
2642
from .openai import OpenAIBot
2743

2844
return OpenAIBot(**kwargs)

src/git_draft/common.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import sys
1616
import tempfile
1717
import tomllib
18-
from typing import Any, Iterator, Mapping, Self
18+
from typing import Any, Iterator, Mapping, Self, Sequence
1919
import xdg_base_dirs
2020

2121

@@ -25,17 +25,17 @@
2525
@dataclasses.dataclass(frozen=True)
2626
class Config:
2727
log_level: int
28-
bots: Mapping[str, BotConfig]
28+
bots: Sequence[BotConfig]
2929
# TODO: Add (prompt) templates.
3030

31-
@classmethod
32-
def default(cls) -> Self:
33-
return cls(logging.INFO, {})
34-
3531
@staticmethod
3632
def path() -> Path:
3733
return xdg_base_dirs.xdg_config_home() / PROGRAM / "config.toml"
3834

35+
@classmethod
36+
def default(cls) -> Self:
37+
return cls(logging.INFO, [])
38+
3939
@classmethod
4040
def load(cls) -> Self:
4141
path = cls.path()
@@ -45,10 +45,9 @@ def load(cls) -> Self:
4545
except FileNotFoundError:
4646
return cls.default()
4747
else:
48-
bot_data = data.get("bots", {})
4948
return cls(
5049
log_level=logging.getLevelName(data["log_level"]),
51-
bots={k: BotConfig(**v) for k, v in bot_data.items()},
50+
bots=[BotConfig(**v) for v in data.get("bots", [])],
5251
)
5352

5453

@@ -58,8 +57,9 @@ def load(cls) -> Self:
5857

5958
@dataclasses.dataclass(frozen=True)
6059
class BotConfig:
61-
loader: str
62-
kwargs: JSONObject | None = None
60+
factory: str
61+
name: str | None = None
62+
config: JSONObject | None = None
6363
pythonpath: str | None = None
6464

6565

@@ -157,3 +157,7 @@ def sql(name: str) -> str:
157157

158158
def random_id(n: int) -> str:
159159
return "".join(_random.choices(_alphabet, k=n))
160+
161+
162+
class UnreachableError(RuntimeError):
163+
pass
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import importlib
2+
import sys
3+
import pytest
4+
5+
from git_draft.bots import Bot, load_bot
6+
from git_draft.common import BotConfig
7+
8+
9+
class FakeBot(Bot):
10+
pass
11+
12+
13+
class TestLoadBot:
14+
def test_existing_factory(self, monkeypatch) -> None:
15+
def import_module(name):
16+
assert name == "fake_module"
17+
return sys.modules[__name__]
18+
19+
monkeypatch.setattr(importlib, "import_module", import_module)
20+
21+
config = BotConfig(factory="fake_module:FakeBot")
22+
bot = load_bot(config)
23+
assert isinstance(bot, FakeBot)
24+
25+
def test_non_existing_factory(self) -> None:
26+
config = BotConfig("git_draft:unknown_factory")
27+
with pytest.raises(NotImplementedError):
28+
load_bot(config)

tests/git_draft/common_test.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,14 @@ def test_load_ok(self) -> None:
5959
text = """\
6060
log_level = "DEBUG"
6161
62-
[bots.foo]
63-
loader = "foo:load"
62+
[[bots]]
63+
factory = "foo:load"
6464
pythonpath = "./abc"
6565
66-
[bots.bar]
67-
loader = "bar"
68-
kwargs = {one=1}
66+
[[bots]]
67+
name = "bar"
68+
factory = "bar"
69+
config = {one=1}
6970
"""
7071
path = sut.Config.path()
7172
path.parent.mkdir(parents=True, exist_ok=True)
@@ -75,10 +76,10 @@ def test_load_ok(self) -> None:
7576
config = sut.Config.load()
7677
assert config == sut.Config(
7778
log_level=logging.DEBUG,
78-
bots={
79-
"foo": sut.BotConfig(loader="foo:load", pythonpath="./abc"),
80-
"bar": sut.BotConfig(loader="bar", kwargs={"one": 1}),
81-
},
79+
bots=[
80+
sut.BotConfig(factory="foo:load", pythonpath="./abc"),
81+
sut.BotConfig(factory="bar", name="bar", config={"one": 1}),
82+
],
8283
)
8384

8485
def test_load_default(self) -> None:

0 commit comments

Comments
 (0)