Skip to content

Commit c19e2d3

Browse files
authored
feat: add prompt templating (#24)
1 parent b41a73b commit c19e2d3

File tree

6 files changed

+161
-20
lines changed

6 files changed

+161
-20
lines changed

poetry.lock

Lines changed: 90 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ requires-python = '>=3.12'
99
dependencies = [
1010
'gitpython (>=3.1.44,<4)',
1111
"xdg-base-dirs (>=6.0.2,<7.0.0)",
12+
"jinja2 (>=3.1.5,<4.0.0)",
1213
]
1314

1415
[project.optional-dependencies]

src/git_draft/__main__.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .common import (
1010
Config,
1111
PROGRAM,
12+
PromptRenderer,
1213
Store,
1314
UnreachableError,
1415
ensure_state_home,
@@ -74,7 +75,7 @@ def callback(_option, _opt, _value, parser) -> None:
7475
"-p",
7576
"--prompt",
7677
dest="prompt",
77-
help="draft generation prompt, read from stdin if unset",
78+
help="inline prompt",
7879
)
7980
parser.add_option(
8081
"-r",
@@ -88,6 +89,12 @@ def callback(_option, _opt, _value, parser) -> None:
8889
help="commit prior worktree changes separately",
8990
action="store_true",
9091
)
92+
parser.add_option(
93+
"-t",
94+
"--template",
95+
dest="template",
96+
help="prompt template",
97+
)
9198

9299
return parser
93100

@@ -98,7 +105,7 @@ def print_operation(op: Operation) -> None:
98105

99106
def main() -> None:
100107
config = Config.load()
101-
(opts, _args) = new_parser().parse_args()
108+
(opts, args) = new_parser().parse_args()
102109

103110
log_path = ensure_state_home() / "log"
104111
if opts.log:
@@ -123,12 +130,18 @@ def main() -> None:
123130
else:
124131
bot_config = config.bots[0]
125132
bot = load_bot(bot_config)
133+
126134
prompt = opts.prompt
127135
if not prompt:
128-
if sys.stdin.isatty():
136+
if opts.template:
137+
renderer = PromptRenderer.default()
138+
kwargs = dict(e.split("=", 1) for e in args)
139+
prompt = renderer.render(opts.template, **kwargs)
140+
elif sys.stdin.isatty():
129141
prompt = open_editor("Enter your prompt here...")
130142
else:
131143
prompt = sys.stdin.read()
144+
132145
manager.generate_draft(
133146
prompt, bot, checkout=opts.checkout, reset=opts.reset
134147
)

src/git_draft/common.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import dataclasses
55
from datetime import datetime
66
import functools
7+
import jinja2
78
import logging
89
import os
910
from pathlib import Path
@@ -22,23 +23,35 @@
2223
PROGRAM = "git-draft"
2324

2425

26+
type JSONValue = Any
27+
type JSONObject = Mapping[str, JSONValue]
28+
29+
30+
_package_root = Path(__file__).parent
31+
32+
33+
def ensure_state_home() -> Path:
34+
path = xdg_base_dirs.xdg_state_home() / PROGRAM
35+
path.mkdir(parents=True, exist_ok=True)
36+
return path
37+
38+
2539
@dataclasses.dataclass(frozen=True)
2640
class Config:
2741
log_level: int
2842
bots: Sequence[BotConfig]
29-
# TODO: Add (prompt) templates.
3043

3144
@staticmethod
32-
def path() -> Path:
33-
return xdg_base_dirs.xdg_config_home() / PROGRAM / "config.toml"
45+
def folder_path() -> Path:
46+
return xdg_base_dirs.xdg_config_home() / PROGRAM
3447

3548
@classmethod
3649
def default(cls) -> Self:
3750
return cls(logging.INFO, [])
3851

3952
@classmethod
4053
def load(cls) -> Self:
41-
path = cls.path()
54+
path = cls.folder_path() / "config.toml"
4255
try:
4356
with open(path, "rb") as reader:
4457
data = tomllib.load(reader)
@@ -51,10 +64,6 @@ def load(cls) -> Self:
5164
)
5265

5366

54-
type JSONValue = Any
55-
type JSONObject = Mapping[str, JSONValue]
56-
57-
5867
@dataclasses.dataclass(frozen=True)
5968
class BotConfig:
6069
factory: str
@@ -63,10 +72,28 @@ class BotConfig:
6372
pythonpath: str | None = None
6473

6574

66-
def ensure_state_home() -> Path:
67-
path = xdg_base_dirs.xdg_state_home() / PROGRAM
68-
path.mkdir(parents=True, exist_ok=True)
69-
return path
75+
_prompt_root = _package_root / "prompts"
76+
77+
78+
class PromptRenderer:
79+
def __init__(self, env: jinja2.Environment) -> None:
80+
self._environment = env
81+
82+
@classmethod
83+
def default(cls):
84+
env = jinja2.Environment(
85+
loader=jinja2.FileSystemLoader(
86+
[Config.folder_path() / "prompts", str(_prompt_root)]
87+
),
88+
autoescape=False,
89+
keep_trailing_newline=True,
90+
auto_reload=False,
91+
)
92+
return cls(env)
93+
94+
def render(self, template_name: str, **kwargs) -> str:
95+
template = self._environment.get_template(f"{template_name}.jinja")
96+
return template.render(kwargs)
7097

7198

7299
_default_editors = ["vim", "emacs", "nano"]
@@ -141,7 +168,7 @@ def cursor(self) -> Iterator[sqlite3.Cursor]:
141168
self._connection.commit()
142169

143170

144-
_query_root = Path(__file__).parent / "queries"
171+
_query_root = _package_root / "queries"
145172

146173

147174
@functools.cache
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add a test for {{ symbol }}.

tests/git_draft/common_test.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def test_load_ok(self) -> None:
6868
factory = "bar"
6969
config = {one=1}
7070
"""
71-
path = sut.Config.path()
72-
path.parent.mkdir(parents=True, exist_ok=True)
73-
with open(path, "w") as f:
71+
path = sut.Config.folder_path()
72+
path.mkdir(parents=True, exist_ok=True)
73+
with open(path / "config.toml", "w") as f:
7474
f.write(textwrap.dedent(text))
7575

7676
config = sut.Config.load()
@@ -85,3 +85,13 @@ def test_load_ok(self) -> None:
8585
def test_load_default(self) -> None:
8686
config = sut.Config.load()
8787
assert config.log_level == logging.INFO
88+
89+
90+
class TestPromptRenderer:
91+
@pytest.fixture(autouse=True)
92+
def setup(self) -> None:
93+
self._renderer = sut.PromptRenderer.default()
94+
95+
def test_ok(self) -> None:
96+
prompt = self._renderer.render("add-test", symbol="foo")
97+
assert "foo" in prompt

0 commit comments

Comments
 (0)