From b0574de455a01f0f9a87acb05f864a1e5d113221 Mon Sep 17 00:00:00 2001 From: Matthieu Monsch Date: Sat, 1 Mar 2025 07:50:35 -0800 Subject: [PATCH] refactor: modularize assistants --- poetry.lock | 50 ++++++++++----- pyproject.toml | 58 ++++++++++++------ src/git_draft/__init__.py | 10 +-- src/git_draft/__main__.py | 15 ++++- src/git_draft/assistant.py | 59 ------------------ src/git_draft/assistants/__init__.py | 21 +++++++ src/git_draft/assistants/common.py | 21 +++++++ src/git_draft/assistants/openai.py | 92 ++++++++++++++++++++++++++++ src/git_draft/manager.py | 12 +++- tests/git_draft/assistant_test.py | 5 -- tests/git_draft/manager_test.py | 4 +- 11 files changed, 233 insertions(+), 114 deletions(-) delete mode 100644 src/git_draft/assistant.py create mode 100644 src/git_draft/assistants/__init__.py create mode 100644 src/git_draft/assistants/common.py create mode 100644 src/git_draft/assistants/openai.py delete mode 100644 tests/git_draft/assistant_test.py diff --git a/poetry.lock b/poetry.lock index d68c14c..866d925 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4,9 +4,10 @@ name = "annotated-types" version = "0.7.0" description = "Reusable constraint types to use with typing.Annotated" -optional = false +optional = true python-versions = ">=3.8" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, @@ -16,9 +17,10 @@ files = [ name = "anyio" version = "4.8.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" -optional = false +optional = true python-versions = ">=3.9" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a"}, {file = "anyio-4.8.0.tar.gz", hash = "sha256:1d9fe889df5212298c0c0723fa20479d1b94883a2df44bd3897aa91083316f7a"}, @@ -83,9 +85,10 @@ uvloop = ["uvloop (>=0.15.2)"] name = "certifi" version = "2025.1.31" description = "Python package for providing Mozilla's CA Bundle." -optional = false +optional = true python-versions = ">=3.6" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -117,15 +120,16 @@ files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -markers = {main = "platform_system == \"Windows\"", dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} +markers = {main = "extra == \"openai\" and platform_system == \"Windows\"", dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} [[package]] name = "distro" version = "1.9.0" description = "Distro - an OS platform information API" -optional = false +optional = true python-versions = ">=3.6" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, @@ -203,9 +207,10 @@ test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3. name = "h11" version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -optional = false +optional = true python-versions = ">=3.7" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, @@ -215,9 +220,10 @@ files = [ name = "httpcore" version = "1.0.7" description = "A minimal low-level HTTP client." -optional = false +optional = true python-versions = ">=3.8" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd"}, {file = "httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c"}, @@ -237,9 +243,10 @@ trio = ["trio (>=0.22.0,<1.0)"] name = "httpx" version = "0.28.1" description = "The next generation HTTP client." -optional = false +optional = true python-versions = ">=3.8" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"}, {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, @@ -262,9 +269,10 @@ zstd = ["zstandard (>=0.18.0)"] name = "idna" version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" -optional = false +optional = true python-versions = ">=3.6" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -289,9 +297,10 @@ files = [ name = "jiter" version = "0.8.2" description = "Fast iterable JSON parser." -optional = false +optional = true python-versions = ">=3.8" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "jiter-0.8.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:ca8577f6a413abe29b079bc30f907894d7eb07a865c4df69475e868d73e71c7b"}, {file = "jiter-0.8.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b25bd626bde7fb51534190c7e3cb97cee89ee76b76d7585580e22f34f5e3f393"}, @@ -452,9 +461,10 @@ files = [ name = "openai" version = "1.64.0" description = "The official Python library for the openai API" -optional = false +optional = true python-versions = ">=3.8" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "openai-1.64.0-py3-none-any.whl", hash = "sha256:20f85cde9e95e9fbb416e3cb5a6d3119c0b28308afd6e3cc47bf100623dac623"}, {file = "openai-1.64.0.tar.gz", hash = "sha256:2861053538704d61340da56e2f176853d19f1dc5704bc306b7597155f850d57a"}, @@ -578,9 +588,10 @@ files = [ name = "pydantic" version = "2.10.6" description = "Data validation using Python type hints" -optional = false +optional = true python-versions = ">=3.8" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584"}, {file = "pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236"}, @@ -599,9 +610,10 @@ timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows name = "pydantic-core" version = "2.27.2" description = "Core functionality for Pydantic validation and serialization" -optional = false +optional = true python-versions = ">=3.8" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa"}, {file = "pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c"}, @@ -757,9 +769,10 @@ files = [ name = "sniffio" version = "1.3.1" description = "Sniff out which async library your code is running under" -optional = false +optional = true python-versions = ">=3.7" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, @@ -811,9 +824,10 @@ files = [ name = "tqdm" version = "4.67.1" description = "Fast, Extensible Progress Meter" -optional = false +optional = true python-versions = ">=3.7" groups = ["main"] +markers = "extra == \"openai\"" files = [ {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"}, {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, @@ -840,8 +854,12 @@ files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +markers = {main = "extra == \"openai\""} + +[extras] +openai = ["openai"] [metadata] lock-version = "2.1" python-versions = ">=3.12,<4" -content-hash = "98ba96ddfc998c9c99546fd9c3e3ce358527be48efa24a1228b19b4658fdc195" +content-hash = "97ac3b11fb233e092f7340da6e498136fd8356d8e05d48b3dd84fccad7bf5c91" diff --git a/pyproject.toml b/pyproject.toml index 0f1b3e5..bdfcaef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,22 +1,36 @@ +[project] +name = 'git-draft' +description = 'Version-controlled code assistant' +authors = [{name='Matthieu Monsch', email='mtth@apache.org'}] +license = 'MIT' +readme = 'README.md' +dynamic = ['version'] +requires-python = '>=3.12' +dependencies = [ + 'gitpython >=3.1.44,<4', +] + +[project.optional-dependencies] +openai = ['openai >=1.64.0,<2'] + +[project.scripts] +git-draft = 'git_draft.__main__:main' + +[project.urls] +repository = 'https://github.com/mtth/git-draft' +documentation = 'https://mtth.github.io/git-draft' + [build-system] requires = ['poetry-core'] build-backend = 'poetry.core.masonry.api' +# Poetry + [tool.poetry] -name = 'git-draft' version = '0.0.0' # Set programmatically -description = 'Version-controlled code assistant' -authors = ['Matthieu Monsch '] -readme = 'README.md' -repository = 'https://github.com/mtth/git-draft' -packages = [{include = 'git_draft', from = 'src'}] - -[tool.poetry.scripts] -git-draft = 'git_draft.__main__:main' +packages = [{include='git_draft', from='src'}] [tool.poetry.dependencies] -gitpython = '^3.1.44' -openai = '^1.64.0' python = '>=3.12,<4' [tool.poetry.group.dev.dependencies] @@ -27,15 +41,7 @@ mypy = '^1.2.0' poethepoet = '^0.25.0' pytest = '^7.1.2' -[tool.black] -line-length = 79 -include = '\.py$' - -[tool.flake8] -ignore = ['E203', 'E501', 'E704', 'W503'] - -[tool.mypy] -disable_error_code = 'import-untyped' +# Poe [tool.poe.tasks.fix] help = 'format source code' @@ -55,5 +61,17 @@ args = [ {name='args', help='target folders', positional=true, multiple=true, default='src tests'}, ] +# Other tools + +[tool.black] +line-length = 79 +include = '\.py$' + +[tool.flake8] +ignore = ['E203', 'E501', 'E704', 'W503'] + +[tool.mypy] +disable_error_code = 'import-untyped' + [tool.pytest.ini_options] log_level = 'DEBUG' diff --git a/src/git_draft/__init__.py b/src/git_draft/__init__.py index ad90e65..c094f06 100644 --- a/src/git_draft/__init__.py +++ b/src/git_draft/__init__.py @@ -1,11 +1,7 @@ -from .assistant import Assistant, OpenAIAssistant -from .common import open_editor -from .manager import Manager, enclosing_repo +from .assistants import Assistant, Session, Toolbox __all__ = [ "Assistant", - "OpenAIAssistant", - "Manager", - "enclosing_repo", - "open_editor", + "Session", + "Toolbox", ] diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 71d73d0..fadfcb6 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -5,7 +5,9 @@ import sys import textwrap -from . import Manager, OpenAIAssistant, enclosing_repo, open_editor +from .assistants import load_assistant +from .common import open_editor +from .manager import Manager, enclosing_repo EPILOG = """\ @@ -39,6 +41,12 @@ def callback(_option, _opt, _value, parser) -> None: add_command("finalize", help="apply the current draft to the original branch") add_command("generate", help="draft a new change from a prompt") +parser.add_option( + "-a", + "--assistant", + dest="ASSISTANT", + help="assistant key", +) parser.add_option( "-d", "--delete", @@ -65,20 +73,21 @@ def callback(_option, _opt, _value, parser) -> None: def main() -> None: - (opts, args) = parser.parse_args() + (opts, _args) = parser.parse_args() repo = enclosing_repo() manager = Manager(repo) command = getattr(opts, "command", "generate") if command == "generate": + assistant = load_assistant(opts.assistant, {}) prompt = opts.prompt if not prompt: if sys.stdin.isatty(): prompt = open_editor(textwrap.dedent(EDITOR_PLACEHOLDER)) else: prompt = sys.stdin.read() - manager.generate_draft(prompt, OpenAIAssistant(), reset=opts.reset) + manager.generate_draft(prompt, assistant, reset=opts.reset) elif command == "finalize": manager.finalize_draft(delete=opts.delete) elif command == "discard": diff --git a/src/git_draft/assistant.py b/src/git_draft/assistant.py deleted file mode 100644 index 3557984..0000000 --- a/src/git_draft/assistant.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -import dataclasses -import openai -from pathlib import PurePosixPath -import textwrap -from typing import Protocol, Sequence - - -class Toolbox(Protocol): - def list_files(self) -> Sequence[PurePosixPath]: ... - def read_file(self, path: PurePosixPath) -> str: ... - def write_file(self, path: PurePosixPath, data: str) -> None: ... - - -@dataclasses.dataclass(frozen=True) -class Session: - token_count: int - calls: list[Call] - - -@dataclasses.dataclass(frozen=True) -class Call: - usage: openai.types.CompletionUsage | None - - -class Assistant: - def run(self, prompt: str, toolbox: Toolbox) -> Session: - raise NotImplementedError() - - -# https://aider.chat/docs/more-info.html -# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py -_SYSTEM_PROMPT = textwrap.dedent( - """ - You are an expert software engineer, who writes correct and concise code. -""" -) - - -class OpenAIAssistant(Assistant): - def __init__(self) -> None: - self._client = openai.OpenAI() - - def run(self, prompt: str, toolbox: Toolbox) -> Session: - # TODO: Switch to the thread run API, using tools to leverage toolbox - # methods. - # https://platform.openai.com/docs/assistants/deep-dive#runs-and-run-steps - # https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py - completion = self._client.chat.completions.create( - messages=[ - {"role": "system", "content": _SYSTEM_PROMPT}, - {"role": "user", "content": prompt}, - ], - model="gpt-4o", - ) - content = completion.choices[0].message.content or "" - toolbox.write_file(PurePosixPath(f"{completion.id}.txt"), content) - return Session(0, calls=[Call(completion.usage)]) diff --git a/src/git_draft/assistants/__init__.py b/src/git_draft/assistants/__init__.py new file mode 100644 index 0000000..05474d7 --- /dev/null +++ b/src/git_draft/assistants/__init__.py @@ -0,0 +1,21 @@ +from typing import Any, Mapping + +from .common import Assistant, Session, Toolbox + +__all__ = [ + "Assistant", + "Session", + "Toolbox", +] + + +def load_assistant(entry: str, kwargs: Mapping[str, Any]) -> Assistant: + if entry == "openai": + return _load_openai_assistant(**kwargs) + raise NotImplementedError() # TODO + + +def _load_openai_assistant(**kwargs) -> Assistant: + from .openai import OpenAIAssistant + + return OpenAIAssistant(**kwargs) diff --git a/src/git_draft/assistants/common.py b/src/git_draft/assistants/common.py new file mode 100644 index 0000000..3ea68c6 --- /dev/null +++ b/src/git_draft/assistants/common.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import dataclasses +from pathlib import PurePosixPath +from typing import Protocol, Sequence + + +class Toolbox(Protocol): + def list_files(self) -> Sequence[PurePosixPath]: ... + def read_file(self, path: PurePosixPath) -> str: ... + def write_file(self, path: PurePosixPath, data: str) -> None: ... + + +@dataclasses.dataclass(frozen=True) +class Session: + token_count: int + + +class Assistant: + def run(self, prompt: str, toolbox: Toolbox) -> Session: + raise NotImplementedError() diff --git a/src/git_draft/assistants/openai.py b/src/git_draft/assistants/openai.py new file mode 100644 index 0000000..c4ccc48 --- /dev/null +++ b/src/git_draft/assistants/openai.py @@ -0,0 +1,92 @@ +import openai +from pathlib import PurePosixPath + +from .common import Assistant, Session, Toolbox + + +# https://aider.chat/docs/more-info.html +# https://github.com/Aider-AI/aider/blob/main/aider/prompts.py +_INSTRUCTIONS = """\ + You are an expert software engineer, who writes correct and concise code. +""" + +_tools = [ # TODO + { + "type": "function", + "function": { + "name": "get_current_temperature", + "description": "Get the current temperature for a specific location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g., San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["Celsius", "Fahrenheit"], + "description": "The temperature unit to use. Infer this from the user's location.", + }, + }, + "required": ["location", "unit"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_rain_probability", + "description": "Get the probability of rain for a specific location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g., San Francisco, CA", + } + }, + "required": ["location"], + }, + }, + }, +] + + +class OpenAIAssistant(Assistant): + """An OpenAI-backed assistant + + See the following links for resources: + + * https://platform.openai.com/docs/assistants/tools/function-calling + * https://platform.openai.com/docs/assistants/deep-dive#runs-and-run-steps + * https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py + """ + + def __init__(self) -> None: + self._client = openai.OpenAI() + + def run(self, prompt: str, toolbox: Toolbox) -> Session: + # TODO: Switch to the thread run API, using tools to leverage toolbox + # methods. + # assistant = client.beta.assistants.create( + # instructions=_INSTRUCTIONS, + # model="gpt-4o", + # tools=_tools, + # ) + # thread = client.beta.threads.create() + # message = client.beta.threads.messages.create( + # thread_id=thread.id, + # role="user", + # content="What's the weather in San Francisco today and the likelihood it'll rain?", + # ) + completion = self._client.chat.completions.create( + messages=[ + {"role": "system", "content": _INSTRUCTIONS}, + {"role": "user", "content": prompt}, + ], + model="gpt-4o", + ) + content = completion.choices[0].message.content or "" + toolbox.write_file(PurePosixPath(f"{completion.id}.txt"), content) + return Session(0) diff --git a/src/git_draft/manager.py b/src/git_draft/manager.py index 41c2835..cbcd232 100644 --- a/src/git_draft/manager.py +++ b/src/git_draft/manager.py @@ -8,7 +8,7 @@ import tempfile from typing import Callable, ClassVar, Match, Self, Sequence -from .assistant import Assistant +from .assistants import Assistant, Toolbox def enclosing_repo(path: str | None = None) -> git.Repo: @@ -110,7 +110,13 @@ def active(cls, repo: git.Repo) -> _Branch | None: return _Branch(init_shortsha, init_note) -class _Toolbox: +class _Toolbox(Toolbox): + """Git-index backed toolbox + + All files are directly read from and written to the index. This allows + concurrent editing without interference. + """ + def __init__(self, repo: git.Repo) -> None: self._repo = repo @@ -136,6 +142,8 @@ def write_file(self, path: PurePosixPath, data: str) -> None: class Manager: + """Draft state manager""" + def __init__(self, repo: git.Repo) -> None: self._repo = repo diff --git a/tests/git_draft/assistant_test.py b/tests/git_draft/assistant_test.py deleted file mode 100644 index 2526843..0000000 --- a/tests/git_draft/assistant_test.py +++ /dev/null @@ -1,5 +0,0 @@ -import git_draft.assistant as sut - - -def test_assistant(): - assert sut.Assistant() diff --git a/tests/git_draft/manager_test.py b/tests/git_draft/manager_test.py index 373fb9c..475679b 100644 --- a/tests/git_draft/manager_test.py +++ b/tests/git_draft/manager_test.py @@ -5,7 +5,7 @@ import tempfile from typing import Iterator -from git_draft.assistant import Assistant, Session, Toolbox +from git_draft.assistants import Assistant, Session, Toolbox import git_draft.manager as sut @@ -49,7 +49,7 @@ def test_write_multiple(self, repo: git.Repo) -> None: class _FakeAssistant(Assistant): def run(self, prompt: str, toolbox: Toolbox) -> Session: toolbox.write_file(PurePosixPath("PROMPT"), prompt) - return Session(len(prompt), []) + return Session(len(prompt)) class TestManager: