Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .common import (
PROGRAM,
Config,
Feedback,
Progress,
UnreachableError,
ensure_state_home,
)
Expand Down Expand Up @@ -167,9 +167,9 @@ async def run() -> None: # noqa: PLR0912 PLR0915
datefmt="%m-%d %H:%M",
)

feedback = Feedback.dynamic() if sys.stdin.isatty() else Feedback.static()
progress = Progress.dynamic() if sys.stdin.isatty() else Progress.static()
repo = Repo.enclosing(Path(opts.root) if opts.root else Path.cwd())
drafter = Drafter.create(repo, Store.persistent(), feedback)
drafter = Drafter.create(repo, Store.persistent(), progress)
match getattr(opts, "command", "new"):
case "new":
bot_config = None
Expand Down
46 changes: 23 additions & 23 deletions src/git_draft/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,32 +144,32 @@ def _tagged(text: str, /, **kwargs) -> str:
return f"{text} [{', '.join(tags)}]" if tags else text


class Feedback:
"""User feedback interface"""
class Progress:
"""Progress feedback interface"""

def report(self, text: str, **tags) -> None: # pragma: no cover
raise NotImplementedError()

def spinner(
self, text: str, **tags
) -> contextlib.AbstractContextManager[
FeedbackSpinner
ProgressSpinner
]: # pragma: no cover
raise NotImplementedError()

@staticmethod
def dynamic() -> Feedback:
"""Feedback suitable for interactive terminals"""
return _DynamicFeedback()
def dynamic() -> Progress:
"""Progress suitable for interactive terminals"""
return _DynamicProgress()

@staticmethod
def static() -> Feedback:
"""Feedback suitable for pipes, etc."""
return _StaticFeedback()
def static() -> Progress:
"""Progress suitable for pipes, etc."""
return _StaticProgress()


class FeedbackSpinner:
"""Operation feedback tracker"""
class ProgressSpinner:
"""Operation progress tracker"""

@contextlib.contextmanager
def hidden(self) -> Iterator[None]:
Expand All @@ -179,9 +179,9 @@ def update(self, text: str, **tags) -> None: # pragma: no cover
raise NotImplementedError()


class _DynamicFeedback(Feedback):
class _DynamicProgress(Progress):
def __init__(self) -> None:
self._spinner: _DynamicFeedbackSpinner | None = None
self._spinner: _DynamicProgressSpinner | None = None

def report(self, text: str, **tags) -> None:
message = f"☞ {_tagged(text, **tags)}"
Expand All @@ -191,10 +191,10 @@ def report(self, text: str, **tags) -> None:
print(message) # noqa

@contextlib.contextmanager
def spinner(self, text: str, **tags) -> Iterator[FeedbackSpinner]:
def spinner(self, text: str, **tags) -> Iterator[ProgressSpinner]:
assert not self._spinner
with yaspin.yaspin(text=_tagged(text, **tags)) as spinner:
self._spinner = _DynamicFeedbackSpinner(spinner)
self._spinner = _DynamicProgressSpinner(spinner)
try:
yield self._spinner
except Exception:
Expand All @@ -206,7 +206,7 @@ def spinner(self, text: str, **tags) -> Iterator[FeedbackSpinner]:
self._spinner = None


class _DynamicFeedbackSpinner(FeedbackSpinner):
class _DynamicProgressSpinner(ProgressSpinner):
def __init__(self, yaspin: yaspin.core.Yaspin) -> None:
self.yaspin = yaspin

Expand All @@ -219,19 +219,19 @@ def update(self, text: str, **tags) -> None:
self.yaspin.text = _tagged(text, **tags)


class _StaticFeedback(Feedback):
class _StaticProgress(Progress):
def report(self, text: str, **tags) -> None:
print(_tagged(text, **tags)) # noqa

@contextlib.contextmanager
def spinner(self, text: str, **tags) -> Iterator[FeedbackSpinner]:
def spinner(self, text: str, **tags) -> Iterator[ProgressSpinner]:
self.report(text, **tags)
yield _StaticFeedbackSpinner(self)
yield _StaticProgressSpinner(self)


class _StaticFeedbackSpinner(FeedbackSpinner):
def __init__(self, feedback: _StaticFeedback) -> None:
self._feedback = feedback
class _StaticProgressSpinner(ProgressSpinner):
def __init__(self, progress: _StaticProgress) -> None:
self._progress = progress

def update(self, text: str, **tags) -> None:
self._feedback.report(text, **tags)
self._progress.report(text, **tags)
40 changes: 20 additions & 20 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Literal

from .bots import Action, Bot, Goal
from .common import Feedback, JSONObject, qualified_class_name, reindent
from .common import JSONObject, Progress, qualified_class_name, reindent
from .git import SHA, Repo
from .prompt import TemplatedPrompt
from .store import Store, sql
Expand Down Expand Up @@ -89,16 +89,16 @@ def _active_folio(repo: Repo) -> Folio | None:
class Drafter:
"""Draft state orchestrator"""

def __init__(self, store: Store, repo: Repo, feedback: Feedback) -> None:
def __init__(self, store: Store, repo: Repo, progress: Progress) -> None:
self._store = store
self._repo = repo
self._feedback = feedback
self._progress = progress

@classmethod
def create(cls, repo: Repo, store: Store, feedback: Feedback) -> Drafter:
def create(cls, repo: Repo, store: Store, progress: Progress) -> Drafter:
with store.cursor() as cursor:
cursor.executescript(sql("create-tables"))
return cls(store, repo, feedback)
return cls(store, repo, progress)

async def generate_draft(
self,
Expand All @@ -107,7 +107,7 @@ async def generate_draft(
merge_strategy: DraftMergeStrategy | None = None,
prompt_transform: Callable[[str], str] | None = None,
) -> Draft:
with self._feedback.spinner("Preparing prompt...") as spinner:
with self._progress.spinner("Preparing prompt...") as spinner:
# Handle prompt templating and editing. We do this first in case
# this fails, to avoid creating unnecessary branches.
toolbox, dirty = RepoToolbox.for_working_dir(self._repo)
Expand Down Expand Up @@ -141,8 +141,8 @@ async def generate_draft(
)

# Run the bot to generate the change.
operation_recorder = _OperationRecorder(self._feedback)
with self._feedback.spinner("Running bot...") as spinner:
operation_recorder = _OperationRecorder(self._progress)
with self._progress.spinner("Running bot...") as spinner:
change = await self._generate_change(
bot,
Goal(prompt_contents),
Expand All @@ -151,7 +151,7 @@ async def generate_draft(
),
)
if change.action.question:
self._feedback.report("Requested feedback.")
self._progress.report("Requested progress.")
spinner.update(
"Completed bot run.",
runtime=round(change.walltime.total_seconds(), 1),
Expand All @@ -167,7 +167,7 @@ async def generate_draft(
walltime=change.walltime,
token_count=change.action.token_count,
)
with self._feedback.spinner("Creating draft commit...") as spinner:
with self._progress.spinner("Creating draft commit...") as spinner:
if dirty:
parent_commit_rev = self._commit_tree(
toolbox.tree_sha(), "HEAD", "sync(prompt)"
Expand Down Expand Up @@ -214,7 +214,7 @@ async def generate_draft(
_logger.info("Created new draft in folio %s.", folio.id)

if merge_strategy:
with self._feedback.spinner("Merging changes...") as spinner:
with self._progress.spinner("Merging changes...") as spinner:
if parent_commit_rev != "HEAD":
# If there was a sync(prompt) commit, we move forward to
# it. This will avoid conflicts with earlier changes.
Expand Down Expand Up @@ -260,7 +260,7 @@ def quit_folio(self) -> None:
if check_call.code:
raise RuntimeError("Origin branch diverged, please rebase first")

with self._feedback.spinner("Switching branch...") as spinner:
with self._progress.spinner("Switching branch...") as spinner:
# Create a reference to the current state for later analysis.
self._sync_head("finalize")
self._repo.git("update-ref", _draft_ref(folio.id, "@"), "HEAD")
Expand All @@ -287,7 +287,7 @@ def quit_folio(self) -> None:
_logger.info("Quit %s.", folio)

def _create_folio(self) -> Folio:
with self._feedback.spinner("Creating draft branch...") as spinner:
with self._progress.spinner("Creating draft branch...") as spinner:
origin_branch = self._repo.active_branch()
if origin_branch is None:
raise RuntimeError("No currently active branch")
Expand Down Expand Up @@ -436,33 +436,33 @@ class _OperationRecorder(ToolVisitor):
analysis.
"""

def __init__(self, feedback: Feedback) -> None:
def __init__(self, progress: Progress) -> None:
self.operations = list[_Operation]()
self._feedback = feedback
self._progress = progress

def on_list_files(
self, paths: Sequence[PurePosixPath], reason: str | None
) -> None:
count = len(paths)
self._feedback.report("Listed available files.", count=count)
self._progress.report("Listed available files.", count=count)
self._record(reason, "list_files", count=count)

def on_read_file(
self, path: PurePosixPath, contents: str | None, reason: str | None
) -> None:
size = -1 if contents is None else len(contents)
self._feedback.report(f"Read {path}.", length=size)
self._progress.report(f"Read {path}.", length=size)
self._record(reason, "read_file", path=str(path), size=size)

def on_write_file(
self, path: PurePosixPath, contents: str, reason: str | None
) -> None:
size = len(contents)
self._feedback.report(f"Wrote {path}.", length=size)
self._progress.report(f"Wrote {path}.", length=size)
self._record(reason, "write_file", path=str(path), size=size)

def on_delete_file(self, path: PurePosixPath, reason: str | None) -> None:
self._feedback.report(f"Deleted {path}.")
self._progress.report(f"Deleted {path}.")
self._record(reason, "delete_file", path=str(path))

def on_rename_file(
Expand All @@ -471,7 +471,7 @@ def on_rename_file(
dst_path: PurePosixPath,
reason: str | None,
) -> None:
self._feedback.report(f"Renamed {src_path} to {dst_path}.")
self._progress.report(f"Renamed {src_path} to {dst_path}.")
self._record(
reason,
"rename_file",
Expand Down
4 changes: 1 addition & 3 deletions src/git_draft/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,7 @@ def _load_prompt(
assert env.loader, "No loader in environment"
template = env.loader.load(env, str(rel_path))
context: _Context = dict(
program=name,
prompt=_load_layouts(),
toolbox=toolbox
program=name, prompt=_load_layouts(), toolbox=toolbox
)
try:
module = template.make_module(vars=cast(dict, context))
Expand Down
4 changes: 2 additions & 2 deletions tests/git_draft/drafter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from git_draft.bots import Action, Bot, Goal, Toolbox
from git_draft.common import Feedback
from git_draft.common import Progress
import git_draft.drafter as sut
from git_draft.git import SHA, GitError, Repo
from git_draft.store import Store
Expand Down Expand Up @@ -46,7 +46,7 @@ def setup(self, repo: Repo, repo_fs: RepoFS) -> None:
self._repo = repo
self._fs = repo_fs
self._drafter = sut.Drafter.create(
repo, Store.in_memory(), Feedback.static()
repo, Store.in_memory(), Progress.static()
)

def _commits(self, ref: str | None = None) -> Sequence[SHA]:
Expand Down