Skip to content

Commit 4afc3b5

Browse files
authored
fix: improve unknown branch detection (#18)
1 parent 86047c6 commit 4afc3b5

File tree

5 files changed

+55
-12
lines changed

5 files changed

+55
-12
lines changed

src/git_draft/assistants/common.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,15 @@ class Toolbox(Protocol):
1111
# signature-only versions.
1212

1313
def list_files(self) -> Sequence[PurePosixPath]: ...
14+
1415
def read_file(self, path: PurePosixPath) -> str: ...
15-
def write_file(self, path: PurePosixPath, contents: str) -> None: ...
16+
17+
def write_file(
18+
self,
19+
path: PurePosixPath,
20+
contents: str,
21+
change_description: str | None = None,
22+
) -> None: ...
1623

1724

1825
@dataclasses.dataclass(frozen=True)

src/git_draft/assistants/openai.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _function_tool_param(
2626
"type": "object",
2727
"additionalProperties": False,
2828
"properties": inputs or {},
29-
"required": required_inputs or [],
29+
"required": list(inputs.keys()) if inputs else [],
3030
},
3131
"strict": True,
3232
},
@@ -47,7 +47,6 @@ def _function_tool_param(
4747
"description": "Path of the file to be read",
4848
},
4949
},
50-
required_inputs=["path"],
5150
),
5251
_function_tool_param(
5352
name="write_file",
@@ -65,8 +64,13 @@ def _function_tool_param(
6564
"type": "string",
6665
"description": "New contents of the file",
6766
},
67+
"change_description": {
68+
"type": "string",
69+
"description": """\
70+
Brief description of the changes performed on this file
71+
""",
72+
},
6873
},
69-
required_inputs=["path", "contents"],
7074
),
7175
]
7276

@@ -77,6 +81,8 @@ def _function_tool_param(
7781
You are an expert software engineer, who writes correct and concise code.
7882
Use the provided functions to find the filesyou need to answer the query,
7983
read the content of the relevant ones, and save the changes you suggest.
84+
When writing a file, include a summary description of the changes you have
85+
made.
8086
"""
8187

8288

src/git_draft/common.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import string
1212
import sys
1313
import tempfile
14-
from typing import ContextManager
14+
from typing import Iterator
1515
import xdg_base_dirs
1616

1717

@@ -30,9 +30,7 @@ def _guess_editor_binpath() -> str:
3030

3131

3232
def _get_tty_filename():
33-
if sys.platform == "win32":
34-
return "CON:"
35-
return "/dev/tty"
33+
return "CON:" if sys.platform == "win32" else "/dev/tty"
3634

3735

3836
def open_editor(placeholder="") -> str:
@@ -80,8 +78,16 @@ def persistent(cls) -> Store:
8078
def in_memory(cls) -> Store:
8179
return cls(sqlite3.connect(":memory:"))
8280

83-
def cursor(self) -> ContextManager[sqlite3.Cursor]:
84-
return contextlib.closing(self._connection.cursor())
81+
@contextlib.contextmanager
82+
def cursor(self) -> Iterator[sqlite3.Cursor]:
83+
with contextlib.closing(self._connection.cursor()) as cursor:
84+
try:
85+
yield cursor
86+
except: # noqa
87+
self._connection.rollback()
88+
raise
89+
else:
90+
self._connection.commit()
8591

8692

8793
_query_root = Path(__file__).parent / "queries"

src/git_draft/manager.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,12 @@ def read_file(self, path: PurePosixPath) -> str:
6363
# Read the file from the index.
6464
return self._repo.git.show(f":{path}")
6565

66-
def write_file(self, path: PurePosixPath, contents: str) -> None:
66+
def write_file(
67+
self,
68+
path: PurePosixPath,
69+
contents: str,
70+
change_description: str | None = None,
71+
) -> None:
6772
# Update the index without touching the worktree.
6873
# https://stackoverflow.com/a/25352119
6974
with tempfile.NamedTemporaryFile(delete_on_close=False) as temp:
@@ -187,9 +192,12 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
187192
raise RuntimeError("Not currently on a draft branch")
188193

189194
with self._store.cursor() as cursor:
190-
[(origin_branch, origin_sha, sync_sha)] = cursor.execute(
195+
rows = cursor.execute(
191196
sql("get-branch-by-suffix"), {"suffix": branch.suffix}
192197
)
198+
if not rows:
199+
raise RuntimeError("Unrecognized branch")
200+
[(origin_branch, origin_sha, sync_sha)] = rows
193201

194202
if (
195203
not apply

tests/git_draft/common_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ def test_ensure_state_home(state_home) -> None:
1717
assert path.exists()
1818

1919

20+
class TestRandomId:
21+
def test_length(self) -> None:
22+
length = 10
23+
result = sut.random_id(length)
24+
assert len(result) == length
25+
26+
def test_content(self) -> None:
27+
result = sut.random_id(1000)
28+
assert set(result).issubset(sut._alphabet)
29+
30+
2031
class TestStore:
2132
def test_cursor(self) -> None:
2233
store = sut.Store.persistent()
@@ -30,3 +41,8 @@ def test_cursor(self) -> None:
3041

3142
def test_sql() -> None:
3243
assert "create" in sut.sql("create-tables")
44+
45+
46+
def test_sql_missing() -> None:
47+
with pytest.raises(FileNotFoundError):
48+
sut.sql("non_existent_file")

0 commit comments

Comments
 (0)