diff --git a/poetry.lock b/poetry.lock index 5f25dea..cd71eca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -736,24 +736,43 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pytest" -version = "7.4.4" +version = "8.3.5" description = "pytest: simple powerful testing with Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" groups = ["dev"] files = [ - {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, - {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, + {file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"}, + {file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"}, ] [package.dependencies] colorama = {version = "*", markers = "sys_platform == \"win32\""} iniconfig = "*" packaging = "*" -pluggy = ">=0.12,<2.0" +pluggy = ">=1.5,<2" + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-asyncio" +version = "0.26.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0"}, + {file = "pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f"}, +] + +[package.dependencies] +pytest = ">=8.2,<9" [package.extras] -testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] name = "ruff" @@ -934,4 +953,4 @@ openai = ["openai"] [metadata] lock-version = "2.1" python-versions = ">=3.12,<4" -content-hash = "3e30f690d81bad2a4f36bfc042d3b45bcd83eeed65167c2d6b232083b30370f1" +content-hash = "1059cef5706fc2016a50259c4b1a53a6a99c757ff14b7bac804ce1b134857abf" diff --git a/pyproject.toml b/pyproject.toml index dd04d43..1c7bc68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,8 @@ python = ">=3.12,<4" coverage = "^7.4.4" mypy = "^1.2.1" poethepoet = "^0.25.0" -pytest = "^7.1.0" +pytest = "^8.2.0" +pytest-asyncio = "^0.26.0" ruff = "^0.11.0" # Poe diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 01abbc9..857e2ff 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import enum import importlib.metadata import logging @@ -151,7 +152,7 @@ def edit(*, path: Path | None = None, text: str | None = None) -> str: _PROMPT_PLACEHOLDER = "Enter your prompt here..." -def main() -> None: # noqa: PLR0912 PLR0915 +async def main() -> None: # noqa: PLR0912 PLR0915 config = Config.load() (opts, args) = new_parser().parse_args() @@ -198,7 +199,7 @@ def main() -> None: # noqa: PLR0912 PLR0915 editable = False # We already edited the prompt accept = Accept(opts.accept or 0) - drafter.generate_draft( + await drafter.generate_draft( prompt, bot, prompt_transform=open_editor if editable else None, @@ -231,7 +232,7 @@ def main() -> None: # noqa: PLR0912 PLR0915 if __name__ == "__main__": try: - main() + asyncio.run(main()) except Exception as err: _logger.exception("Program failed.") print(f"Error: {err}", file=sys.stderr) diff --git a/src/git_draft/bots/common.py b/src/git_draft/bots/common.py index f9c0bb2..5e3484e 100644 --- a/src/git_draft/bots/common.py +++ b/src/git_draft/bots/common.py @@ -66,6 +66,6 @@ def state_folder_path(cls, ensure_exists: bool = False) -> Path: path.mkdir(parents=True, exist_ok=True) return path - def act(self, goal: Goal, toolbox: Toolbox) -> Action: + async def act(self, goal: Goal, toolbox: Toolbox) -> Action: """Runs the bot, striving to achieve the goal with the given toolbox""" raise NotImplementedError() diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index 8701460..fb03ee5 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -234,7 +234,7 @@ def __init__(self, client: openai.OpenAI, model: str) -> None: self._client = client self._model = model - def act(self, goal: Goal, toolbox: Toolbox) -> Action: + async def act(self, goal: Goal, toolbox: Toolbox) -> Action: tools = _ToolsFactory(strict=False).params() tool_handler = _CompletionsToolHandler(toolbox) @@ -318,7 +318,7 @@ def _load_assistant_id(self) -> str: f.write(assistant_id) return assistant_id - def act(self, goal: Goal, toolbox: Toolbox) -> Action: + async def act(self, goal: Goal, toolbox: Toolbox) -> Action: assistant_id = self._load_assistant_id() thread = self._client.beta.threads.create() diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index b713e5f..4ca21d3 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -100,7 +100,7 @@ def create(cls, repo: Repo, store: Store, feedback: Feedback) -> Drafter: cursor.executescript(sql("create-tables")) return cls(store, repo, feedback) - def generate_draft( + async def generate_draft( self, prompt: str | TemplatedPrompt, bot: Bot, @@ -143,7 +143,7 @@ def generate_draft( # Run the bot to generate the change. operation_recorder = _OperationRecorder(self._feedback) with self._feedback.spinner("Running bot...") as spinner: - change = self._generate_change( + change = await self._generate_change( bot, Goal(prompt_contents), toolbox.with_visitors( @@ -342,7 +342,7 @@ def _prepare_prompt( raise ValueError("Missing or empty prompt") return contents - def _generate_change( + async def _generate_change( self, bot: Bot, goal: Goal, @@ -352,7 +352,7 @@ def _generate_change( start_time = time.perf_counter() _logger.debug("Running bot... [bot=%s]", bot) - action = bot.act(goal, toolbox) + action = await bot.act(goal, toolbox) _logger.info("Completed bot action. [action=%s]", action) end_time = time.perf_counter() diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index 8c9c3c3..deca8a6 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -29,7 +29,7 @@ def noop(cls) -> Self: def prompt(cls) -> Self: return cls({"PROMPT": lambda goal: goal.prompt}) - def act(self, goal: Goal, toolbox: Toolbox) -> Action: + async def act(self, goal: Goal, toolbox: Toolbox) -> Action: for key, value in self._contents.items(): path = PurePosixPath(key) if value is None: @@ -62,22 +62,25 @@ def _commit_files(self, ref: str) -> frozenset[str]: def _checkout(self) -> None: self._repo.git("checkout", "--", ".") - def test_generate_draft(self) -> None: + @pytest.mark.asyncio + async def test_generate_draft(self) -> None: self._fs.write("p1", "a") - self._drafter.generate_draft("hello", _SimpleBot({"p1": "A"})) + await self._drafter.generate_draft("hello", _SimpleBot({"p1": "A"})) assert len(self._commits()) == 1 assert len(self._commits("@{u}")) == 3 assert self._fs.read("p1") == "a" - def test_generate_empty_draft(self) -> None: - self._drafter.generate_draft("hello", _SimpleBot.noop()) + @pytest.mark.asyncio + async def test_generate_empty_draft(self) -> None: + await self._drafter.generate_draft("hello", _SimpleBot.noop()) assert len(self._commits()) == 1 assert len(self._commits("@{u}")) == 2 - def test_generate_draft_merge(self) -> None: + @pytest.mark.asyncio + async def test_generate_draft_merge(self) -> None: self._fs.write("p1", "a") - self._drafter.generate_draft( + await self._drafter.generate_draft( "hello", _SimpleBot({"p2": "b"}), merge_strategy="ignore-all-space" ) # No sync(merge) commit since no changes happened between. @@ -85,14 +88,15 @@ def test_generate_draft_merge(self) -> None: assert self._fs.read("p1") == "a" assert self._fs.read("p2") == "b" - def test_generate_draft_merge_no_conflict(self) -> None: + @pytest.mark.asyncio + async def test_generate_draft_merge_no_conflict(self) -> None: self._fs.write("p1", "a") def update(_goal: Goal) -> str: self._fs.write("p2", "b") return "A" - self._drafter.generate_draft( + await self._drafter.generate_draft( "hello", _SimpleBot({"p1": update}), merge_strategy="ignore-all-space", @@ -101,21 +105,23 @@ def update(_goal: Goal) -> str: assert self._fs.read("p1") == "A" assert self._fs.read("p2") == "b" - def test_generate_draft_merge_theirs(self) -> None: + @pytest.mark.asyncio + async def test_generate_draft_merge_theirs(self) -> None: self._fs.write("p1", "a") def update(_goal: Goal) -> str: self._fs.write("p1", "b") return "A" - self._drafter.generate_draft( + await self._drafter.generate_draft( "hello", _SimpleBot({"p1": update}), merge_strategy="theirs" ) # sync(merge) commit here since p1 was updated separately. assert len(self._commits()) == 5 # init, sync, prompt, sync, merge assert self._fs.read("p1") == "A" - def test_generate_draft_merge_conflict(self) -> None: + @pytest.mark.asyncio + async def test_generate_draft_merge_conflict(self) -> None: self._fs.write("p1", "a") def update(_goal: Goal) -> str: @@ -123,48 +129,57 @@ def update(_goal: Goal) -> str: return "A" with pytest.raises(GitError): - self._drafter.generate_draft( + await self._drafter.generate_draft( "hello", _SimpleBot({"p1": update}), merge_strategy="ignore-all-space", ) - def test_generate_outside_branch(self) -> None: + @pytest.mark.asyncio + async def test_generate_outside_branch(self) -> None: self._repo.git("checkout", "--detach") with pytest.raises(RuntimeError): - self._drafter.generate_draft("ok", _SimpleBot.noop()) + await self._drafter.generate_draft("ok", _SimpleBot.noop()) - def test_generate_empty_prompt(self) -> None: + @pytest.mark.asyncio + async def test_generate_empty_prompt(self) -> None: with pytest.raises(ValueError): - self._drafter.generate_draft("", _SimpleBot.noop()) + await self._drafter.generate_draft("", _SimpleBot.noop()) - def test_generate_reuse_branch(self) -> None: + @pytest.mark.asyncio + async def test_generate_reuse_branch(self) -> None: bot = _SimpleBot({"prompt": lambda goal: goal.prompt}) - self._drafter.generate_draft("prompt1", bot, "theirs") - self._drafter.generate_draft("prompt2", bot, "theirs") + await self._drafter.generate_draft("prompt1", bot, "theirs") + await self._drafter.generate_draft("prompt2", bot, "theirs") assert self._fs.read("prompt") == "prompt2" - def test_delete_unknown_file(self) -> None: - self._drafter.generate_draft("hello", _SimpleBot({"p1": None})) + @pytest.mark.asyncio + async def test_delete_unknown_file(self) -> None: + await self._drafter.generate_draft("hello", _SimpleBot({"p1": None})) - def test_quit_keeps_changes(self) -> None: + @pytest.mark.asyncio + async def test_quit_keeps_changes(self) -> None: self._fs.write("p1.txt", "a1") - self._drafter.generate_draft("hello", _SimpleBot.prompt(), "theirs") + await self._drafter.generate_draft( + "hello", _SimpleBot.prompt(), "theirs" + ) self._fs.write("p1.txt", "a2") self._drafter.quit_folio() assert self._fs.read("p1.txt") == "a2" assert self._fs.read("PROMPT") == "hello" - def test_latest_draft_prompt(self) -> None: + @pytest.mark.asyncio + async def test_latest_draft_prompt(self) -> None: bot = _SimpleBot.noop() prompt1 = "First prompt" - self._drafter.generate_draft(prompt1, bot) + await self._drafter.generate_draft(prompt1, bot) assert self._drafter.latest_draft_prompt() == prompt1 prompt2 = "Second prompt" - self._drafter.generate_draft(prompt2, bot) + await self._drafter.generate_draft(prompt2, bot) assert self._drafter.latest_draft_prompt() == prompt2 - def test_latest_draft_prompt_no_active_branch(self) -> None: + @pytest.mark.asyncio + async def test_latest_draft_prompt_no_active_branch(self) -> None: assert self._drafter.latest_draft_prompt() is None