diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f24a4ea7..61bfb6476 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `TTY_INTERACTIVE` environment variable to force interactive mode off or on https://github.com/Textualize/rich/pull/3777 +- Allowed custom spinner animations throughout the library https://github.com/Textualize/rich/pull/3791 ## [14.0.0] - 2025-03-30 diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 4b04786b9..a0cae1476 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -94,3 +94,4 @@ The following people have contributed to the development of Rich: - [Jonathan Helmus](https://github.com/jjhelmus) - [Brandon Capener](https://github.com/bcapener) - [Alex Zheng](https://github.com/alexzheng111) +- [Maddy Guthridge](https://maddyguthridge.com/) diff --git a/docs/source/console.rst b/docs/source/console.rst index 4d79886c2..5ee698773 100644 --- a/docs/source/console.rst +++ b/docs/source/console.rst @@ -136,6 +136,10 @@ Run the following command to see the available choices for ``spinner``:: python -m rich.spinner +You can use a custom spinner by providing a dictionary with the following properties + +* ``"interval"`` Intended time per frame of spinner +* ``"frames"`` Frames of the spinner. If this is a single ``str``, each character is a single frame. If a ``list[str]`` is given, each list element is a single frame. Justify / Alignment ------------------- diff --git a/rich/_spinners.py b/rich/_spinners.py index d0bb1fe75..4e5275cef 100644 --- a/rich/_spinners.py +++ b/rich/_spinners.py @@ -19,7 +19,20 @@ IN THE SOFTWARE. """ -SPINNERS = { +from typing import TypedDict, List, Dict, Union + + +class SpinnerAnimation(TypedDict): + interval: float + """Intended time per frame, in milliseconds""" + frames: Union[List[str], str] + """ + Frames of this spinner. If a single `str`, each character is a single + frame. If a `list[str]`, each list element is a single frame. + """ + + +SPINNERS: Dict[str, SpinnerAnimation] = { "dots": { "interval": 80, "frames": "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏", diff --git a/rich/console.py b/rich/console.py index 7d4d4a8f6..034a606e3 100644 --- a/rich/console.py +++ b/rich/console.py @@ -36,6 +36,7 @@ ) from rich._null_file import NULL_FILE +from rich._spinners import SpinnerAnimation from . import errors, themes from ._emoji_replace import _emoji_replace @@ -1163,7 +1164,7 @@ def status( self, status: RenderableType, *, - spinner: str = "dots", + spinner: Union[str, SpinnerAnimation] = "dots", spinner_style: StyleType = "status.spinner", speed: float = 1.0, refresh_per_second: float = 12.5, @@ -2181,9 +2182,9 @@ def export_text(self, *, clear: bool = True, styles: bool = False) -> str: str: String containing console contents. """ - assert ( - self.record - ), "To export console contents set record=True in the constructor or instance" + assert self.record, ( + "To export console contents set record=True in the constructor or instance" + ) with self._record_buffer_lock: if styles: @@ -2237,9 +2238,9 @@ def export_html( Returns: str: String containing console contents as HTML. """ - assert ( - self.record - ), "To export console contents set record=True in the constructor or instance" + assert self.record, ( + "To export console contents set record=True in the constructor or instance" + ) fragments: List[str] = [] append = fragments.append _theme = theme or DEFAULT_TERMINAL_THEME diff --git a/rich/progress.py b/rich/progress.py index ef6ad60f0..cfd16008e 100644 --- a/rich/progress.py +++ b/rich/progress.py @@ -46,7 +46,7 @@ from .jupyter import JupyterMixin from .live import Live from .progress_bar import ProgressBar -from .spinner import Spinner +from .spinner import Spinner, SpinnerAnimation from .style import StyleType from .table import Column, Table from .text import Text, TextType @@ -575,7 +575,7 @@ class SpinnerColumn(ProgressColumn): def __init__( self, - spinner_name: str = "dots", + spinner_name: Union[str, SpinnerAnimation] = "dots", style: Optional[StyleType] = "progress.spinner", speed: float = 1.0, finished_text: TextType = " ", @@ -591,7 +591,7 @@ def __init__( def set_spinner( self, - spinner_name: str, + spinner_name: Union[str, SpinnerAnimation], spinner_style: Optional[StyleType] = "progress.spinner", speed: float = 1.0, ) -> None: diff --git a/rich/spinner.py b/rich/spinner.py index a3a3caf84..4e962b218 100644 --- a/rich/spinner.py +++ b/rich/spinner.py @@ -1,6 +1,6 @@ -from typing import TYPE_CHECKING, List, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union -from ._spinners import SPINNERS +from ._spinners import SPINNERS, SpinnerAnimation from .measure import Measurement from .table import Table from .text import Text @@ -10,11 +10,19 @@ from .style import StyleType +# Explicitly export `SpinnerInfo` to avoid linter annoyances if other people +# want to use our type definition. +__all__ = [ + "Spinner", + "SpinnerAnimation", +] + + class Spinner: """A spinner animation. Args: - name (str): Name of spinner (run python -m rich.spinner). + name (str | SpinnerInfo): Name of spinner (run python -m rich.spinner), or a dict of shape { "interval": float, "frames": str | list[str] } text (RenderableType, optional): A renderable to display at the right of the spinner (str or Text typically). Defaults to "". style (StyleType, optional): Style for spinner animation. Defaults to None. speed (float, optional): Speed factor for animation. Defaults to 1.0. @@ -25,22 +33,26 @@ class Spinner: def __init__( self, - name: str, + name: str | SpinnerAnimation, text: "RenderableType" = "", *, style: Optional["StyleType"] = None, speed: float = 1.0, ) -> None: - try: - spinner = SPINNERS[name] - except KeyError: - raise KeyError(f"no spinner called {name!r}") + if isinstance(name, str): + try: + spinner = SPINNERS[name] + except KeyError: + raise KeyError(f"no spinner called {name!r}") + else: + spinner = name + self.text: "Union[RenderableType, Text]" = ( Text.from_markup(text) if isinstance(text, str) else text ) self.name = name - self.frames = cast(List[str], spinner["frames"])[:] - self.interval = cast(float, spinner["interval"]) + self.frames = spinner["frames"][:] + self.interval = spinner["interval"] self.start_time: Optional[float] = None self.style = style self.speed = speed diff --git a/rich/status.py b/rich/status.py index 65744838e..6d1feb10b 100644 --- a/rich/status.py +++ b/rich/status.py @@ -1,6 +1,7 @@ from types import TracebackType -from typing import Optional, Type +from typing import Optional, Type, Union +from ._spinners import SpinnerAnimation from .console import Console, RenderableType from .jupyter import JupyterMixin from .live import Live @@ -25,7 +26,7 @@ def __init__( status: RenderableType, *, console: Optional[Console] = None, - spinner: str = "dots", + spinner: Union[str, SpinnerAnimation] = "dots", spinner_style: StyleType = "status.spinner", speed: float = 1.0, refresh_per_second: float = 12.5, @@ -54,7 +55,7 @@ def update( self, status: Optional[RenderableType] = None, *, - spinner: Optional[str] = None, + spinner: Union[str, SpinnerAnimation, None] = None, spinner_style: Optional[StyleType] = None, speed: Optional[float] = None, ) -> None: diff --git a/tests/test_spinner.py b/tests/test_spinner.py index efeeb7173..db1baba3c 100644 --- a/tests/test_spinner.py +++ b/tests/test_spinner.py @@ -3,7 +3,7 @@ from rich.console import Console from rich.measure import Measurement from rich.rule import Rule -from rich.spinner import Spinner +from rich.spinner import Spinner, SpinnerAnimation from rich.text import Text @@ -70,3 +70,28 @@ def test_spinner_markup(): spinner = Spinner("dots", "[bold]spinning[/bold]") assert isinstance(spinner.text, Text) assert str(spinner.text) == "spinning" + + +def test_custom_spinner_render(): + custom_spinner: SpinnerAnimation = { + "interval": 80, + "frames": "abcdef", + } + time = 0.0 + + def get_time(): + nonlocal time + return time + + console = Console( + width=80, color_system=None, force_terminal=True, get_time=get_time + ) + console.begin_capture() + spinner = Spinner(custom_spinner, "Foo") + console.print(spinner) + time += 80 / 1000 + console.print(spinner) + result = console.end_capture() + print(repr(result)) + expected = "a Foo\nb Foo\n" + assert result == expected