Skip to content

Commit aca8581

Browse files
authored
feat: add docstring prompt (#26)
1 parent e55bb4a commit aca8581

File tree

12 files changed

+175
-58
lines changed

12 files changed

+175
-58
lines changed

src/git_draft/__main__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from .common import (
1010
Config,
1111
PROGRAM,
12-
PromptRenderer,
1312
Store,
1413
UnreachableError,
1514
ensure_state_home,
1615
open_editor,
1716
)
1817
from .manager import Manager
18+
from .prompt import TemplatedPrompt
1919

2020

2121
def new_parser() -> optparse.OptionParser:
@@ -134,9 +134,7 @@ def main() -> None:
134134
prompt = opts.prompt
135135
if not prompt:
136136
if opts.template:
137-
renderer = PromptRenderer.default()
138-
kwargs = dict(e.split("=", 1) for e in args)
139-
prompt = renderer.render(opts.template, **kwargs)
137+
prompt = TemplatedPrompt.parse(opts.template, *args)
140138
elif sys.stdin.isatty():
141139
prompt = open_editor("Enter your prompt here...")
142140
else:

src/git_draft/bots/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,19 @@
1919

2020

2121
def load_bot(config: BotConfig) -> Bot:
22+
"""Load and return a Bot instance using the provided configuration.
23+
24+
If a pythonpath is specified in the config and not already present in
25+
sys.path, it is added. The function expects the config.factory in the
26+
format 'module:symbol' or 'symbol'. If only 'symbol' is provided, the
27+
current module is used.
28+
29+
Args:
30+
config: BotConfig object containing bot configuration details.
31+
32+
Raises:
33+
NotImplementedError: If the specified factory cannot be found.
34+
"""
2235
if config.pythonpath and config.pythonpath not in sys.path:
2336
sys.path.insert(0, config.pythonpath)
2437

@@ -39,6 +52,14 @@ def load_bot(config: BotConfig) -> Bot:
3952

4053

4154
def openai_bot(**kwargs) -> Bot:
55+
"""Instantiate and return an OpenAIBot with provided keyword arguments.
56+
57+
This function imports the OpenAIBot class from the openai module and
58+
returns an instance configured with the provided arguments.
59+
60+
Args:
61+
**kwargs: Arbitrary keyword arguments used to configure the bot.
62+
"""
4263
from .openai import OpenAIBot
4364

4465
return OpenAIBot(**kwargs)

src/git_draft/common.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import dataclasses
55
from datetime import datetime
66
import functools
7-
import jinja2
87
import logging
98
import os
109
from pathlib import Path
@@ -27,7 +26,7 @@
2726
type JSONObject = Mapping[str, JSONValue]
2827

2928

30-
_package_root = Path(__file__).parent
29+
package_root = Path(__file__).parent
3130

3231

3332
def ensure_state_home() -> Path:
@@ -72,30 +71,6 @@ class BotConfig:
7271
pythonpath: str | None = None
7372

7473

75-
_prompt_root = _package_root / "prompts"
76-
77-
78-
class PromptRenderer:
79-
def __init__(self, env: jinja2.Environment) -> None:
80-
self._environment = env
81-
82-
@classmethod
83-
def default(cls):
84-
env = jinja2.Environment(
85-
loader=jinja2.FileSystemLoader(
86-
[Config.folder_path() / "prompts", str(_prompt_root)]
87-
),
88-
autoescape=False,
89-
keep_trailing_newline=True,
90-
auto_reload=False,
91-
)
92-
return cls(env)
93-
94-
def render(self, template_name: str, **kwargs) -> str:
95-
template = self._environment.get_template(f"{template_name}.jinja")
96-
return template.render(kwargs)
97-
98-
9974
_default_editors = ["vim", "emacs", "nano"]
10075

10176

@@ -168,7 +143,7 @@ def cursor(self) -> Iterator[sqlite3.Cursor]:
168143
self._connection.commit()
169144

170145

171-
_query_root = _package_root / "queries"
146+
_query_root = package_root / "queries"
172147

173148

174149
@functools.cache

src/git_draft/manager.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from .bots import Bot, OperationHook, Toolbox
1515
from .common import Store, random_id, sql
16+
from .prompt import PromptRenderer, TemplatedPrompt
1617

1718

1819
_logger = logging.getLogger(__name__)
@@ -135,13 +136,13 @@ def _create_branch(self, sync: bool) -> _Branch:
135136

136137
def generate_draft(
137138
self,
138-
prompt: str,
139+
prompt: str | TemplatedPrompt,
139140
bot: Bot,
140141
checkout=False,
141142
reset=False,
142143
sync=False,
143144
) -> None:
144-
if not prompt.strip():
145+
if isinstance(prompt, str) and not prompt.strip():
145146
raise ValueError("Empty prompt")
146147
if self._repo.is_dirty(working_tree=False):
147148
if not reset:
@@ -157,24 +158,31 @@ def generate_draft(
157158
branch = self._create_branch(sync)
158159
_logger.debug("Created branch %s.", branch)
159160

161+
if isinstance(prompt, TemplatedPrompt):
162+
renderer = PromptRenderer.for_repo(self._repo)
163+
prompt_contents = renderer.render(prompt)
164+
else:
165+
prompt_contents = prompt
160166
with self._store.cursor() as cursor:
161167
[(prompt_id,)] = cursor.execute(
162168
sql("add-prompt"),
163169
{
164170
"branch_suffix": branch.suffix,
165-
"contents": prompt,
171+
"contents": prompt_contents,
166172
},
167173
)
168174

169175
start_time = time.perf_counter()
170176
toolbox = _Toolbox(self._repo, self._operation_hook)
171-
action = bot.act(prompt, toolbox)
177+
action = bot.act(prompt_contents, toolbox)
172178
end_time = time.perf_counter()
173179

174180
title = action.title
175181
if not title:
176-
title = textwrap.shorten(prompt, break_on_hyphens=False, width=72)
177-
commit = self._repo.index.commit(f"draft! {title}\n\n{prompt}")
182+
title = _default_title(prompt_contents)
183+
commit = self._repo.index.commit(
184+
f"draft! {title}\n\n{prompt_contents}"
185+
)
178186

179187
with self._store.cursor() as cursor:
180188
cursor.execute(
@@ -247,3 +255,7 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
247255
self._repo.git.checkout(sync_sha, "--", ".")
248256
if delete:
249257
self._repo.git.branch("-D", branch.name)
258+
259+
260+
def _default_title(prompt: str) -> str:
261+
return textwrap.shorten(prompt, break_on_hyphens=False, width=72)

src/git_draft/prompt.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import dataclasses
2+
import git
3+
import jinja2
4+
from typing import Mapping, Self
5+
6+
from .common import Config, package_root
7+
8+
9+
_prompt_root = package_root / "prompts"
10+
11+
12+
@dataclasses.dataclass(frozen=True)
13+
class TemplatedPrompt:
14+
template: str
15+
context: Mapping[str, str]
16+
17+
@classmethod
18+
def parse(cls, name: str, *args: str) -> Self:
19+
"""Parse arguments into a TemplatedPrompt
20+
21+
Args:
22+
name: The name of the template.
23+
*args: Additional arguments for context, expected in 'key=value'
24+
format.
25+
"""
26+
return cls(name, dict(e.split("=", 1) for e in args))
27+
28+
29+
class PromptRenderer:
30+
"""Renderer for prompt templates using Jinja2"""
31+
32+
def __init__(self, env: jinja2.Environment) -> None:
33+
self._environment = env
34+
35+
@classmethod
36+
def for_repo(cls, repo: git.Repo) -> Self:
37+
env = jinja2.Environment(
38+
auto_reload=False,
39+
autoescape=False,
40+
keep_trailing_newline=True,
41+
loader=jinja2.FileSystemLoader(
42+
[Config.folder_path() / "prompts", str(_prompt_root)]
43+
),
44+
undefined=jinja2.StrictUndefined,
45+
)
46+
env.globals["repo"] = {
47+
"file_paths": repo.git.ls_files().splitlines(),
48+
}
49+
return cls(env)
50+
51+
def render(self, prompt: TemplatedPrompt) -> str:
52+
template = self._environment.get_template(f"{prompt.template}.jinja")
53+
return template.render(prompt.context)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{% if src_path is defined %}
2+
Add docstrings to all public functions and classes in {{ src_path }}.
3+
{% else %}
4+
Add docstrings to all public functions and classes in this repository.
5+
{% endif %}
6+
7+
Be concise and do not repeat yourself.
8+
9+
Focus on highlighting aspects which are not obvious from the name of the
10+
symbols. Take time to look at the implementation to discover any behaviors
11+
which could be surprising, and make sure to mention those in the docstring.
12+
13+
Docstrings should use the "Args" format for arguments. See the following
14+
examples:
15+
16+
17+
```python
18+
def write_file(path: Path, contents: str) -> None:
19+
"""Updates a file's contents
20+
21+
Args:
22+
path: Path to the file to update.
23+
contents: New file contents.
24+
"""
25+
...
26+
27+
class Renderer:
28+
"""A smart renderer"""
29+
30+
...
31+
```
32+
33+
Additionally, the first paragraph of each docstring should fit in a single line
34+
and not include a period at the end. It should be a brief summary of the
35+
symbol's functionality.
36+
37+
{% include "includes/file-list.jinja" %}
Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
1-
Add a test for {{ symbol }}.
1+
Add tests for {{ symbol }}. Follow existing conventions when implementing the
2+
tests. For example, if the surrounding code uses fixtures, do so as well.
3+
4+
{% if src_path is defined %}The symbol to be tested is defined in {{ src_path }}.{% endif %}
5+
6+
{% if test_path is defined %}The tests should be added to {{ test_path }}.{% endif %}
7+
8+
Do not stop until you have added at least one test. You should add separate
9+
tests to cover the normal execution path, and to cover any exceptional cases.
10+
11+
{% include "includes/file-list.jinja" %}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
For reference, here is the list of all currently available files in the
2+
repository:
3+
4+
{% for path in repo.file_paths -%}
5+
* {{ path }}
6+
{% endfor %}

tests/git_draft/common_test.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,3 @@ def test_load_ok(self) -> None:
8585
def test_load_default(self) -> None:
8686
config = sut.Config.load()
8787
assert config.log_level == logging.INFO
88-
89-
90-
class TestPromptRenderer:
91-
@pytest.fixture(autouse=True)
92-
def setup(self) -> None:
93-
self._renderer = sut.PromptRenderer.default()
94-
95-
def test_ok(self) -> None:
96-
prompt = self._renderer.render("add-test", symbol="foo")
97-
assert "foo" in prompt

tests/git_draft/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from pathlib import Path
2+
from typing import Iterator
3+
import git
4+
import pytest
5+
6+
7+
@pytest.fixture
8+
def repo(tmp_path: Path) -> Iterator[git.Repo]:
9+
repo = git.Repo.init(str(tmp_path / "repo"), initial_branch="main")
10+
repo.index.commit("init")
11+
yield repo

0 commit comments

Comments
 (0)