diff --git a/.github/codeql.yml b/.github/codeql.yml index 7603545ab6..d883676d10 100644 --- a/.github/codeql.yml +++ b/.github/codeql.yml @@ -10,7 +10,11 @@ query-filters: - exclude: id: py/missing-call-to-init - exclude: - id: py/method-first-arg-is-not-self + id: py/method-first-arg-is-not-self + - exclude: + id: py/cyclic-import + - exclude: + id: py/unsafe-cyclic-import paths: - manim paths-ignore: diff --git a/.github/scripts/ci_build_cairo.py b/.github/scripts/ci_build_cairo.py index 469db7b452..be50793a9e 100644 --- a/.github/scripts/ci_build_cairo.py +++ b/.github/scripts/ci_build_cairo.py @@ -14,8 +14,8 @@ import sys import tarfile import tempfile -import typing import urllib.request +from collections.abc import Generator from contextlib import contextmanager from pathlib import Path from sys import stdout @@ -67,7 +67,7 @@ def run_command(command, cwd=None, env=None): @contextmanager -def gha_group(title: str) -> typing.Generator: +def gha_group(title: str) -> Generator: if not is_ci(): yield return diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1d59b4661d..0d094abdee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: python-version: ${{ matrix.python }} - name: Install uv - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@v6 with: enable-cache: true @@ -54,10 +54,13 @@ jobs: - name: Install Texlive (Linux) if: runner.os == 'Linux' - uses: teatimeguest/setup-texlive-action@v3 + uses: zauguin/install-texlive@v4 with: - cache: true - packages: scheme-basic fontspec inputenc fontenc tipa mathrsfs calligra xcolor standalone preview doublestroke ms everysel setspace rsfs relsize ragged2e fundus-calligra microtype wasysym physics dvisvgm jknapltx wasy cm-super babel-english gnu-freefont mathastext cbfonts-fd xetex + packages: > + scheme-basic latex fontspec tipa calligra xcolor + standalone preview doublestroke setspace rsfs relsize + ragged2e fundus-calligra microtype wasysym physics dvisvgm jknapltx + wasy cm-super babel-english gnu-freefont mathastext cbfonts-fd xetex - name: Start virtual display (Linux) if: runner.os == 'Linux' diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index b74c25b597..e8e0d92de5 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -21,7 +21,7 @@ jobs: python-version: 3.13 - name: Install uv - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@v6 - name: Build and push release to PyPI run: | diff --git a/.github/workflows/release-publish-documentation.yml b/.github/workflows/release-publish-documentation.yml index 3983a69800..ab90ded0c3 100644 --- a/.github/workflows/release-publish-documentation.yml +++ b/.github/workflows/release-publish-documentation.yml @@ -17,7 +17,7 @@ jobs: python-version: 3.13 - name: Install uv - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@v6 - name: Install system dependencies run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fdc48e8371..81dcc7cf24 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - id: check-toml name: Validate pyproject.toml - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.0 + rev: v0.12.7 hooks: - id: ruff name: ruff lint @@ -22,7 +22,7 @@ repos: - id: ruff-format types: [python] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 + rev: v1.17.1 hooks: - id: mypy additional_dependencies: diff --git a/docs/source/conf.py b/docs/source/conf.py index 821d58d852..bd55b2d341 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -156,6 +156,7 @@ # This specifies any additional css files that will override the theme's html_css_files = ["custom.css"] +latex_engine = "lualatex" # external links extlinks = { diff --git a/docs/source/contributing/docs/types.rst b/docs/source/contributing/docs/types.rst index 840d29e6c0..5457f5c126 100644 --- a/docs/source/contributing/docs/types.rst +++ b/docs/source/contributing/docs/types.rst @@ -85,14 +85,8 @@ typed as a :class:`~.Point3D`, because it represents a direction along which to shift a :class:`~.Mobject`, not a position in space. As a general rule, if a parameter is called ``direction`` or ``axis``, -it should be type hinted as some form of :class:`~.VectorND`. - -.. warning:: - - This is not always true. For example, as of Manim 0.18.0, the direction - parameter of the :class:`.Vector` Mobject should be - ``Point2DLike | Point3DLike``, as it can also accept ``tuple[float, float]`` - and ``tuple[float, float, float]``. +it should be type hinted as some form of :class:`~.VectorND` or +:class:`~.VectorNDLike`. Colors ------ diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 00b4299a52..6584daadd8 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -299,7 +299,7 @@ Animations path.become(previous_path) path.add_updater(update_path) self.add(path, dot) - self.play(Rotating(dot, radians=PI, about_point=RIGHT, run_time=2)) + self.play(Rotating(dot, angle=PI, about_point=RIGHT, run_time=2)) self.wait() self.play(dot.animate.shift(UP)) self.play(dot.animate.shift(LEFT)) diff --git a/docs/source/reference_index/utilities_misc.rst b/docs/source/reference_index/utilities_misc.rst index 874a20ef86..bda1cf4961 100644 --- a/docs/source/reference_index/utilities_misc.rst +++ b/docs/source/reference_index/utilities_misc.rst @@ -15,6 +15,7 @@ Module Index ~utils.commands ~utils.config_ops constants + data_structures ~utils.debug ~utils.deprecation ~utils.docbuild diff --git a/example_scenes/advanced_tex_fonts.py b/example_scenes/advanced_tex_fonts.py index d8d7486ff9..a5ff47b50b 100644 --- a/example_scenes/advanced_tex_fonts.py +++ b/example_scenes/advanced_tex_fonts.py @@ -52,7 +52,7 @@ class TexFontTemplateLibrary(Scene): Many of the in the TexFontTemplates collection require that specific fonts are installed on your local machine. For example, choosing the template TexFontTemplates.comic_sans will - not compile if the Comic Sans Micrososft font is not installed. + not compile if the Comic Sans Microsoft font is not installed. This scene will only render those Templates that do not cause a TeX compilation error on your system. Furthermore, some of the ones that do render, diff --git a/manim/__init__.py b/manim/__init__.py index a4034ed134..0605d4a3ae 100644 --- a/manim/__init__.py +++ b/manim/__init__.py @@ -1,9 +1,16 @@ #!/usr/bin/env python from __future__ import annotations -from importlib.metadata import version +from importlib.metadata import PackageNotFoundError, version -__version__ = version(__name__) +# Use installed distribution version if available; otherwise fall back to a +# sensible default so that importing from a source checkout works without an +# editable install (pip install -e .). +try: + __version__ = version(__name__) +except PackageNotFoundError: + # Package is not installed; provide a fallback version string. + __version__ = "0.0.0+unknown" # isort: off diff --git a/manim/_config/__init__.py b/manim/_config/__init__.py index 3eed54b481..2d3883d227 100644 --- a/manim/_config/__init__.py +++ b/manim/_config/__init__.py @@ -23,10 +23,9 @@ parser = make_config_parser() -# The logger can be accessed from anywhere as manim.logger, or as -# logging.getLogger("manim"). The console must be accessed as manim.console. -# Throughout the codebase, use manim.console.print() instead of print(). -# Use error_console to print errors so that it outputs to stderr. +# Logger usage: accessible globally as `manim.logger` or via `logging.getLogger("manim")`. +# For printing, use `manim.console.print()` instead of the built-in `print()`. +# For error output, use `error_console`, which prints to stderr. logger, console, error_console = make_logger( parser["logger"], parser["CLI"]["verbosity"], @@ -45,7 +44,7 @@ # This has to go here because it needs access to this module's config @contextmanager def tempconfig(temp: ManimConfig | dict[str, Any]) -> Generator[None, None, None]: - """Context manager that temporarily modifies the global ``config`` object. + """Temporarily modifies the global ``config`` object using a context manager. Inside the ``with`` statement, the modified config will be used. After context manager exits, the config will be restored to its original state. diff --git a/manim/_config/cli_colors.py b/manim/_config/cli_colors.py index 5b1d151bdb..e62428a75a 100644 --- a/manim/_config/cli_colors.py +++ b/manim/_config/cli_colors.py @@ -1,3 +1,9 @@ +"""Parses CLI context settings from the configuration file and returns a Cloup Context settings dictionary. + +This module reads configuration values for help formatting, theme styles, and alignment options +used when rendering command-line interfaces in Manim. +""" + from __future__ import annotations import configparser @@ -9,7 +15,7 @@ def parse_cli_ctx(parser: configparser.SectionProxy) -> dict[str, Any]: - formatter_settings: dict[str, str | int] = { + formatter_settings: dict[str, str | int | None] = { "indent_increment": int(parser["indent_increment"]), "width": int(parser["width"]), "col1_max_width": int(parser["col1_max_width"]), @@ -28,6 +34,7 @@ def parse_cli_ctx(parser: configparser.SectionProxy) -> dict[str, Any]: "col2", "epilog", } + # Extract and apply any style-related keys defined in the config section. for k, v in parser.items(): if k in theme_keys and v: theme_settings.update({k: Style(v)}) @@ -37,22 +44,24 @@ def parse_cli_ctx(parser: configparser.SectionProxy) -> dict[str, Any]: if theme is None: formatter = HelpFormatter.settings( theme=HelpTheme(**theme_settings), - **formatter_settings, # type: ignore[arg-type] + **formatter_settings, ) elif theme.lower() == "dark": formatter = HelpFormatter.settings( theme=HelpTheme.dark().with_(**theme_settings), - **formatter_settings, # type: ignore[arg-type] + **formatter_settings, ) elif theme.lower() == "light": formatter = HelpFormatter.settings( theme=HelpTheme.light().with_(**theme_settings), - **formatter_settings, # type: ignore[arg-type] + **formatter_settings, ) - return Context.settings( + return_val: dict[str, Any] = Context.settings( align_option_groups=parser["align_option_groups"].lower() == "true", align_sections=parser["align_sections"].lower() == "true", show_constraints=True, formatter_settings=formatter, ) + + return return_val diff --git a/manim/_config/utils.py b/manim/_config/utils.py index f109af6939..31290f802a 100644 --- a/manim/_config/utils.py +++ b/manim/_config/utils.py @@ -122,10 +122,14 @@ def make_config_parser( # read_file() before calling read() for any optional files." # https://docs.python.org/3/library/configparser.html#configparser.ConfigParser.read parser = configparser.ConfigParser() + logger.info(f"Reading config file: {library_wide}") with library_wide.open() as file: parser.read_file(file) # necessary file other_files = [user_wide, Path(custom_file) if custom_file else folder_wide] + for path in other_files: + if path.exists(): + logger.info(f"Reading config file: {path}") parser.read(other_files) # optional files return parser @@ -1414,7 +1418,7 @@ def window_position(self, value: str) -> None: @property def window_size(self) -> str: - """The size of the opengl window. 'default' to automatically scale the window based on the display monitor.""" + """The size of the opengl window as 'width,height' or 'default' to automatically scale the window based on the display monitor.""" return self._d["window_size"] @window_size.setter diff --git a/manim/animation/animation.py b/manim/animation/animation.py index 93a304a87e..5de7605f8f 100644 --- a/manim/animation/animation.py +++ b/manim/animation/animation.py @@ -14,10 +14,10 @@ __all__ = ["Animation", "Wait", "Add", "override_animation"] -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from copy import deepcopy from functools import partialmethod -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any from typing_extensions import Self @@ -120,7 +120,7 @@ def __new__( if func is not None: anim = func(mobject, *args, **kwargs) logger.debug( - f"The {cls.__name__} animation has been is overridden for " + f"The {cls.__name__} animation has been overridden for " f"{type(mobject).__name__} mobjects. use_override = False can " f" be used as keyword argument to prevent animation overriding.", ) @@ -129,7 +129,7 @@ def __new__( def __init__( self, - mobject: Mobject | None, + mobject: Mobject | OpenGLMobject | None, lag_ratio: float = DEFAULT_ANIMATION_LAG_RATIO, run_time: float = DEFAULT_ANIMATION_RUN_TIME, rate_func: Callable[[float], float] = smooth, @@ -140,7 +140,7 @@ def __init__( introducer: bool = False, *, _on_finish: Callable[[], None] = lambda _: None, - **kwargs, + use_override: bool = True, # included here to avoid TypeError if passed from a subclass' constructor ) -> None: self._typecheck_input(mobject) self.run_time: float = run_time @@ -160,8 +160,6 @@ def __init__( else: self.starting_mobject: Mobject = Mobject() self.mobject: Mobject = mobject if mobject is not None else Mobject() - if kwargs: - logger.debug("Animation received extra kwargs: %s", kwargs) if hasattr(self, "CONFIG"): logger.error( @@ -264,11 +262,11 @@ def _setup_scene(self, scene: Scene) -> None: ): scene.add(self.mobject) - def create_starting_mobject(self) -> Mobject: + def create_starting_mobject(self) -> Mobject | OpenGLMobject: # Keep track of where the mobject starts return self.mobject.copy() - def get_all_mobjects(self) -> Sequence[Mobject]: + def get_all_mobjects(self) -> Sequence[Mobject | OpenGLMobject]: """Get all mobjects involved in the animation. Ordering must match the ordering of arguments to interpolate_submobject @@ -498,6 +496,8 @@ def __init_subclass__(cls, **kwargs) -> None: cls._original__init__ = cls.__init__ + _original__init__ = __init__ # needed if set_default() is called with no kwargs directly from Animation + @classmethod def set_default(cls, **kwargs) -> None: """Sets the default values of keyword arguments. @@ -540,7 +540,7 @@ def construct(self): def prepare_animation( - anim: Animation | mobject._AnimationBuilder, + anim: Animation | mobject._AnimationBuilder | opengl_mobject._AnimationBuilder, ) -> Animation: r"""Returns either an unchanged animation, or the animation built from a passed animation factory. diff --git a/manim/animation/changing.py b/manim/animation/changing.py index bb11cfc0a4..9b5617f157 100644 --- a/manim/animation/changing.py +++ b/manim/animation/changing.py @@ -4,8 +4,12 @@ __all__ = ["AnimatedBoundary", "TracedPath"] -from typing import Callable +from collections.abc import Callable, Sequence +from typing import Any +from typing_extensions import Self + +from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from manim.mobject.types.vectorized_mobject import VGroup, VMobject from manim.utils.color import ( @@ -16,7 +20,7 @@ WHITE, ParsableManimColor, ) -from manim.utils.rate_functions import smooth +from manim.utils.rate_functions import RateFunction, smooth class AnimatedBoundary(VGroup): @@ -38,14 +42,14 @@ def construct(self): def __init__( self, - vmobject, - colors=[BLUE_D, BLUE_B, BLUE_E, GREY_BROWN], - max_stroke_width=3, - cycle_rate=0.5, - back_and_forth=True, - draw_rate_func=smooth, - fade_rate_func=smooth, - **kwargs, + vmobject: VMobject, + colors: Sequence[ParsableManimColor] = [BLUE_D, BLUE_B, BLUE_E, GREY_BROWN], + max_stroke_width: float = 3, + cycle_rate: float = 0.5, + back_and_forth: bool = True, + draw_rate_func: RateFunction = smooth, + fade_rate_func: RateFunction = smooth, + **kwargs: Any, ): super().__init__(**kwargs) self.colors = colors @@ -59,10 +63,10 @@ def __init__( vmobject.copy().set_style(stroke_width=0, fill_opacity=0) for x in range(2) ] self.add(*self.boundary_copies) - self.total_time = 0 + self.total_time = 0.0 self.add_updater(lambda m, dt: self.update_boundary_copies(dt)) - def update_boundary_copies(self, dt): + def update_boundary_copies(self, dt: float) -> None: # Not actual time, but something which passes at # an altered rate to make the implementation below # cleaner @@ -78,9 +82,9 @@ def update_boundary_copies(self, dt): fade_alpha = self.fade_rate_func(alpha) if self.back_and_forth and int(time) % 2 == 1: - bounds = (1 - draw_alpha, 1) + bounds = (1.0 - draw_alpha, 1.0) else: - bounds = (0, draw_alpha) + bounds = (0.0, draw_alpha) self.full_family_become_partial(growing, vmobject, *bounds) growing.set_stroke(colors[index], width=msw) @@ -90,7 +94,9 @@ def update_boundary_copies(self, dt): self.total_time += dt - def full_family_become_partial(self, mob1, mob2, a, b): + def full_family_become_partial( + self, mob1: VMobject, mob2: VMobject, a: float, b: float + ) -> Self: family1 = mob1.family_members_with_points() family2 = mob2.family_members_with_points() for sm1, sm2 in zip(family1, family2): @@ -146,20 +152,21 @@ def __init__( stroke_width: float = 2, stroke_color: ParsableManimColor | None = WHITE, dissipating_time: float | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(stroke_color=stroke_color, stroke_width=stroke_width, **kwargs) self.traced_point_func = traced_point_func self.dissipating_time = dissipating_time - self.time = 1 if self.dissipating_time else None + self.time = 1.0 if self.dissipating_time else None self.add_updater(self.update_path) - def update_path(self, mob, dt): + def update_path(self, mob: Mobject, dt: float) -> None: new_point = self.traced_point_func() if not self.has_points(): self.start_new_path(new_point) self.add_line_to(new_point) if self.dissipating_time: + assert self.time is not None self.time += dt if self.time - 1 > self.dissipating_time: nppcc = self.n_points_per_curve diff --git a/manim/animation/composition.py b/manim/animation/composition.py index 128066ba80..a9da6f8902 100644 --- a/manim/animation/composition.py +++ b/manim/animation/composition.py @@ -2,9 +2,8 @@ from __future__ import annotations -import types -from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any import numpy as np @@ -12,7 +11,7 @@ from manim.animation.animation import Animation, prepare_animation from manim.constants import RendererType from manim.mobject.mobject import Group, Mobject -from manim.mobject.opengl.opengl_mobject import OpenGLGroup +from manim.mobject.opengl.opengl_mobject import OpenGLGroup, OpenGLMobject from manim.scene.scene import Scene from manim.utils.iterables import remove_list_redundancies from manim.utils.parameter_parsing import flatten_iterable_parameters @@ -54,31 +53,34 @@ class AnimationGroup(Animation): def __init__( self, - *animations: Animation | Iterable[Animation] | types.GeneratorType[Animation], - group: Group | VGroup | OpenGLGroup | OpenGLVGroup = None, + *animations: Animation | Iterable[Animation], + group: Group | VGroup | OpenGLGroup | OpenGLVGroup | None = None, run_time: float | None = None, rate_func: Callable[[float], float] = linear, lag_ratio: float = 0, - **kwargs, - ) -> None: + **kwargs: Any, + ): arg_anim = flatten_iterable_parameters(animations) self.animations = [prepare_animation(anim) for anim in arg_anim] self.rate_func = rate_func - self.group = group - if self.group is None: + if group is None: mobjects = remove_list_redundancies( [anim.mobject for anim in self.animations if not anim.is_introducer()], ) if config["renderer"] == RendererType.OPENGL: - self.group = OpenGLGroup(*mobjects) + self.group: Group | VGroup | OpenGLGroup | OpenGLVGroup = OpenGLGroup( + *mobjects + ) else: self.group = Group(*mobjects) + else: + self.group = group super().__init__( self.group, rate_func=self.rate_func, lag_ratio=lag_ratio, **kwargs ) self.run_time: float = self.init_run_time(run_time) - def get_all_mobjects(self) -> Sequence[Mobject]: + def get_all_mobjects(self) -> Sequence[Mobject | OpenGLMobject]: return list(self.group) def begin(self) -> None: @@ -93,7 +95,7 @@ def begin(self) -> None: for anim in self.animations: anim.begin() - def _setup_scene(self, scene) -> None: + def _setup_scene(self, scene: Scene) -> None: for anim in self.animations: anim._setup_scene(scene) @@ -118,7 +120,7 @@ def update_mobjects(self, dt: float) -> None: ]: anim.update_mobjects(dt) - def init_run_time(self, run_time) -> float: + def init_run_time(self, run_time: float | None) -> float: """Calculates the run time of the animation, if different from ``run_time``. Parameters @@ -146,9 +148,9 @@ def build_animations_with_timings(self) -> None: run_times = np.array([anim.run_time for anim in self.animations]) num_animations = run_times.shape[0] dtype = [("anim", "O"), ("start", "f8"), ("end", "f8")] - self.anims_with_timings = np.zeros(num_animations, dtype=dtype) - self.anims_begun = np.zeros(num_animations, dtype=bool) - self.anims_finished = np.zeros(num_animations, dtype=bool) + self.anims_with_timings: np.ndarray = np.zeros(num_animations, dtype=dtype) + self.anims_begun: np.ndarray = np.zeros(num_animations, dtype=bool) + self.anims_finished: np.ndarray = np.zeros(num_animations, dtype=bool) if num_animations == 0: return @@ -228,7 +230,7 @@ def construct(self): )) """ - def __init__(self, *animations: Animation, lag_ratio: float = 1, **kwargs) -> None: + def __init__(self, *animations: Animation, lag_ratio: float = 1, **kwargs: Any): super().__init__(*animations, lag_ratio=lag_ratio, **kwargs) def begin(self) -> None: @@ -247,7 +249,7 @@ def update_mobjects(self, dt: float) -> None: if self.active_animation: self.active_animation.update_mobjects(dt) - def _setup_scene(self, scene) -> None: + def _setup_scene(self, scene: Scene | None) -> None: if scene is None: return if self.is_introducer(): @@ -339,7 +341,7 @@ def __init__( self, *animations: Animation, lag_ratio: float = DEFAULT_LAGGED_START_LAG_RATIO, - **kwargs, + **kwargs: Any, ): super().__init__(*animations, lag_ratio=lag_ratio, **kwargs) @@ -384,20 +386,22 @@ def construct(self): def __init__( self, - AnimationClass: Callable[..., Animation], + animation_class: type[Animation], mobject: Mobject, - arg_creator: Callable[[Mobject], str] = None, + arg_creator: Callable[[Mobject], Iterable[Any]] | None = None, run_time: float = 2, - **kwargs, - ) -> None: - args_list = [] - for submob in mobject: - if arg_creator: - args_list.append(arg_creator(submob)) - else: - args_list.append((submob,)) + **kwargs: Any, + ): + if arg_creator is None: + + def identity(mob: Mobject) -> Mobject: + return mob + + arg_creator = identity + + args_list = [arg_creator(submob) for submob in mobject] anim_kwargs = dict(kwargs) if "lag_ratio" in anim_kwargs: anim_kwargs.pop("lag_ratio") - animations = [AnimationClass(*args, **anim_kwargs) for args in args_list] + animations = [animation_class(*args, **anim_kwargs) for args in args_list] super().__init__(*animations, run_time=run_time, **kwargs) diff --git a/manim/animation/creation.py b/manim/animation/creation.py index dc3ec69527..654100707b 100644 --- a/manim/animation/creation.py +++ b/manim/animation/creation.py @@ -76,8 +76,8 @@ def construct(self): import itertools as it -from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING import numpy as np @@ -120,7 +120,7 @@ def __init__( ): pointwise = getattr(mobject, "pointwise_become_partial", None) if not callable(pointwise): - raise NotImplementedError("This animation is not defined for this Mobject.") + raise TypeError(f"{self.__class__.__name__} only works for VMobjects.") super().__init__(mobject, **kwargs) def interpolate_submobject( @@ -133,7 +133,7 @@ def interpolate_submobject( starting_submobject, *self._get_bounds(alpha) ) - def _get_bounds(self, alpha: float) -> None: + def _get_bounds(self, alpha: float) -> tuple[float, float]: raise NotImplementedError("Please use Create or ShowPassingFlash") @@ -173,7 +173,7 @@ def __init__( ) -> None: super().__init__(mobject, lag_ratio=lag_ratio, introducer=introducer, **kwargs) - def _get_bounds(self, alpha: float) -> tuple[int, float]: + def _get_bounds(self, alpha: float) -> tuple[float, float]: return (0, alpha) @@ -229,8 +229,6 @@ def __init__( rate_func: Callable[[float], float] = double_smooth, stroke_width: float = 2, stroke_color: str = None, - draw_border_animation_config: dict = {}, # what does this dict accept? - fill_animation_config: dict = {}, introducer: bool = True, **kwargs, ) -> None: @@ -244,8 +242,6 @@ def __init__( ) self.stroke_width = stroke_width self.stroke_color = stroke_color - self.draw_border_animation_config = draw_border_animation_config - self.fill_animation_config = fill_animation_config self.outline = self.get_outline() def _typecheck_input(self, vmobject: VMobject | OpenGLVMobject) -> None: diff --git a/manim/animation/fading.py b/manim/animation/fading.py index 79cd41a516..480ecc8fc7 100644 --- a/manim/animation/fading.py +++ b/manim/animation/fading.py @@ -19,6 +19,8 @@ def construct(self): "FadeIn", ] +from typing import Any + import numpy as np from manim.mobject.opengl.opengl_mobject import OpenGLMobject @@ -53,7 +55,7 @@ def __init__( shift: np.ndarray | None = None, target_position: np.ndarray | Mobject | None = None, scale: float = 1, - **kwargs, + **kwargs: Any, ) -> None: if not mobjects: raise ValueError("At least one mobject must be passed.") @@ -85,7 +87,7 @@ def _create_faded_mobject(self, fadeIn: bool) -> Mobject: Mobject The faded, shifted and scaled copy of the mobject. """ - faded_mobject = self.mobject.copy() + faded_mobject: Mobject = self.mobject.copy() # type: ignore[assignment] faded_mobject.fade(1) direction_modifier = -1 if fadeIn and not self.point_target else 1 faded_mobject.shift(self.shift_vector * direction_modifier) @@ -131,13 +133,13 @@ def construct(self): """ - def __init__(self, *mobjects: Mobject, **kwargs) -> None: + def __init__(self, *mobjects: Mobject, **kwargs: Any) -> None: super().__init__(*mobjects, introducer=True, **kwargs) - def create_target(self): - return self.mobject + def create_target(self) -> Mobject: + return self.mobject # type: ignore[return-value] - def create_starting_mobject(self): + def create_starting_mobject(self) -> Mobject: return self._create_faded_mobject(fadeIn=True) @@ -179,12 +181,12 @@ def construct(self): """ - def __init__(self, *mobjects: Mobject, **kwargs) -> None: + def __init__(self, *mobjects: Mobject, **kwargs: Any) -> None: super().__init__(*mobjects, remover=True, **kwargs) - def create_target(self): + def create_target(self) -> Mobject: return self._create_faded_mobject(fadeIn=False) - def clean_up_from_scene(self, scene: Scene = None) -> None: + def clean_up_from_scene(self, scene: Scene) -> None: super().clean_up_from_scene(scene) self.interpolate(0) diff --git a/manim/animation/growing.py b/manim/animation/growing.py index d9f526c136..889de79fc0 100644 --- a/manim/animation/growing.py +++ b/manim/animation/growing.py @@ -31,16 +31,17 @@ def construct(self): "SpinInFromNothing", ] -import typing - -import numpy as np +from typing import TYPE_CHECKING, Any from ..animation.transform import Transform from ..constants import PI from ..utils.paths import spiral_path -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from manim.mobject.geometry.line import Arrow + from manim.mobject.opengl.opengl_mobject import OpenGLMobject + from manim.typing import Point3DLike, Vector3DLike + from manim.utils.color import ParsableManimColor from ..mobject.mobject import Mobject @@ -76,16 +77,20 @@ def construct(self): """ def __init__( - self, mobject: Mobject, point: np.ndarray, point_color: str = None, **kwargs - ) -> None: + self, + mobject: Mobject, + point: Point3DLike, + point_color: ParsableManimColor | None = None, + **kwargs: Any, + ): self.point = point self.point_color = point_color super().__init__(mobject, introducer=True, **kwargs) - def create_target(self) -> Mobject: + def create_target(self) -> Mobject | OpenGLMobject: return self.mobject - def create_starting_mobject(self) -> Mobject: + def create_starting_mobject(self) -> Mobject | OpenGLMobject: start = super().create_starting_mobject() start.scale(0) start.move_to(self.point) @@ -118,7 +123,12 @@ def construct(self): """ - def __init__(self, mobject: Mobject, point_color: str = None, **kwargs) -> None: + def __init__( + self, + mobject: Mobject, + point_color: ParsableManimColor | None = None, + **kwargs: Any, + ): point = mobject.get_center() super().__init__(mobject, point, point_color=point_color, **kwargs) @@ -153,8 +163,12 @@ def construct(self): """ def __init__( - self, mobject: Mobject, edge: np.ndarray, point_color: str = None, **kwargs - ) -> None: + self, + mobject: Mobject, + edge: Vector3DLike, + point_color: ParsableManimColor | None = None, + **kwargs: Any, + ): point = mobject.get_critical_point(edge) super().__init__(mobject, point, point_color=point_color, **kwargs) @@ -183,11 +197,13 @@ def construct(self): """ - def __init__(self, arrow: Arrow, point_color: str = None, **kwargs) -> None: + def __init__( + self, arrow: Arrow, point_color: ParsableManimColor | None = None, **kwargs: Any + ): point = arrow.get_start() super().__init__(arrow, point, point_color=point_color, **kwargs) - def create_starting_mobject(self) -> Mobject: + def create_starting_mobject(self) -> Mobject | OpenGLMobject: start_arrow = self.mobject.copy() start_arrow.scale(0, scale_tips=True, about_point=self.point) if self.point_color: @@ -224,8 +240,12 @@ def construct(self): """ def __init__( - self, mobject: Mobject, angle: float = PI / 2, point_color: str = None, **kwargs - ) -> None: + self, + mobject: Mobject, + angle: float = PI / 2, + point_color: ParsableManimColor | None = None, + **kwargs: Any, + ): self.angle = angle super().__init__( mobject, path_func=spiral_path(angle), point_color=point_color, **kwargs diff --git a/manim/animation/indication.py b/manim/animation/indication.py index f931491b37..6008573119 100644 --- a/manim/animation/indication.py +++ b/manim/animation/indication.py @@ -40,14 +40,16 @@ def construct(self): ] from collections.abc import Iterable -from typing import Callable +from typing import Any import numpy as np +from typing_extensions import Self from manim.mobject.geometry.arc import Circle, Dot from manim.mobject.geometry.line import Line from manim.mobject.geometry.polygram import Rectangle from manim.mobject.geometry.shape_matchers import SurroundingRectangle +from manim.mobject.opengl.opengl_mobject import OpenGLMobject from manim.scene.scene import Scene from .. import config @@ -61,9 +63,10 @@ def construct(self): from ..constants import * from ..mobject.mobject import Mobject from ..mobject.types.vectorized_mobject import VGroup, VMobject +from ..typing import Point3D, Point3DLike, Vector3DLike from ..utils.bezier import interpolate, inverse_interpolate from ..utils.color import GREY, YELLOW, ParsableManimColor -from ..utils.rate_functions import smooth, there_and_back, wiggle +from ..utils.rate_functions import RateFunction, smooth, there_and_back, wiggle from ..utils.space_ops import normalize @@ -95,12 +98,12 @@ def construct(self): def __init__( self, - focus_point: np.ndarray | Mobject, + focus_point: Point3DLike | Mobject, opacity: float = 0.2, - color: str = GREY, + color: ParsableManimColor = GREY, run_time: float = 2, - **kwargs, - ) -> None: + **kwargs: Any, + ): self.focus_point = focus_point self.color = color self.opacity = opacity @@ -151,15 +154,15 @@ def __init__( self, mobject: Mobject, scale_factor: float = 1.2, - color: str = YELLOW, - rate_func: Callable[[float, float | None], np.ndarray] = there_and_back, - **kwargs, - ) -> None: + color: ParsableManimColor = YELLOW, + rate_func: RateFunction = there_and_back, + **kwargs: Any, + ): self.color = color self.scale_factor = scale_factor super().__init__(mobject, rate_func=rate_func, **kwargs) - def create_target(self) -> Mobject: + def create_target(self) -> Mobject | OpenGLMobject: target = self.mobject.copy() target.scale(self.scale_factor) target.set_color(self.color) @@ -219,20 +222,20 @@ def construct(self): def __init__( self, - point: np.ndarray | Mobject, + point: Point3DLike | Mobject, line_length: float = 0.2, num_lines: int = 12, flash_radius: float = 0.1, line_stroke_width: int = 3, - color: str = YELLOW, + color: ParsableManimColor = YELLOW, time_width: float = 1, run_time: float = 1.0, - **kwargs, - ) -> None: + **kwargs: Any, + ): if isinstance(point, Mobject): - self.point = point.get_center() + self.point: Point3D = point.get_center() else: - self.point = point + self.point = np.asarray(point) self.color = color self.line_length = line_length self.num_lines = num_lines @@ -303,11 +306,13 @@ def construct(self): """ - def __init__(self, mobject: VMobject, time_width: float = 0.1, **kwargs) -> None: + def __init__( + self, mobject: VMobject, time_width: float = 0.1, **kwargs: Any + ) -> None: self.time_width = time_width super().__init__(mobject, remover=True, introducer=True, **kwargs) - def _get_bounds(self, alpha: float) -> tuple[float]: + def _get_bounds(self, alpha: float) -> tuple[float, float]: tw = self.time_width upper = interpolate(0, 1 + tw, alpha) lower = upper - tw @@ -322,7 +327,14 @@ def clean_up_from_scene(self, scene: Scene) -> None: class ShowPassingFlashWithThinningStrokeWidth(AnimationGroup): - def __init__(self, vmobject, n_segments=10, time_width=0.1, remover=True, **kwargs): + def __init__( + self, + vmobject: VMobject, + n_segments: int = 10, + time_width: float = 0.1, + remover: bool = True, + **kwargs: Any, + ): self.n_segments = n_segments self.time_width = time_width self.remover = remover @@ -389,19 +401,19 @@ def construct(self): def __init__( self, mobject: Mobject, - direction: np.ndarray = UP, + direction: Vector3DLike = UP, amplitude: float = 0.2, - wave_func: Callable[[float], float] = smooth, + wave_func: RateFunction = smooth, time_width: float = 1, ripples: int = 1, run_time: float = 2, - **kwargs, - ) -> None: + **kwargs: Any, + ): x_min = mobject.get_left()[0] x_max = mobject.get_right()[0] vect = amplitude * normalize(direction) - def wave(t): + def wave(t: float) -> float: # Creates a wave with n ripples from a simple rate_func # This wave is build up as follows: # The time is split into 2*ripples phases. In every phase the amplitude @@ -467,7 +479,8 @@ def homotopy( relative_x = inverse_interpolate(x_min, x_max, x) wave_phase = inverse_interpolate(lower, upper, relative_x) nudge = wave(wave_phase) * vect - return np.array([x, y, z]) + nudge + return_value: tuple[float, float, float] = np.array([x, y, z]) + nudge + return return_value super().__init__(homotopy, mobject, run_time=run_time, **kwargs) @@ -511,24 +524,28 @@ def __init__( scale_value: float = 1.1, rotation_angle: float = 0.01 * TAU, n_wiggles: int = 6, - scale_about_point: np.ndarray | None = None, - rotate_about_point: np.ndarray | None = None, + scale_about_point: Point3DLike | None = None, + rotate_about_point: Point3DLike | None = None, run_time: float = 2, - **kwargs, - ) -> None: + **kwargs: Any, + ): self.scale_value = scale_value self.rotation_angle = rotation_angle self.n_wiggles = n_wiggles self.scale_about_point = scale_about_point + if scale_about_point is not None: + self.scale_about_point = np.array(scale_about_point) self.rotate_about_point = rotate_about_point + if rotate_about_point is not None: + self.rotate_about_point = np.array(rotate_about_point) super().__init__(mobject, run_time=run_time, **kwargs) - def get_scale_about_point(self) -> np.ndarray: + def get_scale_about_point(self) -> Point3D: if self.scale_about_point is None: return self.mobject.get_center() return self.scale_about_point - def get_rotate_about_point(self) -> np.ndarray: + def get_rotate_about_point(self) -> Point3D: if self.rotate_about_point is None: return self.mobject.get_center() return self.rotate_about_point @@ -538,7 +555,7 @@ def interpolate_submobject( submobject: Mobject, starting_submobject: Mobject, alpha: float, - ) -> None: + ) -> Self: submobject.points[:, :] = starting_submobject.points submobject.scale( interpolate(1, self.scale_value, there_and_back(alpha)), @@ -548,6 +565,7 @@ def interpolate_submobject( wiggle(alpha, self.n_wiggles) * self.rotation_angle, about_point=self.get_rotate_about_point(), ) + return self class Circumscribe(Succession): @@ -595,18 +613,18 @@ def construct(self): def __init__( self, mobject: Mobject, - shape: type = Rectangle, - fade_in=False, - fade_out=False, - time_width=0.3, + shape: type[Rectangle] | type[Circle] = Rectangle, + fade_in: bool = False, + fade_out: bool = False, + time_width: float = 0.3, buff: float = SMALL_BUFF, color: ParsableManimColor = YELLOW, - run_time=1, - stroke_width=DEFAULT_STROKE_WIDTH, - **kwargs, + run_time: float = 1, + stroke_width: float = DEFAULT_STROKE_WIDTH, + **kwargs: Any, ): if shape is Rectangle: - frame = SurroundingRectangle( + frame: SurroundingRectangle | Circle = SurroundingRectangle( mobject, color=color, buff=buff, @@ -685,7 +703,7 @@ def __init__( time_off: float = 0.5, blinks: int = 1, hide_at_end: bool = False, - **kwargs, + **kwargs: Any, ): animations = [ UpdateFromFunc( diff --git a/manim/animation/movement.py b/manim/animation/movement.py index b9b185db06..c0c79d4aef 100644 --- a/manim/animation/movement.py +++ b/manim/animation/movement.py @@ -18,7 +18,13 @@ from ..utils.rate_functions import linear if TYPE_CHECKING: - from ..mobject.mobject import Mobject, VMobject + from typing_extensions import Self + + from manim.mobject.types.vectorized_mobject import VMobject + from manim.typing import MappingFunction, Point3D + from manim.utils.rate_functions import RateFunction + + from ..mobject.mobject import Mobject class Homotopy(Animation): @@ -72,27 +78,33 @@ def __init__( mobject: Mobject, run_time: float = 3, apply_function_kwargs: dict[str, Any] | None = None, - **kwargs, - ) -> None: + **kwargs: Any, + ): self.homotopy = homotopy self.apply_function_kwargs = ( apply_function_kwargs if apply_function_kwargs is not None else {} ) super().__init__(mobject, run_time=run_time, **kwargs) - def function_at_time_t(self, t: float) -> tuple[float, float, float]: - return lambda p: self.homotopy(*p, t) + def function_at_time_t(self, t: float) -> MappingFunction: + def mapping_function(p: Point3D) -> Point3D: + x, y, z = p + return np.array(self.homotopy(x, y, z, t)) + + return mapping_function def interpolate_submobject( self, submobject: Mobject, starting_submobject: Mobject, alpha: float, - ) -> None: + ) -> Self: submobject.points = starting_submobject.points submobject.apply_function( - self.function_at_time_t(alpha), **self.apply_function_kwargs + self.function_at_time_t(alpha), + **self.apply_function_kwargs, ) + return self class SmoothedVectorizedHomotopy(Homotopy): @@ -101,15 +113,20 @@ def interpolate_submobject( submobject: Mobject, starting_submobject: Mobject, alpha: float, - ) -> None: + ) -> Self: + assert isinstance(submobject, VMobject) super().interpolate_submobject(submobject, starting_submobject, alpha) submobject.make_smooth() + return self class ComplexHomotopy(Homotopy): def __init__( - self, complex_homotopy: Callable[[complex], float], mobject: Mobject, **kwargs - ) -> None: + self, + complex_homotopy: Callable[[complex, float], float], + mobject: Mobject, + **kwargs: Any, + ): """Complex Homotopy a function Cx[0, 1] to C""" def homotopy( @@ -131,9 +148,9 @@ def __init__( mobject: Mobject, virtual_time: float = 1, suspend_mobject_updating: bool = False, - rate_func: Callable[[float], float] = linear, - **kwargs, - ) -> None: + rate_func: RateFunction = linear, + **kwargs: Any, + ): self.virtual_time = virtual_time self.function = function super().__init__( @@ -149,7 +166,7 @@ def interpolate_mobject(self, alpha: float) -> None: self.rate_func(alpha) - self.rate_func(self.last_alpha) ) self.mobject.apply_function(lambda p: p + dt * self.function(p)) - self.last_alpha = alpha + self.last_alpha: float = alpha class MoveAlongPath(Animation): @@ -171,9 +188,9 @@ def __init__( self, mobject: Mobject, path: VMobject, - suspend_mobject_updating: bool | None = False, - **kwargs, - ) -> None: + suspend_mobject_updating: bool = False, + **kwargs: Any, + ): self.path = path super().__init__( mobject, suspend_mobject_updating=suspend_mobject_updating, **kwargs diff --git a/manim/animation/numbers.py b/manim/animation/numbers.py index 86bfe7154b..e8724c4d34 100644 --- a/manim/animation/numbers.py +++ b/manim/animation/numbers.py @@ -5,7 +5,8 @@ __all__ = ["ChangingDecimal", "ChangeDecimalToValue"] -import typing +from collections.abc import Callable +from typing import Any from manim.mobject.text.numbers import DecimalNumber @@ -14,12 +15,47 @@ class ChangingDecimal(Animation): + """Animate a :class:`~.DecimalNumber` to values specified by a user-supplied function. + + Parameters + ---------- + decimal_mob + The :class:`~.DecimalNumber` instance to animate. + number_update_func + A function that returns the number to display at each point in the animation. + suspend_mobject_updating + If ``True``, the mobject is not updated outside this animation. + + Raises + ------ + TypeError + If ``decimal_mob`` is not an instance of :class:`~.DecimalNumber`. + + Examples + -------- + + .. manim:: ChangingDecimalExample + + class ChangingDecimalExample(Scene): + def construct(self): + number = DecimalNumber(0) + self.add(number) + self.play( + ChangingDecimal( + number, + lambda a: 5 * a, + run_time=3 + ) + ) + self.wait() + """ + def __init__( self, decimal_mob: DecimalNumber, - number_update_func: typing.Callable[[float], float], - suspend_mobject_updating: bool | None = False, - **kwargs, + number_update_func: Callable[[float], float], + suspend_mobject_updating: bool = False, + **kwargs: Any, ) -> None: self.check_validity_of_input(decimal_mob) self.number_update_func = number_update_func @@ -32,12 +68,34 @@ def check_validity_of_input(self, decimal_mob: DecimalNumber) -> None: raise TypeError("ChangingDecimal can only take in a DecimalNumber") def interpolate_mobject(self, alpha: float) -> None: - self.mobject.set_value(self.number_update_func(self.rate_func(alpha))) + self.mobject.set_value(self.number_update_func(self.rate_func(alpha))) # type: ignore[attr-defined] class ChangeDecimalToValue(ChangingDecimal): + """Animate a :class:`~.DecimalNumber` to a target value using linear interpolation. + + Parameters + ---------- + decimal_mob + The :class:`~.DecimalNumber` instance to animate. + target_number + The target value to transition to. + + Examples + -------- + + .. manim:: ChangeDecimalToValueExample + + class ChangeDecimalToValueExample(Scene): + def construct(self): + number = DecimalNumber(0) + self.add(number) + self.play(ChangeDecimalToValue(number, 10, run_time=3)) + self.wait() + """ + def __init__( - self, decimal_mob: DecimalNumber, target_number: int, **kwargs + self, decimal_mob: DecimalNumber, target_number: int, **kwargs: Any ) -> None: start_number = decimal_mob.number super().__init__( diff --git a/manim/animation/rotation.py b/manim/animation/rotation.py index 7bdd42238a..2673c8cbaf 100644 --- a/manim/animation/rotation.py +++ b/manim/animation/rotation.py @@ -4,8 +4,8 @@ __all__ = ["Rotating", "Rotate"] -from collections.abc import Sequence -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any import numpy as np @@ -19,19 +19,85 @@ class Rotating(Animation): + """Animation that rotates a Mobject. + + Parameters + ---------- + mobject + The mobject to be rotated. + angle + The rotation angle in radians. Predefined constants such as ``DEGREES`` + can also be used to specify the angle in degrees. + axis + The rotation axis as a numpy vector. + about_point + The rotation center. + about_edge + If ``about_point`` is ``None``, this argument specifies + the direction of the bounding box point to be taken as + the rotation center. + run_time + The duration of the animation in seconds. + rate_func + The function defining the animation progress based on the relative + runtime (see :mod:`~.rate_functions`) . + **kwargs + Additional keyword arguments passed to :class:`~.Animation`. + + Examples + -------- + .. manim:: RotatingDemo + + class RotatingDemo(Scene): + def construct(self): + circle = Circle(radius=1, color=BLUE) + line = Line(start=ORIGIN, end=RIGHT) + arrow = Arrow(start=ORIGIN, end=RIGHT, buff=0, color=GOLD) + vg = VGroup(circle,line,arrow) + self.add(vg) + anim_kw = {"about_point": arrow.get_start(), "run_time": 1} + self.play(Rotating(arrow, 180*DEGREES, **anim_kw)) + self.play(Rotating(arrow, PI, **anim_kw)) + self.play(Rotating(vg, PI, about_point=RIGHT)) + self.play(Rotating(vg, PI, axis=UP, about_point=ORIGIN)) + self.play(Rotating(vg, PI, axis=RIGHT, about_edge=UP)) + self.play(vg.animate.move_to(ORIGIN)) + + .. manim:: RotatingDifferentAxis + + class RotatingDifferentAxis(ThreeDScene): + def construct(self): + axes = ThreeDAxes() + cube = Cube() + arrow2d = Arrow(start=[0, -1.2, 1], end=[0, 1.2, 1], color=YELLOW_E) + cube_group = VGroup(cube,arrow2d) + self.set_camera_orientation(gamma=0, phi=40*DEGREES, theta=40*DEGREES) + self.add(axes, cube_group) + play_kw = {"run_time": 1.5} + self.play(Rotating(cube_group, PI), **play_kw) + self.play(Rotating(cube_group, PI, axis=UP), **play_kw) + self.play(Rotating(cube_group, 180*DEGREES, axis=RIGHT), **play_kw) + self.wait(0.5) + + See also + -------- + :class:`~.Rotate`, :meth:`~.Mobject.rotate` + + """ + def __init__( self, mobject: Mobject, + angle: float = TAU, axis: np.ndarray = OUT, - radians: np.ndarray = TAU, about_point: np.ndarray | None = None, about_edge: np.ndarray | None = None, run_time: float = 5, rate_func: Callable[[float], float] = linear, - **kwargs, + **kwargs: Any, ) -> None: + self.angle = angle self.axis = axis - self.radians = radians self.about_point = about_point self.about_edge = about_edge super().__init__(mobject, run_time=run_time, rate_func=rate_func, **kwargs) @@ -39,7 +105,7 @@ def __init__( def interpolate_mobject(self, alpha: float) -> None: self.mobject.become(self.starting_mobject) self.mobject.rotate( - self.rate_func(alpha) * self.radians, + self.rate_func(alpha) * self.angle, axis=self.axis, about_point=self.about_point, about_edge=self.about_edge, @@ -80,6 +146,10 @@ def construct(self): Rotate(Square(side_length=0.5), angle=2*PI, rate_func=linear), ) + See also + -------- + :class:`~.Rotating`, :meth:`~.Mobject.rotate` + """ def __init__( @@ -89,7 +159,7 @@ def __init__( axis: np.ndarray = OUT, about_point: Sequence[float] | None = None, about_edge: Sequence[float] | None = None, - **kwargs, + **kwargs: Any, ) -> None: if "path_arc" not in kwargs: kwargs["path_arc"] = angle diff --git a/manim/animation/specialized.py b/manim/animation/specialized.py index e5f9e96d96..f39c29fbbb 100644 --- a/manim/animation/specialized.py +++ b/manim/animation/specialized.py @@ -6,6 +6,7 @@ from typing import Any from manim.animation.transform import Restore +from manim.mobject.mobject import Mobject from ..constants import * from .composition import LaggedStart @@ -50,7 +51,7 @@ def construct(self): def __init__( self, - mobject, + mobject: Mobject, focal_point: Sequence[float] = ORIGIN, n_mobs: int = 5, initial_opacity: float = 1, diff --git a/manim/animation/speedmodifier.py b/manim/animation/speedmodifier.py index b8ccea66d1..b762f88b65 100644 --- a/manim/animation/speedmodifier.py +++ b/manim/animation/speedmodifier.py @@ -4,7 +4,8 @@ import inspect import types -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING from numpy import piecewise diff --git a/manim/animation/transform.py b/manim/animation/transform.py index 667208d88a..5bf5b76936 100644 --- a/manim/animation/transform.py +++ b/manim/animation/transform.py @@ -28,11 +28,12 @@ import inspect import types -from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any import numpy as np +from manim.data_structures import MethodWithArgs from manim.mobject.opengl.opengl_mobject import OpenGLGroup, OpenGLMobject from .. import config @@ -208,7 +209,7 @@ def begin(self) -> None: self.mobject.align_data(self.target_copy) super().begin() - def create_target(self) -> Mobject: + def create_target(self) -> Mobject | OpenGLMobject: # Has no meaningful effect here, but may be useful # in subclasses return self.target_mobject @@ -438,13 +439,13 @@ def check_validity_of_input(self, mobject: Mobject) -> None: class _MethodAnimation(MoveToTarget): - def __init__(self, mobject, methods): + def __init__(self, mobject: Mobject, methods: list[MethodWithArgs]) -> None: self.methods = methods super().__init__(mobject) def finish(self) -> None: - for method, method_args, method_kwargs in self.methods: - method.__func__(self.mobject, *method_args, **method_kwargs) + for item in self.methods: + item.method.__func__(self.mobject, *item.args, **item.kwargs) super().finish() diff --git a/manim/animation/transform_matching_parts.py b/manim/animation/transform_matching_parts.py index dbf5dd294e..03305201f1 100644 --- a/manim/animation/transform_matching_parts.py +++ b/manim/animation/transform_matching_parts.py @@ -96,7 +96,6 @@ def __init__( # target_map transform_source = group_type() transform_target = group_type() - kwargs["final_alpha_value"] = 0 for key in set(source_map).intersection(target_map): transform_source.add(source_map[key]) transform_target.add(target_map[key]) @@ -226,7 +225,8 @@ def get_mobject_key(mobject: Mobject) -> int: mobject.save_state() mobject.center() mobject.set(height=1) - result = hash(np.round(mobject.points, 3).tobytes()) + rounded_points = np.round(mobject.points, 3) + 0.0 + result = hash(rounded_points.tobytes()) mobject.restore() return result diff --git a/manim/animation/updaters/mobject_update_utils.py b/manim/animation/updaters/mobject_update_utils.py index 213180f3bd..a332b44cce 100644 --- a/manim/animation/updaters/mobject_update_utils.py +++ b/manim/animation/updaters/mobject_update_utils.py @@ -15,7 +15,8 @@ import inspect -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING import numpy as np diff --git a/manim/animation/updaters/update.py b/manim/animation/updaters/update.py index ded160cff7..29e636db5d 100644 --- a/manim/animation/updaters/update.py +++ b/manim/animation/updaters/update.py @@ -6,11 +6,12 @@ import operator as op -import typing +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from manim.animation.animation import Animation -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from manim.mobject.mobject import Mobject @@ -24,9 +25,9 @@ class UpdateFromFunc(Animation): def __init__( self, mobject: Mobject, - update_function: typing.Callable[[Mobject], typing.Any], + update_function: Callable[[Mobject], Any], suspend_mobject_updating: bool = False, - **kwargs, + **kwargs: Any, ) -> None: self.update_function = update_function super().__init__( @@ -34,16 +35,18 @@ def __init__( ) def interpolate_mobject(self, alpha: float) -> None: - self.update_function(self.mobject) + self.update_function(self.mobject) # type: ignore[arg-type] class UpdateFromAlphaFunc(UpdateFromFunc): def interpolate_mobject(self, alpha: float) -> None: - self.update_function(self.mobject, self.rate_func(alpha)) + self.update_function(self.mobject, self.rate_func(alpha)) # type: ignore[call-arg, arg-type] class MaintainPositionRelativeTo(Animation): - def __init__(self, mobject: Mobject, tracked_mobject: Mobject, **kwargs) -> None: + def __init__( + self, mobject: Mobject, tracked_mobject: Mobject, **kwargs: Any + ) -> None: self.tracked_mobject = tracked_mobject self.diff = op.sub( mobject.get_center(), diff --git a/manim/camera/camera.py b/manim/camera/camera.py index af5899c5c5..e2137fc858 100644 --- a/manim/camera/camera.py +++ b/manim/camera/camera.py @@ -8,19 +8,21 @@ import itertools as it import operator as op import pathlib -from collections.abc import Iterable +from collections.abc import Callable, Iterable from functools import reduce -from typing import Any, Callable +from typing import TYPE_CHECKING, Any import cairo import numpy as np from PIL import Image from scipy.spatial.distance import pdist +from typing_extensions import Self + +from manim.typing import MatrixMN, PixelArray, Point3D, Point3D_Array from .. import config, logger from ..constants import * from ..mobject.mobject import Mobject -from ..mobject.types.image_mobject import AbstractImageMobject from ..mobject.types.point_cloud_mobject import PMobject from ..mobject.types.vectorized_mobject import VMobject from ..utils.color import ManimColor, ParsableManimColor, color_to_int_rgba @@ -29,6 +31,10 @@ from ..utils.iterables import list_difference_update from ..utils.space_ops import angle_of_vector +if TYPE_CHECKING: + from ..mobject.types.image_mobject import AbstractImageMobject + + LINE_JOIN_MAP = { LineJointType.AUTO: None, # TODO: this could be improved LineJointType.ROUND: cairo.LineJoin.ROUND, @@ -70,13 +76,13 @@ class Camera: def __init__( self, background_image: str | None = None, - frame_center: np.ndarray = ORIGIN, + frame_center: Point3D = ORIGIN, image_mode: str = "RGBA", n_channels: int = 4, pixel_array_dtype: str = "uint8", cairo_line_width_multiple: float = 0.01, use_z_index: bool = True, - background: np.ndarray | None = None, + background: PixelArray | None = None, pixel_height: int | None = None, pixel_width: int | None = None, frame_height: float | None = None, @@ -84,8 +90,8 @@ def __init__( frame_rate: float | None = None, background_color: ParsableManimColor | None = None, background_opacity: float | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: self.background_image = background_image self.frame_center = frame_center self.image_mode = image_mode @@ -94,6 +100,9 @@ def __init__( self.cairo_line_width_multiple = cairo_line_width_multiple self.use_z_index = use_z_index self.background = background + self.background_colored_vmobject_displayer: ( + BackgroundColoredVMobjectDisplayer | None + ) = None if pixel_height is None: pixel_height = config["pixel_height"] @@ -116,11 +125,13 @@ def __init__( self.frame_rate = frame_rate if background_color is None: - self._background_color = ManimColor.parse(config["background_color"]) + self._background_color: ManimColor = ManimColor.parse( + config["background_color"] + ) else: self._background_color = ManimColor.parse(background_color) if background_opacity is None: - self._background_opacity = config["background_opacity"] + self._background_opacity: float = config["background_opacity"] else: self._background_opacity = background_opacity @@ -129,7 +140,7 @@ def __init__( self.max_allowable_norm = config["frame_width"] self.rgb_max_val = np.iinfo(self.pixel_array_dtype).max - self.pixel_array_to_cairo_context = {} + self.pixel_array_to_cairo_context: dict[int, cairo.Context] = {} # Contains the correct method to process a list of Mobjects of the # corresponding class. If a Mobject is not an instance of a class in @@ -140,7 +151,7 @@ def __init__( self.resize_frame_shape() self.reset() - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Any) -> Camera: # This is to address a strange bug where deepcopying # will result in a segfault, which is somehow related # to the aggdraw library @@ -148,24 +159,26 @@ def __deepcopy__(self, memo): return copy.copy(self) @property - def background_color(self): + def background_color(self) -> ManimColor: return self._background_color @background_color.setter - def background_color(self, color): + def background_color(self, color: ManimColor) -> None: self._background_color = color self.init_background() @property - def background_opacity(self): + def background_opacity(self) -> float: return self._background_opacity @background_opacity.setter - def background_opacity(self, alpha): + def background_opacity(self, alpha: float) -> None: self._background_opacity = alpha self.init_background() - def type_or_raise(self, mobject: Mobject): + def type_or_raise( + self, mobject: Mobject + ) -> type[VMobject] | type[PMobject] | type[AbstractImageMobject] | type[Mobject]: """Return the type of mobject, if it is a type that can be rendered. If `mobject` is an instance of a class that inherits from a class that @@ -192,8 +205,12 @@ def type_or_raise(self, mobject: Mobject): :exc:`TypeError` When mobject is not an instance of a class that can be rendered. """ - self.display_funcs = { - VMobject: self.display_multiple_vectorized_mobjects, + from ..mobject.types.image_mobject import AbstractImageMobject + + self.display_funcs: dict[ + type[Mobject], Callable[[list[Mobject], PixelArray], Any] + ] = { + VMobject: self.display_multiple_vectorized_mobjects, # type: ignore[dict-item] PMobject: self.display_multiple_point_cloud_mobjects, AbstractImageMobject: self.display_multiple_image_mobjects, Mobject: lambda batch, pa: batch, # Do nothing @@ -206,7 +223,7 @@ def type_or_raise(self, mobject: Mobject): return _type raise TypeError(f"Displaying an object of class {_type} is not supported") - def reset_pixel_shape(self, new_height: float, new_width: float): + def reset_pixel_shape(self, new_height: float, new_width: float) -> None: """This method resets the height and width of a single pixel to the passed new_height and new_width. @@ -223,7 +240,7 @@ def reset_pixel_shape(self, new_height: float, new_width: float): self.resize_frame_shape() self.reset() - def resize_frame_shape(self, fixed_dimension: int = 0): + def resize_frame_shape(self, fixed_dimension: int = 0) -> None: """ Changes frame_shape to match the aspect ratio of the pixels, where fixed_dimension determines @@ -248,7 +265,7 @@ def resize_frame_shape(self, fixed_dimension: int = 0): self.frame_height = frame_height self.frame_width = frame_width - def init_background(self): + def init_background(self) -> None: """Initialize the background. If self.background_image is the path of an image the image is set as background; else, the default @@ -274,7 +291,9 @@ def init_background(self): ) self.background[:, :] = background_rgba - def get_image(self, pixel_array: np.ndarray | list | tuple | None = None): + def get_image( + self, pixel_array: PixelArray | list | tuple | None = None + ) -> Image.Image: """Returns an image from the passed pixel array, or from the current frame if the passed pixel array is none. @@ -286,7 +305,7 @@ def get_image(self, pixel_array: np.ndarray | list | tuple | None = None): Returns ------- - PIL.Image + PIL.Image.Image The PIL image of the array. """ if pixel_array is None: @@ -294,8 +313,8 @@ def get_image(self, pixel_array: np.ndarray | list | tuple | None = None): return Image.fromarray(pixel_array, mode=self.image_mode) def convert_pixel_array( - self, pixel_array: np.ndarray | list | tuple, convert_from_floats: bool = False - ): + self, pixel_array: PixelArray | list | tuple, convert_from_floats: bool = False + ) -> PixelArray: """Converts a pixel array from values that have floats in then to proper RGB values. @@ -321,8 +340,8 @@ def convert_pixel_array( return retval def set_pixel_array( - self, pixel_array: np.ndarray | list | tuple, convert_from_floats: bool = False - ): + self, pixel_array: PixelArray | list | tuple, convert_from_floats: bool = False + ) -> None: """Sets the pixel array of the camera to the passed pixel array. Parameters @@ -332,19 +351,21 @@ def set_pixel_array( convert_from_floats Whether or not to convert float values to proper RGB values, by default False """ - converted_array = self.convert_pixel_array(pixel_array, convert_from_floats) + converted_array: PixelArray = self.convert_pixel_array( + pixel_array, convert_from_floats + ) if not ( hasattr(self, "pixel_array") and self.pixel_array.shape == converted_array.shape ): - self.pixel_array = converted_array + self.pixel_array: PixelArray = converted_array else: # Set in place self.pixel_array[:, :, :] = converted_array[:, :, :] def set_background( - self, pixel_array: np.ndarray | list | tuple, convert_from_floats: bool = False - ): + self, pixel_array: PixelArray | list | tuple, convert_from_floats: bool = False + ) -> None: """Sets the background to the passed pixel_array after converting to valid RGB values. @@ -360,7 +381,7 @@ def set_background( # TODO, this should live in utils, not as a method of Camera def make_background_from_func( self, coords_to_colors_func: Callable[[np.ndarray], np.ndarray] - ): + ) -> PixelArray: """ Makes a pixel array for the background by using coords_to_colors_func to determine each pixel's color. Each input pixel's color. Each input to coords_to_colors_func is an (x, y) pair in space (in ordinary space coordinates; not @@ -386,7 +407,7 @@ def make_background_from_func( def set_background_from_func( self, coords_to_colors_func: Callable[[np.ndarray], np.ndarray] - ): + ) -> None: """ Sets the background to a pixel array using coords_to_colors_func to determine each pixel's color. Each input pixel's color. Each input to coords_to_colors_func is an (x, y) pair in space (in ordinary space coordinates; not @@ -400,7 +421,7 @@ def set_background_from_func( """ self.set_background(self.make_background_from_func(coords_to_colors_func)) - def reset(self): + def reset(self) -> Self: """Resets the camera's pixel array to that of the background @@ -412,7 +433,7 @@ def reset(self): self.set_pixel_array(self.background) return self - def set_frame_to_background(self, background): + def set_frame_to_background(self, background: PixelArray) -> None: self.set_pixel_array(background) #### @@ -422,7 +443,7 @@ def get_mobjects_to_display( mobjects: Iterable[Mobject], include_submobjects: bool = True, excluded_mobjects: list | None = None, - ): + ) -> list[Mobject]: """Used to get the list of mobjects to display with the camera. @@ -454,7 +475,7 @@ def get_mobjects_to_display( mobjects = list_difference_update(mobjects, all_excluded) return list(mobjects) - def is_in_frame(self, mobject: Mobject): + def is_in_frame(self, mobject: Mobject) -> bool: """Checks whether the passed mobject is in frame or not. @@ -481,7 +502,7 @@ def is_in_frame(self, mobject: Mobject): ], ) - def capture_mobject(self, mobject: Mobject, **kwargs: Any): + def capture_mobject(self, mobject: Mobject, **kwargs: Any) -> None: """Capture mobjects by storing it in :attr:`pixel_array`. This is a single-mobject version of :meth:`capture_mobjects`. @@ -497,7 +518,7 @@ def capture_mobject(self, mobject: Mobject, **kwargs: Any): """ return self.capture_mobjects([mobject], **kwargs) - def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs): + def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs: Any) -> None: """Capture mobjects by printing them on :attr:`pixel_array`. This is the essential function that converts the contents of a Scene @@ -532,7 +553,7 @@ def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs): # NOTE: None of the methods below have been mentioned outside of their definitions. Their DocStrings are not as # detailed as possible. - def get_cached_cairo_context(self, pixel_array: np.ndarray): + def get_cached_cairo_context(self, pixel_array: PixelArray) -> cairo.Context | None: """Returns the cached cairo context of the passed pixel array if it exists, and None if it doesn't. @@ -548,7 +569,7 @@ def get_cached_cairo_context(self, pixel_array: np.ndarray): """ return self.pixel_array_to_cairo_context.get(id(pixel_array), None) - def cache_cairo_context(self, pixel_array: np.ndarray, ctx: cairo.Context): + def cache_cairo_context(self, pixel_array: PixelArray, ctx: cairo.Context) -> None: """Caches the passed Pixel array into a Cairo Context Parameters @@ -560,7 +581,7 @@ def cache_cairo_context(self, pixel_array: np.ndarray, ctx: cairo.Context): """ self.pixel_array_to_cairo_context[id(pixel_array)] = ctx - def get_cairo_context(self, pixel_array: np.ndarray): + def get_cairo_context(self, pixel_array: PixelArray) -> cairo.Context: """Returns the cairo context for a pixel array after caching it to self.pixel_array_to_cairo_context If that array has already been cached, it returns the @@ -585,7 +606,7 @@ def get_cairo_context(self, pixel_array: np.ndarray): fh = self.frame_height fc = self.frame_center surface = cairo.ImageSurface.create_for_data( - pixel_array, + pixel_array.data, cairo.FORMAT_ARGB32, pw, ph, @@ -606,8 +627,8 @@ def get_cairo_context(self, pixel_array: np.ndarray): return ctx def display_multiple_vectorized_mobjects( - self, vmobjects: list, pixel_array: np.ndarray - ): + self, vmobjects: list[VMobject], pixel_array: PixelArray + ) -> None: """Displays multiple VMobjects in the pixel_array Parameters @@ -630,8 +651,8 @@ def display_multiple_vectorized_mobjects( ) def display_multiple_non_background_colored_vmobjects( - self, vmobjects: list, pixel_array: np.ndarray - ): + self, vmobjects: Iterable[VMobject], pixel_array: PixelArray + ) -> None: """Displays multiple VMobjects in the cairo context, as long as they don't have background colors. @@ -646,7 +667,7 @@ def display_multiple_non_background_colored_vmobjects( for vmobject in vmobjects: self.display_vectorized(vmobject, ctx) - def display_vectorized(self, vmobject: VMobject, ctx: cairo.Context): + def display_vectorized(self, vmobject: VMobject, ctx: cairo.Context) -> Self: """Displays a VMobject in the cairo context Parameters @@ -667,7 +688,7 @@ def display_vectorized(self, vmobject: VMobject, ctx: cairo.Context): self.apply_stroke(ctx, vmobject) return self - def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject): + def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject) -> Self: """Sets a path for the cairo context with the vmobject passed Parameters @@ -686,7 +707,7 @@ def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject): # TODO, shouldn't this be handled in transform_points_pre_display? # points = points - self.get_frame_center() if len(points) == 0: - return + return self ctx.new_path() subpaths = vmobject.gen_subpaths_from_points_2d(points) @@ -702,8 +723,8 @@ def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject): return self def set_cairo_context_color( - self, ctx: cairo.Context, rgbas: np.ndarray, vmobject: VMobject - ): + self, ctx: cairo.Context, rgbas: MatrixMN, vmobject: VMobject + ) -> Self: """Sets the color of the cairo context Parameters @@ -735,7 +756,7 @@ def set_cairo_context_color( ctx.set_source(pat) return self - def apply_fill(self, ctx: cairo.Context, vmobject: VMobject): + def apply_fill(self, ctx: cairo.Context, vmobject: VMobject) -> Self: """Fills the cairo context Parameters @@ -756,7 +777,7 @@ def apply_fill(self, ctx: cairo.Context, vmobject: VMobject): def apply_stroke( self, ctx: cairo.Context, vmobject: VMobject, background: bool = False - ): + ) -> Self: """Applies a stroke to the VMobject in the cairo context. Parameters @@ -795,7 +816,9 @@ def apply_stroke( ctx.stroke_preserve() return self - def get_stroke_rgbas(self, vmobject: VMobject, background: bool = False): + def get_stroke_rgbas( + self, vmobject: VMobject, background: bool = False + ) -> PixelArray: """Gets the RGBA array for the stroke of the passed VMobject. @@ -814,7 +837,7 @@ def get_stroke_rgbas(self, vmobject: VMobject, background: bool = False): """ return vmobject.get_stroke_rgbas(background) - def get_fill_rgbas(self, vmobject: VMobject): + def get_fill_rgbas(self, vmobject: VMobject) -> PixelArray: """Returns the RGBA array of the fill of the passed VMobject Parameters @@ -829,25 +852,27 @@ def get_fill_rgbas(self, vmobject: VMobject): """ return vmobject.get_fill_rgbas() - def get_background_colored_vmobject_displayer(self): + def get_background_colored_vmobject_displayer( + self, + ) -> BackgroundColoredVMobjectDisplayer: """Returns the background_colored_vmobject_displayer if it exists or makes one and returns it if not. Returns ------- - BackGroundColoredVMobjectDisplayer + BackgroundColoredVMobjectDisplayer Object that displays VMobjects that have the same color as the background. """ - # Quite wordy to type out a bunch - bcvd = "background_colored_vmobject_displayer" - if not hasattr(self, bcvd): - setattr(self, bcvd, BackgroundColoredVMobjectDisplayer(self)) - return getattr(self, bcvd) + if self.background_colored_vmobject_displayer is None: + self.background_colored_vmobject_displayer = ( + BackgroundColoredVMobjectDisplayer(self) + ) + return self.background_colored_vmobject_displayer def display_multiple_background_colored_vmobjects( - self, cvmobjects: list, pixel_array: np.ndarray - ): + self, cvmobjects: Iterable[VMobject], pixel_array: PixelArray + ) -> Self: """Displays multiple vmobjects that have the same color as the background. Parameters @@ -873,8 +898,8 @@ def display_multiple_background_colored_vmobjects( # As a result, the other methods do not have as detailed docstrings as would be preferred. def display_multiple_point_cloud_mobjects( - self, pmobjects: list, pixel_array: np.ndarray - ): + self, pmobjects: list, pixel_array: PixelArray + ) -> None: """Displays multiple PMobjects by modifying the passed pixel array. Parameters @@ -899,8 +924,8 @@ def display_point_cloud( points: list, rgbas: np.ndarray, thickness: float, - pixel_array: np.ndarray, - ): + pixel_array: PixelArray, + ) -> None: """Displays a PMobject by modifying the pixel array suitably. TODO: Write a description for the rgbas argument. @@ -948,7 +973,7 @@ def display_point_cloud( def display_multiple_image_mobjects( self, image_mobjects: list, pixel_array: np.ndarray - ): + ) -> None: """Displays multiple image mobjects by modifying the passed pixel_array. Parameters @@ -963,7 +988,7 @@ def display_multiple_image_mobjects( def display_image_mobject( self, image_mobject: AbstractImageMobject, pixel_array: np.ndarray - ): + ) -> None: """Displays an ImageMobject by changing the pixel_array suitably. Parameters @@ -1020,7 +1045,9 @@ def display_image_mobject( # Paint on top of existing pixel array self.overlay_PIL_image(pixel_array, full_image) - def overlay_rgba_array(self, pixel_array: np.ndarray, new_array: np.ndarray): + def overlay_rgba_array( + self, pixel_array: np.ndarray, new_array: np.ndarray + ) -> None: """Overlays an RGBA array on top of the given Pixel array. Parameters @@ -1032,7 +1059,7 @@ def overlay_rgba_array(self, pixel_array: np.ndarray, new_array: np.ndarray): """ self.overlay_PIL_image(pixel_array, self.get_image(new_array)) - def overlay_PIL_image(self, pixel_array: np.ndarray, image: Image): + def overlay_PIL_image(self, pixel_array: np.ndarray, image: Image) -> None: """Overlays a PIL image on the passed pixel array. Parameters @@ -1047,7 +1074,7 @@ def overlay_PIL_image(self, pixel_array: np.ndarray, image: Image): dtype="uint8", ) - def adjust_out_of_range_points(self, points: np.ndarray): + def adjust_out_of_range_points(self, points: np.ndarray) -> np.ndarray: """If any of the points in the passed array are out of the viable range, they are adjusted suitably. @@ -1078,9 +1105,9 @@ def adjust_out_of_range_points(self, points: np.ndarray): def transform_points_pre_display( self, - mobject, - points, - ): # TODO: Write more detailed docstrings for this method. + mobject: Mobject, + points: Point3D_Array, + ) -> Point3D_Array: # TODO: Write more detailed docstrings for this method. # NOTE: There seems to be an unused argument `mobject`. # Subclasses (like ThreeDCamera) may want to @@ -1093,9 +1120,9 @@ def transform_points_pre_display( def points_to_pixel_coords( self, - mobject, - points, - ): # TODO: Write more detailed docstrings for this method. + mobject: Mobject, + points: np.ndarray, + ) -> np.ndarray: # TODO: Write more detailed docstrings for this method. points = self.transform_points_pre_display(mobject, points) shifted_points = points - self.frame_center @@ -1115,7 +1142,7 @@ def points_to_pixel_coords( result[:, 1] = shifted_points[:, 1] * height_mult + height_add return result.astype("int") - def on_screen_pixels(self, pixel_coords: np.ndarray): + def on_screen_pixels(self, pixel_coords: np.ndarray) -> PixelArray: """Returns array of pixels that are on the screen from a given array of pixel_coordinates @@ -1154,12 +1181,12 @@ def adjusted_thickness(self, thickness: float) -> float: the camera. """ # TODO: This seems...unsystematic - big_sum = op.add(config["pixel_height"], config["pixel_width"]) - this_sum = op.add(self.pixel_height, self.pixel_width) + big_sum: float = op.add(config["pixel_height"], config["pixel_width"]) + this_sum: float = op.add(self.pixel_height, self.pixel_width) factor = big_sum / this_sum return 1 + (thickness - 1) * factor - def get_thickening_nudges(self, thickness: float): + def get_thickening_nudges(self, thickness: float) -> PixelArray: """Determine a list of vectors used to nudge two-dimensional pixel coordinates. @@ -1176,7 +1203,9 @@ def get_thickening_nudges(self, thickness: float): _range = list(range(-thickness // 2 + 1, thickness // 2 + 1)) return np.array(list(it.product(_range, _range))) - def thickened_coordinates(self, pixel_coords: np.ndarray, thickness: float): + def thickened_coordinates( + self, pixel_coords: np.ndarray, thickness: float + ) -> PixelArray: """Returns thickened coordinates for a passed array of pixel coords and a thickness to thicken by. @@ -1198,7 +1227,7 @@ def thickened_coordinates(self, pixel_coords: np.ndarray, thickness: float): return pixel_coords.reshape((size // 2, 2)) # TODO, reimplement using cairo matrix - def get_coords_of_all_pixels(self): + def get_coords_of_all_pixels(self) -> PixelArray: """Returns the cartesian coordinates of each pixel. Returns @@ -1246,20 +1275,20 @@ class BackgroundColoredVMobjectDisplayer: def __init__(self, camera: Camera): self.camera = camera - self.file_name_to_pixel_array_map = {} + self.file_name_to_pixel_array_map: dict[str, PixelArray] = {} self.pixel_array = np.array(camera.pixel_array) self.reset_pixel_array() - def reset_pixel_array(self): + def reset_pixel_array(self) -> None: self.pixel_array[:, :] = 0 def resize_background_array( self, - background_array: np.ndarray, + background_array: PixelArray, new_width: float, new_height: float, mode: str = "RGBA", - ): + ) -> PixelArray: """Resizes the pixel array representing the background. Parameters @@ -1284,8 +1313,8 @@ def resize_background_array( return np.array(resized_image) def resize_background_array_to_match( - self, background_array: np.ndarray, pixel_array: np.ndarray - ): + self, background_array: PixelArray, pixel_array: PixelArray + ) -> PixelArray: """Resizes the background array to match the passed pixel array. Parameters @@ -1304,7 +1333,9 @@ def resize_background_array_to_match( mode = "RGBA" if pixel_array.shape[2] == 4 else "RGB" return self.resize_background_array(background_array, width, height, mode) - def get_background_array(self, image: Image.Image | pathlib.Path | str): + def get_background_array( + self, image: Image.Image | pathlib.Path | str + ) -> PixelArray: """Gets the background array that has the passed file_name. Parameters @@ -1333,7 +1364,7 @@ def get_background_array(self, image: Image.Image | pathlib.Path | str): self.file_name_to_pixel_array_map[image_key] = back_array return back_array - def display(self, *cvmobjects: VMobject): + def display(self, *cvmobjects: VMobject) -> PixelArray | None: """Displays the colored VMobjects. Parameters diff --git a/manim/camera/mapping_camera.py b/manim/camera/mapping_camera.py index 03f0afc3b4..4d347d02a3 100644 --- a/manim/camera/mapping_camera.py +++ b/manim/camera/mapping_camera.py @@ -1,4 +1,4 @@ -"""A camera that allows mapping between objects.""" +"""A camera module that supports spatial mapping between objects for distortion effects.""" from __future__ import annotations @@ -17,8 +17,16 @@ class MappingCamera(Camera): - """Camera object that allows mapping - between objects. + """Parameters + ---------- + mapping_func : callable + Function to map 3D points to new 3D points (identity by default). + min_num_curves : int + Minimum number of curves for VMobjects to avoid visual glitches. + allow_object_intrusion : bool + If True, modifies original mobjects; else works on copies. + kwargs : dict + Additional arguments passed to Camera base class. """ def __init__( @@ -34,12 +42,18 @@ def __init__( super().__init__(**kwargs) def points_to_pixel_coords(self, mobject, points): + # Map points with custom function before converting to pixels return super().points_to_pixel_coords( mobject, np.apply_along_axis(self.mapping_func, 1, points), ) def capture_mobjects(self, mobjects, **kwargs): + """Capture mobjects for rendering after applying the spatial mapping. + + Copies mobjects unless intrusion is allowed, and ensures + vector objects have enough curves for smooth distortion. + """ mobjects = self.get_mobjects_to_display(mobjects, **kwargs) if self.allow_object_intrusion: mobject_copies = mobjects @@ -67,6 +81,13 @@ def capture_mobjects(self, mobjects, **kwargs): # TODO, the classes below should likely be deleted class OldMultiCamera(Camera): + """Parameters + ---------- + cameras_with_start_positions : tuple + Tuples of (Camera, (start_y, start_x)) indicating camera and + its pixel offset on the final frame. + """ + def __init__(self, *cameras_with_start_positions, **kwargs): self.shifted_cameras = [ DictAsObject( @@ -125,6 +146,15 @@ def init_background(self): class SplitScreenCamera(OldMultiCamera): + """Initializes a split screen camera setup with two side-by-side cameras. + + Parameters + ---------- + left_camera : Camera + right_camera : Camera + kwargs : dict + """ + def __init__(self, left_camera, right_camera, **kwargs): Camera.__init__(self, **kwargs) # to set attributes such as pixel_width self.left_camera = left_camera diff --git a/manim/camera/moving_camera.py b/manim/camera/moving_camera.py index 1d01d01e22..deff555b85 100644 --- a/manim/camera/moving_camera.py +++ b/manim/camera/moving_camera.py @@ -1,15 +1,17 @@ -"""A camera able to move through a scene. +"""Defines the MovingCamera class, a camera that can pan and zoom through a scene. .. SEEALSO:: :mod:`.moving_camera_scene` - """ from __future__ import annotations __all__ = ["MovingCamera"] +from collections.abc import Iterable +from typing import Any + import numpy as np from .. import config @@ -17,29 +19,28 @@ from ..constants import DOWN, LEFT, RIGHT, UP from ..mobject.frame import ScreenRectangle from ..mobject.mobject import Mobject -from ..utils.color import WHITE +from ..utils.color import WHITE, ManimColor class MovingCamera(Camera): - """ - Stays in line with the height, width and position of it's 'frame', which is a Rectangle + """A camera that follows and matches the size and position of its 'frame', a Rectangle (or similar Mobject). + + The frame defines the region of space the camera displays and can move or resize dynamically. .. SEEALSO:: :class:`.MovingCameraScene` - """ def __init__( self, frame=None, - fixed_dimension=0, # width - default_frame_stroke_color=WHITE, - default_frame_stroke_width=0, - **kwargs, - ): - """ - Frame is a Mobject, (should almost certainly be a rectangle) + fixed_dimension: int = 0, # width + default_frame_stroke_color: ManimColor = WHITE, + default_frame_stroke_width: int = 0, + **kwargs: Any, + ) -> None: + """Frame is a Mobject, (should almost certainly be a rectangle) determining which region of space the camera displays """ self.fixed_dimension = fixed_dimension @@ -123,7 +124,7 @@ def frame_center(self, frame_center: np.ndarray | list | tuple | Mobject): """ self.frame.move_to(frame_center) - def capture_mobjects(self, mobjects, **kwargs): + def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs: Any) -> None: # self.reset_frame_center() # self.realign_frame_shape() super().capture_mobjects(mobjects, **kwargs) @@ -132,16 +133,14 @@ def capture_mobjects(self, mobjects, **kwargs): # context used for updating should be regenerated # at each frame. So no caching. def get_cached_cairo_context(self, pixel_array): - """ - Since the frame can be moving around, the cairo + """Since the frame can be moving around, the cairo context used for updating should be regenerated at each frame. So no caching. """ return None def cache_cairo_context(self, pixel_array, ctx): - """ - Since the frame can be moving around, the cairo + """Since the frame can be moving around, the cairo context used for updating should be regenerated at each frame. So no caching. """ @@ -159,8 +158,7 @@ def cache_cairo_context(self, pixel_array, ctx): # self.resize_frame_shape(fixed_dimension=self.fixed_dimension) def get_mobjects_indicating_movement(self): - """ - Returns all mobjects whose movement implies that the camera + """Returns all mobjects whose movement implies that the camera should think of all other mobjects on the screen as moving Returns diff --git a/manim/camera/multi_camera.py b/manim/camera/multi_camera.py index a5202135e9..f4bd18a47c 100644 --- a/manim/camera/multi_camera.py +++ b/manim/camera/multi_camera.py @@ -5,7 +5,13 @@ __all__ = ["MultiCamera"] -from manim.mobject.types.image_mobject import ImageMobject +from collections.abc import Iterable +from typing import Any + +from typing_extensions import Self + +from manim.mobject.mobject import Mobject +from manim.mobject.types.image_mobject import ImageMobjectFromCamera from ..camera.moving_camera import MovingCamera from ..utils.iterables import list_difference_update @@ -16,10 +22,10 @@ class MultiCamera(MovingCamera): def __init__( self, - image_mobjects_from_cameras: ImageMobject | None = None, - allow_cameras_to_capture_their_own_display=False, - **kwargs, - ): + image_mobjects_from_cameras: Iterable[ImageMobjectFromCamera] | None = None, + allow_cameras_to_capture_their_own_display: bool = False, + **kwargs: Any, + ) -> None: """Initialises the MultiCamera Parameters @@ -29,7 +35,7 @@ def __init__( kwargs Any valid keyword arguments of MovingCamera. """ - self.image_mobjects_from_cameras = [] + self.image_mobjects_from_cameras: list[ImageMobjectFromCamera] = [] if image_mobjects_from_cameras is not None: for imfc in image_mobjects_from_cameras: self.add_image_mobject_from_camera(imfc) @@ -38,7 +44,9 @@ def __init__( ) super().__init__(**kwargs) - def add_image_mobject_from_camera(self, image_mobject_from_camera: ImageMobject): + def add_image_mobject_from_camera( + self, image_mobject_from_camera: ImageMobjectFromCamera + ) -> None: """Adds an ImageMobject that's been obtained from the camera into the list ``self.image_mobject_from_cameras`` @@ -53,20 +61,20 @@ def add_image_mobject_from_camera(self, image_mobject_from_camera: ImageMobject) assert isinstance(imfc.camera, MovingCamera) self.image_mobjects_from_cameras.append(imfc) - def update_sub_cameras(self): + def update_sub_cameras(self) -> None: """Reshape sub_camera pixel_arrays""" for imfc in self.image_mobjects_from_cameras: pixel_height, pixel_width = self.pixel_array.shape[:2] - imfc.camera.frame_shape = ( - imfc.camera.frame.height, - imfc.camera.frame.width, - ) + # imfc.camera.frame_shape = ( + # imfc.camera.frame.height, + # imfc.camera.frame.width, + # ) imfc.camera.reset_pixel_shape( int(pixel_height * imfc.height / self.frame_height), int(pixel_width * imfc.width / self.frame_width), ) - def reset(self): + def reset(self) -> Self: """Resets the MultiCamera. Returns @@ -79,7 +87,7 @@ def reset(self): super().reset() return self - def capture_mobjects(self, mobjects, **kwargs): + def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs: Any) -> None: self.update_sub_cameras() for imfc in self.image_mobjects_from_cameras: to_add = list(mobjects) @@ -88,7 +96,7 @@ def capture_mobjects(self, mobjects, **kwargs): imfc.camera.capture_mobjects(to_add, **kwargs) super().capture_mobjects(mobjects, **kwargs) - def get_mobjects_indicating_movement(self): + def get_mobjects_indicating_movement(self) -> list[Mobject]: """Returns all mobjects whose movement implies that the camera should think of all other mobjects on the screen as moving diff --git a/manim/camera/three_d_camera.py b/manim/camera/three_d_camera.py index f45854e810..3d44b3e910 100644 --- a/manim/camera/three_d_camera.py +++ b/manim/camera/three_d_camera.py @@ -5,7 +5,8 @@ __all__ = ["ThreeDCamera"] -from typing import Callable +from collections.abc import Callable, Iterable +from typing import Any import numpy as np @@ -16,7 +17,14 @@ get_3d_vmob_start_corner, get_3d_vmob_start_corner_unit_normal, ) +from manim.mobject.types.vectorized_mobject import VMobject from manim.mobject.value_tracker import ValueTracker +from manim.typing import ( + MatrixMN, + Point3D, + Point3D_Array, + Point3DLike, +) from .. import config from ..camera.camera import Camera @@ -30,17 +38,17 @@ class ThreeDCamera(Camera): def __init__( self, - focal_distance=20.0, - shading_factor=0.2, - default_distance=5.0, - light_source_start_point=9 * DOWN + 7 * LEFT + 10 * OUT, - should_apply_shading=True, - exponential_projection=False, - phi=0, - theta=-90 * DEGREES, - gamma=0, - zoom=1, - **kwargs, + focal_distance: float = 20.0, + shading_factor: float = 0.2, + default_distance: float = 5.0, + light_source_start_point: Point3DLike = 9 * DOWN + 7 * LEFT + 10 * OUT, + should_apply_shading: bool = True, + exponential_projection: bool = False, + phi: float = 0, + theta: float = -90 * DEGREES, + gamma: float = 0, + zoom: float = 1, + **kwargs: Any, ): """Initializes the ThreeDCamera @@ -68,23 +76,23 @@ def __init__( self.focal_distance_tracker = ValueTracker(self.focal_distance) self.gamma_tracker = ValueTracker(self.gamma) self.zoom_tracker = ValueTracker(self.zoom) - self.fixed_orientation_mobjects = {} - self.fixed_in_frame_mobjects = set() + self.fixed_orientation_mobjects: dict[Mobject, Callable[[], Point3D]] = {} + self.fixed_in_frame_mobjects: set[Mobject] = set() self.reset_rotation_matrix() @property - def frame_center(self): + def frame_center(self) -> Point3D: return self._frame_center.points[0] @frame_center.setter - def frame_center(self, point): + def frame_center(self, point: Point3DLike) -> None: self._frame_center.move_to(point) - def capture_mobjects(self, mobjects, **kwargs): + def capture_mobjects(self, mobjects: Iterable[Mobject], **kwargs: Any) -> None: self.reset_rotation_matrix() super().capture_mobjects(mobjects, **kwargs) - def get_value_trackers(self): + def get_value_trackers(self) -> list[ValueTracker]: """A list of :class:`ValueTrackers <.ValueTracker>` of phi, theta, focal_distance, gamma and zoom. @@ -101,7 +109,7 @@ def get_value_trackers(self): self.zoom_tracker, ] - def modified_rgbas(self, vmobject, rgbas): + def modified_rgbas(self, vmobject: VMobject, rgbas: MatrixMN) -> MatrixMN: if not self.should_apply_shading: return rgbas if vmobject.shade_in_3d and (vmobject.get_num_points() > 0): @@ -127,28 +135,33 @@ def modified_rgbas(self, vmobject, rgbas): def get_stroke_rgbas( self, - vmobject, - background=False, - ): # NOTE : DocStrings From parent + vmobject: VMobject, + background: bool = False, + ) -> MatrixMN: # NOTE : DocStrings From parent return self.modified_rgbas(vmobject, vmobject.get_stroke_rgbas(background)) - def get_fill_rgbas(self, vmobject): # NOTE : DocStrings From parent + def get_fill_rgbas( + self, vmobject: VMobject + ) -> MatrixMN: # NOTE : DocStrings From parent return self.modified_rgbas(vmobject, vmobject.get_fill_rgbas()) - def get_mobjects_to_display(self, *args, **kwargs): # NOTE : DocStrings From parent + def get_mobjects_to_display( + self, *args: Any, **kwargs: Any + ) -> list[Mobject]: # NOTE : DocStrings From parent mobjects = super().get_mobjects_to_display(*args, **kwargs) rot_matrix = self.get_rotation_matrix() - def z_key(mob): + def z_key(mob: Mobject) -> float: if not (hasattr(mob, "shade_in_3d") and mob.shade_in_3d): - return np.inf + return np.inf # type: ignore[no-any-return] # Assign a number to a three dimensional mobjects # based on how close it is to the camera - return np.dot(mob.get_z_index_reference_point(), rot_matrix.T)[2] + distance: float = np.dot(mob.get_z_index_reference_point(), rot_matrix.T)[2] + return distance return sorted(mobjects, key=z_key) - def get_phi(self): + def get_phi(self) -> float: """Returns the Polar angle (the angle off Z_AXIS) phi. Returns @@ -158,7 +171,7 @@ def get_phi(self): """ return self.phi_tracker.get_value() - def get_theta(self): + def get_theta(self) -> float: """Returns the Azimuthal i.e the angle that spins the camera around the Z_AXIS. Returns @@ -168,7 +181,7 @@ def get_theta(self): """ return self.theta_tracker.get_value() - def get_focal_distance(self): + def get_focal_distance(self) -> float: """Returns focal_distance of the Camera. Returns @@ -178,7 +191,7 @@ def get_focal_distance(self): """ return self.focal_distance_tracker.get_value() - def get_gamma(self): + def get_gamma(self) -> float: """Returns the rotation of the camera about the vector from the ORIGIN to the Camera. Returns @@ -189,7 +202,7 @@ def get_gamma(self): """ return self.gamma_tracker.get_value() - def get_zoom(self): + def get_zoom(self) -> float: """Returns the zoom amount of the camera. Returns @@ -199,7 +212,7 @@ def get_zoom(self): """ return self.zoom_tracker.get_value() - def set_phi(self, value: float): + def set_phi(self, value: float) -> None: """Sets the polar angle i.e the angle between Z_AXIS and Camera through ORIGIN in radians. Parameters @@ -209,7 +222,7 @@ def set_phi(self, value: float): """ self.phi_tracker.set_value(value) - def set_theta(self, value: float): + def set_theta(self, value: float) -> None: """Sets the azimuthal angle i.e the angle that spins the camera around Z_AXIS in radians. Parameters @@ -219,7 +232,7 @@ def set_theta(self, value: float): """ self.theta_tracker.set_value(value) - def set_focal_distance(self, value: float): + def set_focal_distance(self, value: float) -> None: """Sets the focal_distance of the Camera. Parameters @@ -229,7 +242,7 @@ def set_focal_distance(self, value: float): """ self.focal_distance_tracker.set_value(value) - def set_gamma(self, value: float): + def set_gamma(self, value: float) -> None: """Sets the angle of rotation of the camera about the vector from the ORIGIN to the Camera. Parameters @@ -239,7 +252,7 @@ def set_gamma(self, value: float): """ self.gamma_tracker.set_value(value) - def set_zoom(self, value: float): + def set_zoom(self, value: float) -> None: """Sets the zoom amount of the camera. Parameters @@ -249,13 +262,13 @@ def set_zoom(self, value: float): """ self.zoom_tracker.set_value(value) - def reset_rotation_matrix(self): + def reset_rotation_matrix(self) -> None: """Sets the value of self.rotation_matrix to the matrix corresponding to the current position of the camera """ self.rotation_matrix = self.generate_rotation_matrix() - def get_rotation_matrix(self): + def get_rotation_matrix(self) -> MatrixMN: """Returns the matrix corresponding to the current position of the camera. Returns @@ -265,7 +278,7 @@ def get_rotation_matrix(self): """ return self.rotation_matrix - def generate_rotation_matrix(self): + def generate_rotation_matrix(self) -> MatrixMN: """Generates a rotation matrix based off the current position of the camera. Returns @@ -286,7 +299,7 @@ def generate_rotation_matrix(self): result = np.dot(matrix, result) return result - def project_points(self, points: np.ndarray | list): + def project_points(self, points: Point3D_Array) -> Point3D_Array: """Applies the current rotation_matrix as a projection matrix to the passed array of points. @@ -323,7 +336,7 @@ def project_points(self, points: np.ndarray | list): points[:, i] *= factor * zoom return points - def project_point(self, point: list | np.ndarray): + def project_point(self, point: Point3D) -> Point3D: """Applies the current rotation_matrix as a projection matrix to the passed point. @@ -341,9 +354,9 @@ def project_point(self, point: list | np.ndarray): def transform_points_pre_display( self, - mobject, - points, - ): # TODO: Write Docstrings for this Method. + mobject: Mobject, + points: Point3D_Array, + ) -> Point3D_Array: # TODO: Write Docstrings for this Method. points = super().transform_points_pre_display(mobject, points) fixed_orientation = mobject in self.fixed_orientation_mobjects fixed_in_frame = mobject in self.fixed_in_frame_mobjects @@ -362,8 +375,8 @@ def add_fixed_orientation_mobjects( self, *mobjects: Mobject, use_static_center_func: bool = False, - center_func: Callable[[], np.ndarray] | None = None, - ): + center_func: Callable[[], Point3D] | None = None, + ) -> None: """This method allows the mobject to have a fixed orientation, even when the camera moves around. E.G If it was passed through this method, facing the camera, it @@ -384,7 +397,7 @@ def add_fixed_orientation_mobjects( # This prevents the computation of mobject.get_center # every single time a projection happens - def get_static_center_func(mobject): + def get_static_center_func(mobject: Mobject) -> Callable[[], Point3D]: point = mobject.get_center() return lambda: point @@ -398,7 +411,7 @@ def get_static_center_func(mobject): for submob in mobject.get_family(): self.fixed_orientation_mobjects[submob] = func - def add_fixed_in_frame_mobjects(self, *mobjects: Mobject): + def add_fixed_in_frame_mobjects(self, *mobjects: Mobject) -> None: """This method allows the mobject to have a fixed position, even when the camera moves around. E.G If it was passed through this method, at the top of the frame, it @@ -414,7 +427,7 @@ def add_fixed_in_frame_mobjects(self, *mobjects: Mobject): for mobject in extract_mobject_family_members(mobjects): self.fixed_in_frame_mobjects.add(mobject) - def remove_fixed_orientation_mobjects(self, *mobjects: Mobject): + def remove_fixed_orientation_mobjects(self, *mobjects: Mobject) -> None: """If a mobject was fixed in its orientation by passing it through :meth:`.add_fixed_orientation_mobjects`, then this undoes that fixing. The Mobject will no longer have a fixed orientation. @@ -428,7 +441,7 @@ def remove_fixed_orientation_mobjects(self, *mobjects: Mobject): if mobject in self.fixed_orientation_mobjects: del self.fixed_orientation_mobjects[mobject] - def remove_fixed_in_frame_mobjects(self, *mobjects: Mobject): + def remove_fixed_in_frame_mobjects(self, *mobjects: Mobject) -> None: """If a mobject was fixed in frame by passing it through :meth:`.add_fixed_in_frame_mobjects`, then this undoes that fixing. The Mobject will no longer be fixed in frame. diff --git a/manim/cli/cfg/group.py b/manim/cli/cfg/group.py index 13834311ab..3945499f8b 100644 --- a/manim/cli/cfg/group.py +++ b/manim/cli/cfg/group.py @@ -267,6 +267,12 @@ def write(level: str | None = None, openfile: bool = False) -> None: @cfg.command(context_settings=cli_ctx_settings) def show() -> None: + console.print("CONFIG FILES READ", style="bold green underline") + for path in config_file_paths(): + if path.exists(): + console.print(f"{path}") + console.print() + parser = make_config_parser() rich_non_style_entries = [a.replace(".", "_") for a in RICH_NON_STYLE_ENTRIES] for category in parser: diff --git a/manim/cli/checkhealth/checks.py b/manim/cli/checkhealth/checks.py index ec9c07dec7..aabb2b3b47 100644 --- a/manim/cli/checkhealth/checks.py +++ b/manim/cli/checkhealth/checks.py @@ -6,7 +6,8 @@ import os import shutil -from typing import Callable, Protocol, cast +from collections.abc import Callable +from typing import Protocol, cast __all__ = ["HEALTH_CHECKS"] diff --git a/manim/cli/default_group.py b/manim/cli/default_group.py index 579a3e3a05..e3cdcb710c 100644 --- a/manim/cli/default_group.py +++ b/manim/cli/default_group.py @@ -13,7 +13,8 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING, Any import cloup diff --git a/manim/cli/init/commands.py b/manim/cli/init/commands.py index dd9d64837f..197e97bf71 100644 --- a/manim/cli/init/commands.py +++ b/manim/cli/init/commands.py @@ -76,8 +76,8 @@ def update_cfg(cfg_dict: dict[str, Any], project_cfg_path: Path) -> None: cli_config = config["CLI"] for key, value in cfg_dict.items(): if key == "resolution": - cli_config["pixel_height"] = str(value[0]) - cli_config["pixel_width"] = str(value[1]) + cli_config["pixel_width"] = str(value[0]) + cli_config["pixel_height"] = str(value[1]) else: cli_config[key] = str(value) diff --git a/manim/data_structures.py b/manim/data_structures.py new file mode 100644 index 0000000000..0b9309f0b1 --- /dev/null +++ b/manim/data_structures.py @@ -0,0 +1,31 @@ +"""Data classes and other necessary data structures for use in Manim.""" + +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass +from types import MethodType +from typing import Any + + +@dataclass +class MethodWithArgs: + """Object containing a :attr:`method` which is intended to be called later + with the positional arguments :attr:`args` and the keyword arguments + :attr:`kwargs`. + + Attributes + ---------- + method : MethodType + A callable representing a method of some class. + args : Iterable[Any] + Positional arguments for :attr:`method`. + kwargs : dict[str, Any] + Keyword arguments for :attr:`method`. + """ + + __slots__ = ["method", "args", "kwargs"] + + method: MethodType + args: Iterable[Any] + kwargs: dict[str, Any] diff --git a/manim/gui/__init__.py b/manim/gui/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/manim/gui/gui.py b/manim/gui/gui.py deleted file mode 100644 index 75ec67312c..0000000000 --- a/manim/gui/gui.py +++ /dev/null @@ -1,84 +0,0 @@ -from __future__ import annotations - -from pathlib import Path - -try: - import dearpygui.dearpygui as dpg - - dearpygui_imported = True -except ImportError: - dearpygui_imported = False - - -from .. import __version__, config -from ..utils.module_ops import scene_classes_from_file - -__all__ = ["configure_pygui"] - -if dearpygui_imported: - dpg.create_context() - window = dpg.generate_uuid() - - -def configure_pygui(renderer, widgets, update=True): - if not dearpygui_imported: - raise RuntimeError("Attempted to use DearPyGUI when it isn't imported.") - if update: - dpg.delete_item(window) - else: - dpg.create_viewport() - dpg.setup_dearpygui() - dpg.show_viewport() - - dpg.set_viewport_title(title=f"Manim Community v{__version__}") - dpg.set_viewport_width(1015) - dpg.set_viewport_height(540) - - def rerun_callback(sender, data): - renderer.scene.queue.put(("rerun_gui", [], {})) - - def continue_callback(sender, data): - renderer.scene.queue.put(("exit_gui", [], {})) - - def scene_selection_callback(sender, data): - config["scene_names"] = (dpg.get_value(sender),) - renderer.scene.queue.put(("rerun_gui", [], {})) - - scene_classes = scene_classes_from_file(Path(config["input_file"]), full_list=True) - scene_names = [scene_class.__name__ for scene_class in scene_classes] - - with dpg.window( - id=window, - label="Manim GUI", - pos=[config["gui_location"][0], config["gui_location"][1]], - width=1000, - height=500, - ): - dpg.set_global_font_scale(2) - dpg.add_button(label="Rerun", callback=rerun_callback) - dpg.add_button(label="Continue", callback=continue_callback) - dpg.add_combo( - label="Selected scene", - items=scene_names, - callback=scene_selection_callback, - default_value=config["scene_names"][0], - ) - dpg.add_separator() - if len(widgets) != 0: - with dpg.collapsing_header( - label=f"{config['scene_names'][0]} widgets", - default_open=True, - ): - for widget_config in widgets: - widget_config_copy = widget_config.copy() - name = widget_config_copy["name"] - widget = widget_config_copy["widget"] - if widget != "separator": - del widget_config_copy["name"] - del widget_config_copy["widget"] - getattr(dpg, f"add_{widget}")(label=name, **widget_config_copy) - else: - dpg.add_separator() - - if not update: - dpg.start_dearpygui() diff --git a/manim/mobject/frame.py b/manim/mobject/frame.py index 639e8c384e..698e1b7247 100644 --- a/manim/mobject/frame.py +++ b/manim/mobject/frame.py @@ -8,17 +8,21 @@ ] +from typing import Any + from manim.mobject.geometry.polygram import Rectangle from .. import config class ScreenRectangle(Rectangle): - def __init__(self, aspect_ratio=16.0 / 9.0, height=4, **kwargs): + def __init__( + self, aspect_ratio: float = 16.0 / 9.0, height: float = 4, **kwargs: Any + ) -> None: super().__init__(width=aspect_ratio * height, height=height, **kwargs) @property - def aspect_ratio(self): + def aspect_ratio(self) -> float: """The aspect ratio. When set, the width is stretched to accommodate @@ -27,11 +31,11 @@ def aspect_ratio(self): return self.width / self.height @aspect_ratio.setter - def aspect_ratio(self, value): + def aspect_ratio(self, value: float) -> None: self.stretch_to_fit_width(value * self.height) class FullScreenRectangle(ScreenRectangle): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.height = config["frame_height"] diff --git a/manim/mobject/geometry/arc.py b/manim/mobject/geometry/arc.py index 2923239944..8b9e832a5b 100644 --- a/manim/mobject/geometry/arc.py +++ b/manim/mobject/geometry/arc.py @@ -44,7 +44,7 @@ def construct(self): import itertools import warnings -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np from typing_extensions import Self @@ -64,7 +64,6 @@ def construct(self): if TYPE_CHECKING: from collections.abc import Iterable - from typing import Any import manim.mobject.geometry.tips as tips from manim.mobject.mobject import Mobject @@ -74,7 +73,7 @@ def construct(self): Point3D, Point3DLike, QuadraticSpline, - Vector3D, + Vector3DLike, ) @@ -99,12 +98,12 @@ class TipableVMobject(VMobject, metaclass=ConvertToOpenGL): def __init__( self, tip_length: float = DEFAULT_ARROW_TIP_LENGTH, - normal_vector: Vector3D = OUT, + normal_vector: Vector3DLike = OUT, tip_style: dict = {}, **kwargs: Any, ) -> None: self.tip_length: float = tip_length - self.normal_vector: Vector3D = normal_vector + self.normal_vector = normal_vector self.tip_style: dict = tip_style super().__init__(**kwargs) @@ -916,7 +915,8 @@ def generate_points(self) -> None: self.append_points(outer_arc.points) self.add_line_to(inner_arc.points[0]) - init_points = generate_points + def init_points(self) -> None: + self.generate_points() class Sector(AnnularSector): @@ -990,7 +990,8 @@ def generate_points(self) -> None: self.append_points(inner_circle.points) self.shift(self.arc_center) - init_points = generate_points + def init_points(self) -> None: + self.generate_points() class CubicBezier(VMobject, metaclass=ConvertToOpenGL): diff --git a/manim/mobject/geometry/boolean_ops.py b/manim/mobject/geometry/boolean_ops.py index f02b4f7be6..ea5aa38ef1 100644 --- a/manim/mobject/geometry/boolean_ops.py +++ b/manim/mobject/geometry/boolean_ops.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np from pathops import Path as SkiaPath @@ -13,8 +13,6 @@ from manim.mobject.types.vectorized_mobject import VMobject if TYPE_CHECKING: - from typing import Any - from manim.typing import Point2DLike_Array, Point3D_Array, Point3DLike_Array from ...constants import RendererType diff --git a/manim/mobject/geometry/labeled.py b/manim/mobject/geometry/labeled.py index 51b74ccb44..a9fa8c891a 100644 --- a/manim/mobject/geometry/labeled.py +++ b/manim/mobject/geometry/labeled.py @@ -4,7 +4,7 @@ __all__ = ["Label", "LabeledLine", "LabeledArrow", "LabeledPolygram"] -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np @@ -22,8 +22,6 @@ from manim.utils.polylabel import polylabel if TYPE_CHECKING: - from typing import Any - from manim.typing import Point3DLike_Array diff --git a/manim/mobject/geometry/line.py b/manim/mobject/geometry/line.py index 2f1c37fa19..9b0553d1be 100644 --- a/manim/mobject/geometry/line.py +++ b/manim/mobject/geometry/line.py @@ -14,7 +14,7 @@ "RightAngle", ] -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -30,11 +30,9 @@ from manim.utils.space_ops import angle_of_vector, line_intersection, normalize if TYPE_CHECKING: - from typing import Any + from typing_extensions import Self, TypeAlias - from typing_extensions import Literal, Self, TypeAlias - - from manim.typing import Point2DLike, Point3D, Point3DLike, Vector3D + from manim.typing import Point3D, Point3DLike, Vector2DLike, Vector3D, Vector3DLike from manim.utils.color import ParsableManimColor from ..matrix import Matrix # Avoid circular import @@ -147,7 +145,8 @@ def set_points_by_ends( self._account_for_buff(buff) - init_points = generate_points + def init_points(self) -> None: + self.generate_points() def _account_for_buff(self, buff: float) -> None: if buff <= 0: @@ -175,7 +174,7 @@ def _set_start_and_end_attrs( def _pointify( self, mob_or_point: Mobject | Point3DLike, - direction: Vector3D | None = None, + direction: Vector3DLike | None = None, ) -> Point3D: """Transforms a mobject into its corresponding point. Does nothing if a point is passed. @@ -738,7 +737,7 @@ def construct(self): def __init__( self, - direction: Point2DLike | Point3DLike = RIGHT, + direction: Vector2DLike | Vector3DLike = RIGHT, buff: float = 0, **kwargs: Any, ) -> None: diff --git a/manim/mobject/geometry/polygram.py b/manim/mobject/geometry/polygram.py index 5a8dabdaca..3274d29ff8 100644 --- a/manim/mobject/geometry/polygram.py +++ b/manim/mobject/geometry/polygram.py @@ -18,7 +18,7 @@ from math import ceil -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -32,8 +32,6 @@ from manim.utils.space_ops import angle_between_vectors, normalize, regular_vertices if TYPE_CHECKING: - from typing import Any, Literal - import numpy.typing as npt from typing_extensions import Self diff --git a/manim/mobject/geometry/tips.py b/manim/mobject/geometry/tips.py index ea7c6c2414..5080aa8777 100644 --- a/manim/mobject/geometry/tips.py +++ b/manim/mobject/geometry/tips.py @@ -13,7 +13,7 @@ "StealthTip", ] -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np @@ -25,8 +25,6 @@ from manim.utils.space_ops import angle_of_vector if TYPE_CHECKING: - from typing import Any - from manim.typing import Point3D, Vector3D diff --git a/manim/mobject/graphing/coordinate_systems.py b/manim/mobject/graphing/coordinate_systems.py index b21879b90b..7e6f626ba3 100644 --- a/manim/mobject/graphing/coordinate_systems.py +++ b/manim/mobject/graphing/coordinate_systems.py @@ -13,8 +13,8 @@ import fractions as fr import numbers -from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, TypeVar, overload import numpy as np from typing_extensions import Self @@ -64,6 +64,7 @@ Point3D, Point3DLike, Vector3D, + Vector3DLike, ) LineType = TypeVar("LineType", bound=Line) @@ -126,7 +127,7 @@ def __init__( x_length: float | None = None, y_length: float | None = None, dimension: int = 2, - ) -> None: + ): self.dimension = dimension default_step = 1 @@ -153,11 +154,14 @@ def __init__( self.x_length = x_length self.y_length = y_length self.num_sampled_graph_points_per_tick = 10 + self.x_axis: NumberLine - def coords_to_point(self, *coords: ManimFloat): + def coords_to_point(self, *coords: ManimFloat) -> Point3D: + # TODO: I think the method should be able to return more than just a single point. + # E.g. see the implementation of it on line 2065. raise NotImplementedError() - def point_to_coords(self, point: Point3DLike): + def point_to_coords(self, point: Point3DLike) -> list[ManimFloat]: raise NotImplementedError() def polar_to_point(self, radius: float, azimuth: float) -> Point2D: @@ -201,7 +205,7 @@ def point_to_polar(self, point: Point2DLike) -> Point2D: Returns ------- - Tuple[:class:`float`, :class:`float`] + Point2D The coordinate radius (:math:`r`) and the coordinate azimuth (:math:`\theta`). """ x, y = self.point_to_coords(point) @@ -213,7 +217,7 @@ def c2p( """Abbreviation for :meth:`coords_to_point`""" return self.coords_to_point(*coords) - def p2c(self, point: Point3DLike): + def p2c(self, point: Point3DLike) -> list[ManimFloat]: """Abbreviation for :meth:`point_to_coords`""" return self.point_to_coords(point) @@ -221,17 +225,18 @@ def pr2pt(self, radius: float, azimuth: float) -> np.ndarray: """Abbreviation for :meth:`polar_to_point`""" return self.polar_to_point(radius, azimuth) - def pt2pr(self, point: np.ndarray) -> tuple[float, float]: + def pt2pr(self, point: np.ndarray) -> Point2D: """Abbreviation for :meth:`point_to_polar`""" return self.point_to_polar(point) - def get_axes(self): + def get_axes(self) -> VGroup: raise NotImplementedError() - def get_axis(self, index: int) -> Mobject: - return self.get_axes()[index] + def get_axis(self, index: int) -> NumberLine: + val: NumberLine = self.get_axes()[index] + return val - def get_origin(self) -> np.ndarray: + def get_origin(self) -> Point3D: """Gets the origin of :class:`~.Axes`. Returns @@ -241,13 +246,13 @@ def get_origin(self) -> np.ndarray: """ return self.coords_to_point(0, 0) - def get_x_axis(self) -> Mobject: + def get_x_axis(self) -> NumberLine: return self.get_axis(0) - def get_y_axis(self) -> Mobject: + def get_y_axis(self) -> NumberLine: return self.get_axis(1) - def get_z_axis(self) -> Mobject: + def get_z_axis(self) -> NumberLine: return self.get_axis(2) def get_x_unit_size(self) -> float: @@ -258,11 +263,11 @@ def get_y_unit_size(self) -> float: def get_x_axis_label( self, - label: float | str | Mobject, - edge: Sequence[float] = UR, - direction: Sequence[float] = UR, + label: float | str | VMobject, + edge: Vector3D = UR, + direction: Vector3D = UR, buff: float = SMALL_BUFF, - **kwargs, + **kwargs: Any, ) -> Mobject: """Generate an x-axis label. @@ -301,11 +306,11 @@ def construct(self): def get_y_axis_label( self, - label: float | str | Mobject, - edge: Sequence[float] = UR, - direction: Sequence[float] = UP * 0.5 + RIGHT, + label: float | str | VMobject, + edge: Vector3D = UR, + direction: Vector3D = UP * 0.5 + RIGHT, buff: float = SMALL_BUFF, - **kwargs, + **kwargs: Any, ) -> Mobject: """Generate a y-axis label. @@ -347,10 +352,10 @@ def construct(self): def _get_axis_label( self, - label: float | str | Mobject, + label: float | str | VMobject, axis: Mobject, - edge: Sequence[float], - direction: Sequence[float], + edge: Vector3DLike, + direction: Vector3DLike, buff: float = SMALL_BUFF, ) -> Mobject: """Gets the label for an axis. @@ -373,12 +378,14 @@ def _get_axis_label( :class:`~.Mobject` The positioned label along the given axis. """ - label = self.x_axis._create_label_tex(label) - label.next_to(axis.get_edge_center(edge), direction=direction, buff=buff) - label.shift_onto_screen(buff=MED_SMALL_BUFF) - return label + label_mobject: Mobject = self.x_axis._create_label_tex(label) + label_mobject.next_to( + axis.get_edge_center(edge), direction=direction, buff=buff + ) + label_mobject.shift_onto_screen(buff=MED_SMALL_BUFF) + return label_mobject - def get_axis_labels(self): + def get_axis_labels(self) -> VGroup: raise NotImplementedError() def add_coordinates( @@ -453,7 +460,7 @@ def add_coordinates( def get_line_from_axis_to_point( self, index: int, - point: Sequence[float], + point: Point3DLike, line_config: dict | None = ..., color: ParsableManimColor | None = ..., stroke_width: float = ..., @@ -463,7 +470,7 @@ def get_line_from_axis_to_point( def get_line_from_axis_to_point( self, index: int, - point: Sequence[float], + point: Point3DLike, line_func: type[LineType], line_config: dict | None = ..., color: ParsableManimColor | None = ..., @@ -518,7 +525,7 @@ def get_line_from_axis_to_point( # type: ignore[no-untyped-def] line = line_func(axis.get_projection(point), point, **line_config) return line - def get_vertical_line(self, point: Sequence[float], **kwargs: Any) -> Line: + def get_vertical_line(self, point: Point3DLike, **kwargs: Any) -> Line: """A vertical line from the x-axis to a given point in the scene. Parameters @@ -552,7 +559,7 @@ def construct(self): """ return self.get_line_from_axis_to_point(0, point, **kwargs) - def get_horizontal_line(self, point: Sequence[float], **kwargs) -> Line: + def get_horizontal_line(self, point: Point3DLike, **kwargs: Any) -> Line: """A horizontal line from the y-axis to a given point in the scene. Parameters @@ -584,7 +591,7 @@ def construct(self): """ return self.get_line_from_axis_to_point(1, point, **kwargs) - def get_lines_to_point(self, point: Sequence[float], **kwargs) -> VGroup: + def get_lines_to_point(self, point: Point3DLike, **kwargs: Any) -> VGroup: """Generate both horizontal and vertical lines from the axis to a point. Parameters @@ -630,7 +637,9 @@ def plot( function: Callable[[float], float], x_range: Sequence[float] | None = None, use_vectorized: bool = False, - colorscale: Union[Iterable[Color], Iterable[Color, float]] | None = None, + colorscale: Iterable[ParsableManimColor] + | Iterable[ParsableManimColor, float] + | None = None, colorscale_axis: int = 1, **kwargs: Any, ) -> ParametricFunction: @@ -1093,7 +1102,7 @@ def i2gp(self, x: float, graph: ParametricFunction) -> np.ndarray: def get_graph_label( self, graph: ParametricFunction, - label: float | str | Mobject = "f(x)", + label: float | str | VMobject = "f(x)", x_val: float | None = None, direction: Sequence[float] = RIGHT, buff: float = MED_SMALL_BUFF, @@ -1150,7 +1159,7 @@ def construct(self): dot_config = {} if color is None: color = graph.get_color() - label = self.x_axis._create_label_tex(label).set_color(color) + label_object: Mobject = self.x_axis._create_label_tex(label).set_color(color) if x_val is None: # Search from right to left @@ -1161,14 +1170,14 @@ def construct(self): else: point = self.input_to_graph_point(x_val, graph) - label.next_to(point, direction, buff=buff) - label.shift_onto_screen() + label_object.next_to(point, direction, buff=buff) + label_object.shift_onto_screen() if dot: dot = Dot(point=point, **dot_config) - label.add(dot) - label.dot = dot - return label + label_object.add(dot) + label_object.dot = dot + return label_object # calculus @@ -1176,14 +1185,14 @@ def get_riemann_rectangles( self, graph: ParametricFunction, x_range: Sequence[float] | None = None, - dx: float | None = 0.1, + dx: float = 0.1, input_sample_type: str = "left", stroke_width: float = 1, stroke_color: ParsableManimColor = BLACK, fill_opacity: float = 1, color: Iterable[ParsableManimColor] | ParsableManimColor = (BLUE, GREEN), show_signed_area: bool = True, - bounded_graph: ParametricFunction = None, + bounded_graph: ParametricFunction | None = None, blend: bool = False, width_scale_factor: float = 1.001, ) -> VGroup: @@ -1277,16 +1286,16 @@ def construct(self): x_range = [*x_range[:2], dx] rectangles = VGroup() - x_range = np.arange(*x_range) + x_range_array = np.arange(*x_range) if isinstance(color, (list, tuple)): color = [ManimColor(c) for c in color] else: color = [ManimColor(color)] - colors = color_gradient(color, len(x_range)) + colors = color_gradient(color, len(x_range_array)) - for x, color in zip(x_range, colors): + for x, color in zip(x_range_array, colors): if input_sample_type == "left": sample_input = x elif input_sample_type == "right": @@ -1341,7 +1350,7 @@ def get_area( x_range: tuple[float, float] | None = None, color: ParsableManimColor | Iterable[ParsableManimColor] = (BLUE, GREEN), opacity: float = 0.3, - bounded_graph: ParametricFunction = None, + bounded_graph: ParametricFunction | None = None, **kwargs: Any, ) -> Polygon: """Returns a :class:`~.Polygon` representing the area under the graph passed. @@ -1485,10 +1494,14 @@ def slope_of_tangent( ax.slope_of_tangent(x=-2, graph=curve) # -3.5000000259052038 """ - return np.tan(self.angle_of_tangent(x, graph, **kwargs)) + val: float = np.tan(self.angle_of_tangent(x, graph, **kwargs)) + return val def plot_derivative_graph( - self, graph: ParametricFunction, color: ParsableManimColor = GREEN, **kwargs + self, + graph: ParametricFunction, + color: ParsableManimColor = GREEN, + **kwargs: Any, ) -> ParametricFunction: """Returns the curve of the derivative of the passed graph. @@ -1526,7 +1539,7 @@ def construct(self): self.add(ax, curves, labels) """ - def deriv(x): + def deriv(x: float) -> float: return self.slope_of_tangent(x, graph) return self.plot(deriv, color=color, **kwargs) @@ -1587,7 +1600,7 @@ def antideriv(x): x_vals = np.linspace(0, x, samples, axis=1 if use_vectorized else 0) f_vec = np.vectorize(graph.underlying_function) y_vals = f_vec(x_vals) - return np.trapz(y_vals, x_vals) + y_intercept + return np.trapezoid(y_vals, x_vals) + y_intercept return self.plot(antideriv, use_vectorized=use_vectorized, **kwargs) @@ -1843,14 +1856,17 @@ def construct(self): return T_label_group - def __matmul__(self, coord: Point3DLike | Mobject): + def __matmul__(self, coord: Point3DLike | Mobject) -> Point3DLike: if isinstance(coord, Mobject): coord = coord.get_center() return self.coords_to_point(*coord) - def __rmatmul__(self, point: Point3DLike): + def __rmatmul__(self, point: Point3DLike) -> Point3DLike: return self.point_to_coords(point) + @staticmethod + def _origin_shift(axis_range: Sequence[float]) -> float: ... + class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL): """Creates a set of axes. @@ -1918,7 +1934,7 @@ def __init__( y_axis_config: dict | None = None, tips: bool = True, **kwargs: Any, - ) -> None: + ): VGroup.__init__(self, **kwargs) CoordinateSystem.__init__(self, x_range, y_range, x_length, y_length) @@ -1926,8 +1942,11 @@ def __init__( "include_tip": tips, "numbers_to_exclude": [0], } - self.x_axis_config = {} - self.y_axis_config = {"rotation": 90 * DEGREES, "label_direction": LEFT} + self.x_axis_config: dict[str, Any] = {} + self.y_axis_config: dict[str, Any] = { + "rotation": 90 * DEGREES, + "label_direction": LEFT, + } self._update_default_configs( (self.axis_config, self.x_axis_config, self.y_axis_config), @@ -2414,14 +2433,14 @@ def __init__( y_length: float | None = config.frame_height + 2.5, z_length: float | None = config.frame_height - 1.5, z_axis_config: dict[str, Any] | None = None, - z_normal: Vector3D = DOWN, + z_normal: Vector3DLike = DOWN, num_axis_pieces: int = 20, - light_source: Sequence[float] = 9 * DOWN + 7 * LEFT + 10 * OUT, + light_source: Point3DLike = 9 * DOWN + 7 * LEFT + 10 * OUT, # opengl stuff (?) - depth=None, - gloss=0.5, + depth: Any = None, + gloss: float = 0.5, **kwargs: dict[str, Any], - ) -> None: + ): super().__init__( x_range=x_range, x_length=x_length, @@ -2433,7 +2452,7 @@ def __init__( self.z_range = z_range self.z_length = z_length - self.z_axis_config = {} + self.z_axis_config: dict[str, Any] = {} self._update_default_configs((self.z_axis_config,), (z_axis_config,)) self.z_axis_config = merge_dicts_recursively( self.axis_config, @@ -2443,7 +2462,7 @@ def __init__( self.z_normal = z_normal self.num_axis_pieces = num_axis_pieces - self.light_source = light_source + self.light_source = np.array(light_source) self.dimension = 3 @@ -2500,13 +2519,13 @@ def make_func(axis): def get_y_axis_label( self, - label: float | str | Mobject, - edge: Sequence[float] = UR, - direction: Sequence[float] = UR, + label: float | str | VMobject, + edge: Vector3DLike = UR, + direction: Vector3DLike = UR, buff: float = SMALL_BUFF, rotation: float = PI / 2, - rotation_axis: Vector3D = OUT, - **kwargs, + rotation_axis: Vector3DLike = OUT, + **kwargs: dict[str, Any], ) -> Mobject: """Generate a y-axis label. @@ -2550,12 +2569,12 @@ def construct(self): def get_z_axis_label( self, - label: float | str | Mobject, - edge: Vector3D = OUT, - direction: Vector3D = RIGHT, + label: float | str | VMobject, + edge: Vector3DLike = OUT, + direction: Vector3DLike = RIGHT, buff: float = SMALL_BUFF, rotation: float = PI / 2, - rotation_axis: Vector3D = RIGHT, + rotation_axis: Vector3DLike = RIGHT, **kwargs: Any, ) -> Mobject: """Generate a z-axis label. @@ -2600,9 +2619,9 @@ def construct(self): def get_axis_labels( self, - x_label: float | str | Mobject = "x", - y_label: float | str | Mobject = "y", - z_label: float | str | Mobject = "z", + x_label: float | str | VMobject = "x", + y_label: float | str | VMobject = "y", + z_label: float | str | VMobject = "z", ) -> VGroup: """Defines labels for the x_axis and y_axis of the graph. @@ -2741,7 +2760,7 @@ def __init__( **kwargs: dict[str, Any], ): # configs - self.axis_config = { + self.axis_config: dict[str, Any] = { "stroke_width": 2, "include_ticks": False, "include_tip": False, @@ -2749,8 +2768,8 @@ def __init__( "label_direction": DR, "font_size": 24, } - self.y_axis_config = {"label_direction": DR} - self.background_line_style = { + self.y_axis_config: dict[str, Any] = {"label_direction": DR} + self.background_line_style: dict[str, Any] = { "stroke_color": BLUE_D, "stroke_width": 2, "stroke_opacity": 1, @@ -2997,7 +3016,7 @@ def __init__( size: float | None = None, radius_step: float = 1, azimuth_step: float | None = None, - azimuth_units: str | None = "PI radians", + azimuth_units: str = "PI radians", azimuth_compact_fraction: bool = True, azimuth_offset: float = 0, azimuth_direction: str = "CCW", @@ -3009,7 +3028,7 @@ def __init__( faded_line_ratio: int = 1, make_smooth_after_applying_functions: bool = True, **kwargs: Any, - ) -> None: + ): # error catching if azimuth_units in ["PI radians", "TAU radians", "degrees", "gradians", None]: self.azimuth_units = azimuth_units @@ -3130,11 +3149,11 @@ def _get_lines(self) -> tuple[VGroup, VGroup]: unit_vector = self.x_axis.get_unit_vector()[0] for k, x in enumerate(rinput): - new_line = Circle(radius=x * unit_vector) + new_circle = Circle(radius=x * unit_vector) if k % ratio_faded_lines == 0: - alines1.add(new_line) + alines1.add(new_circle) else: - alines2.add(new_line) + alines2.add(new_circle) line = Line(center, self.get_x_axis().get_end()) @@ -3292,7 +3311,9 @@ def add_coordinates( self.add(self.get_coordinate_labels(r_values, a_values)) return self - def get_radian_label(self, number, font_size: float = 24, **kwargs: Any) -> MathTex: + def get_radian_label( + self, number: float, font_size: float = 24, **kwargs: Any + ) -> MathTex: constant_label = {"PI radians": r"\pi", "TAU radians": r"\tau"}[ self.azimuth_units ] @@ -3361,7 +3382,7 @@ def construct(self): """ - def __init__(self, **kwargs: Any) -> None: + def __init__(self, **kwargs: Any): super().__init__( **kwargs, ) diff --git a/manim/mobject/graphing/functions.py b/manim/mobject/graphing/functions.py index 83c48b1092..0ce39267a3 100644 --- a/manim/mobject/graphing/functions.py +++ b/manim/mobject/graphing/functions.py @@ -5,8 +5,8 @@ __all__ = ["ParametricFunction", "FunctionGraph", "ImplicitFunction"] -from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING import numpy as np from isosurfaces import plot_isoline @@ -17,9 +17,12 @@ from manim.mobject.types.vectorized_mobject import VMobject if TYPE_CHECKING: + from typing import Any + from typing_extensions import Self from manim.typing import Point3D, Point3DLike + from manim.utils.color import ParsableManimColor from manim.utils.color import YELLOW @@ -111,7 +114,7 @@ def __init__( discontinuities: Iterable[float] | None = None, use_smoothing: bool = True, use_vectorized: bool = False, - **kwargs, + **kwargs: Any, ): def internal_parametric_function(t: float) -> Point3D: """Wrap ``function``'s output inside a NumPy array.""" @@ -143,13 +146,13 @@ def generate_points(self) -> Self: lambda t: self.t_min <= t <= self.t_max, self.discontinuities, ) - discontinuities = np.array(list(discontinuities)) + discontinuities_array = np.array(list(discontinuities)) boundary_times = np.array( [ self.t_min, self.t_max, - *(discontinuities - self.dt), - *(discontinuities + self.dt), + *(discontinuities_array - self.dt), + *(discontinuities_array + self.dt), ], ) boundary_times.sort() @@ -179,7 +182,8 @@ def generate_points(self) -> Self: self.make_smooth() return self - init_points = generate_points + def init_points(self) -> None: + self.generate_points() class FunctionGraph(ParametricFunction): @@ -211,19 +215,27 @@ def construct(self): self.add(cos_func, sin_func_1, sin_func_2) """ - def __init__(self, function, x_range=None, color=YELLOW, **kwargs): + def __init__( + self, + function: Callable[[float], Any], + x_range: tuple[float, float] | tuple[float, float, float] | None = None, + color: ParsableManimColor = YELLOW, + **kwargs: Any, + ) -> None: if x_range is None: - x_range = np.array([-config["frame_x_radius"], config["frame_x_radius"]]) + x_range = (-config["frame_x_radius"], config["frame_x_radius"]) self.x_range = x_range - self.parametric_function = lambda t: np.array([t, function(t), 0]) - self.function = function + self.parametric_function: Callable[[float], Point3D] = lambda t: np.array( + [t, function(t), 0] + ) + self.function = function # type: ignore[assignment] super().__init__(self.parametric_function, self.x_range, color=color, **kwargs) - def get_function(self): + def get_function(self) -> Callable[[float], Any]: return self.function - def get_point_from_function(self, x): + def get_point_from_function(self, x: float) -> Point3D: return self.parametric_function(x) @@ -236,7 +248,7 @@ def __init__( min_depth: int = 5, max_quads: int = 1500, use_smoothing: bool = True, - **kwargs, + **kwargs: Any, ): """An implicit function. @@ -295,7 +307,7 @@ def construct(self): super().__init__(**kwargs) - def generate_points(self): + def generate_points(self) -> Self: p_min, p_max = ( np.array([self.x_range[0], self.y_range[0]]), np.array([self.x_range[1], self.y_range[1]]), @@ -317,4 +329,5 @@ def generate_points(self): self.make_smooth() return self - init_points = generate_points + def init_points(self) -> None: + self.generate_points() diff --git a/manim/mobject/graphing/number_line.py b/manim/mobject/graphing/number_line.py index 017fac5bcb..66772e74d4 100644 --- a/manim/mobject/graphing/number_line.py +++ b/manim/mobject/graphing/number_line.py @@ -8,12 +8,16 @@ __all__ = ["NumberLine", "UnitInterval"] -from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Any + + from typing_extensions import Self + from manim.mobject.geometry.tips import ArrowTip - from manim.typing import Point3DLike + from manim.typing import Point3D, Point3DLike, Vector3D import numpy as np @@ -21,8 +25,9 @@ from manim.constants import * from manim.mobject.geometry.line import Line from manim.mobject.graphing.scale import LinearBase, _ScaleBase -from manim.mobject.text.numbers import DecimalNumber +from manim.mobject.text.numbers import DecimalNumber, Integer from manim.mobject.text.tex_mobject import MathTex, Tex +from manim.mobject.text.text_mobject import Text from manim.mobject.types.vectorized_mobject import VGroup, VMobject from manim.utils.bezier import interpolate from manim.utils.config_ops import merge_dicts_recursively @@ -157,14 +162,14 @@ def __init__( # numbers/labels include_numbers: bool = False, font_size: float = 36, - label_direction: Sequence[float] = DOWN, - label_constructor: VMobject = MathTex, + label_direction: Point3DLike = DOWN, + label_constructor: type[MathTex] = MathTex, scaling: _ScaleBase = LinearBase(), line_to_number_buff: float = MED_SMALL_BUFF, decimal_number_config: dict | None = None, numbers_to_exclude: Iterable[float] | None = None, numbers_to_include: Iterable[float] | None = None, - **kwargs, + **kwargs: Any, ): # avoid mutable arguments in defaults if numbers_to_exclude is None: @@ -189,6 +194,9 @@ def __init__( # turn into a NumPy array to scale by just applying the function self.x_range = np.array(x_range, dtype=float) + self.x_min: float + self.x_max: float + self.x_step: float self.x_min, self.x_max, self.x_step = scaling.function(self.x_range) self.length = length self.unit_size = unit_size @@ -246,16 +254,16 @@ def __init__( if self.scaling.custom_labels: tick_range = self.get_tick_range() + custom_labels = self.scaling.get_custom_labels( + tick_range, + unit_decimal_places=decimal_number_config["num_decimal_places"], + ) + self.add_labels( dict( zip( tick_range, - self.scaling.get_custom_labels( - tick_range, - unit_decimal_places=decimal_number_config[ - "num_decimal_places" - ], - ), + custom_labels, ) ), ) @@ -267,21 +275,25 @@ def __init__( font_size=self.font_size, ) - def rotate_about_zero(self, angle: float, axis: Sequence[float] = OUT, **kwargs): + def rotate_about_zero( + self, angle: float, axis: Vector3D = OUT, **kwargs: Any + ) -> Self: return self.rotate_about_number(0, angle, axis, **kwargs) def rotate_about_number( - self, number: float, angle: float, axis: Sequence[float] = OUT, **kwargs - ): + self, number: float, angle: float, axis: Vector3D = OUT, **kwargs: Any + ) -> Self: return self.rotate(angle, axis, about_point=self.n2p(number), **kwargs) - def add_ticks(self): + def add_ticks(self) -> None: """Adds ticks to the number line. Ticks can be accessed after creation via ``self.ticks``. """ ticks = VGroup() elongated_tick_size = self.tick_size * self.longer_tick_multiple - elongated_tick_offsets = self.numbers_with_elongated_ticks - self.x_min + elongated_tick_offsets = ( + np.array(self.numbers_with_elongated_ticks) - self.x_min + ) for x in self.get_tick_range(): size = self.tick_size if np.any(np.isclose(x - self.x_min, elongated_tick_offsets)): @@ -413,31 +425,34 @@ def point_to_number(self, point: Sequence[float]) -> float: point = np.asarray(point) start, end = self.get_start_and_end() unit_vect = normalize(end - start) - proportion = np.dot(point - start, unit_vect) / np.dot(end - start, unit_vect) + proportion: float = np.dot(point - start, unit_vect) / np.dot( + end - start, unit_vect + ) return interpolate(self.x_min, self.x_max, proportion) - def n2p(self, number: float | np.ndarray) -> np.ndarray: + def n2p(self, number: float | np.ndarray) -> Point3D: """Abbreviation for :meth:`~.NumberLine.number_to_point`.""" return self.number_to_point(number) - def p2n(self, point: Sequence[float]) -> float: + def p2n(self, point: Point3DLike) -> float: """Abbreviation for :meth:`~.NumberLine.point_to_number`.""" return self.point_to_number(point) def get_unit_size(self) -> float: - return self.get_length() / (self.x_range[1] - self.x_range[0]) + val: float = self.get_length() / (self.x_range[1] - self.x_range[0]) + return val - def get_unit_vector(self) -> np.ndarray: + def get_unit_vector(self) -> Vector3D: return super().get_unit_vector() * self.unit_size def get_number_mobject( self, x: float, - direction: Sequence[float] | None = None, + direction: Vector3D | None = None, buff: float | None = None, font_size: float | None = None, - label_constructor: VMobject | None = None, - **number_config, + label_constructor: type[MathTex] | None = None, + **number_config: dict[str, Any], ) -> VMobject: """Generates a positioned :class:`~.DecimalNumber` mobject generated according to ``label_constructor``. @@ -462,7 +477,7 @@ def get_number_mobject( :class:`~.DecimalNumber` The positioned mobject. """ - number_config = merge_dicts_recursively( + number_config_merged = merge_dicts_recursively( self.decimal_number_config, number_config, ) @@ -476,7 +491,10 @@ def get_number_mobject( label_constructor = self.label_constructor num_mob = DecimalNumber( - x, font_size=font_size, mob_class=label_constructor, **number_config + x, + font_size=font_size, + mob_class=label_constructor, + **number_config_merged, ) num_mob.next_to(self.number_to_point(x), direction=direction, buff=buff) @@ -485,7 +503,7 @@ def get_number_mobject( num_mob.shift(num_mob[0].width * LEFT / 2) return num_mob - def get_number_mobjects(self, *numbers, **kwargs) -> VGroup: + def get_number_mobjects(self, *numbers: float, **kwargs: Any) -> VGroup: if len(numbers) == 0: numbers = self.default_numbers_to_display() return VGroup([self.get_number_mobject(number, **kwargs) for number in numbers]) @@ -498,9 +516,9 @@ def add_numbers( x_values: Iterable[float] | None = None, excluding: Iterable[float] | None = None, font_size: float | None = None, - label_constructor: VMobject | None = None, - **kwargs, - ): + label_constructor: type[MathTex] | None = None, + **kwargs: Any, + ) -> Self: """Adds :class:`~.DecimalNumber` mobjects representing their position at each tick of the number line. The numbers can be accessed after creation via ``self.numbers``. @@ -551,11 +569,11 @@ def add_numbers( def add_labels( self, dict_values: dict[float, str | float | VMobject], - direction: Sequence[float] = None, + direction: Point3DLike | None = None, buff: float | None = None, font_size: float | None = None, - label_constructor: VMobject | None = None, - ): + label_constructor: type[MathTex] | None = None, + ) -> Self: """Adds specifically positioned labels to the :class:`~.NumberLine` using a ``dict``. The labels can be accessed after creation via ``self.labels``. @@ -598,6 +616,7 @@ def add_labels( label = self._create_label_tex(label, label_constructor) if hasattr(label, "font_size"): + assert isinstance(label, (MathTex, Tex, Text, Integer)), label label.font_size = font_size else: raise AttributeError(f"{label} is not compatible with add_labels.") @@ -612,7 +631,7 @@ def _create_label_tex( self, label_tex: str | float | VMobject, label_constructor: Callable | None = None, - **kwargs, + **kwargs: Any, ) -> VMobject: """Checks if the label is a :class:`~.VMobject`, otherwise, creates a label by passing ``label_tex`` to ``label_constructor``. @@ -633,24 +652,25 @@ def _create_label_tex( :class:`~.VMobject` The label. """ - if label_constructor is None: - label_constructor = self.label_constructor if isinstance(label_tex, (VMobject, OpenGLVMobject)): return label_tex - else: + if label_constructor is None: + label_constructor = self.label_constructor + if isinstance(label_tex, str): return label_constructor(label_tex, **kwargs) + return label_constructor(str(label_tex), **kwargs) @staticmethod - def _decimal_places_from_step(step) -> int: - step = str(step) - if "." not in step: + def _decimal_places_from_step(step: float) -> int: + step_str = str(step) + if "." not in step_str: return 0 - return len(step.split(".")[-1]) + return len(step_str.split(".")[-1]) - def __matmul__(self, other: float): + def __matmul__(self, other: float) -> Point3D: return self.n2p(other) - def __rmatmul__(self, other: Point3DLike | Mobject): + def __rmatmul__(self, other: Point3DLike | Mobject) -> float: if isinstance(other, Mobject): other = other.get_center() return self.p2n(other) @@ -659,10 +679,10 @@ def __rmatmul__(self, other: Point3DLike | Mobject): class UnitInterval(NumberLine): def __init__( self, - unit_size=10, - numbers_with_elongated_ticks=None, - decimal_number_config=None, - **kwargs, + unit_size: float = 10, + numbers_with_elongated_ticks: list[float] | None = None, + decimal_number_config: dict[str, Any] | None = None, + **kwargs: Any, ): numbers_with_elongated_ticks = ( [0, 1] diff --git a/manim/mobject/graphing/probability.py b/manim/mobject/graphing/probability.py index 24134c0a7a..9b88179bdc 100644 --- a/manim/mobject/graphing/probability.py +++ b/manim/mobject/graphing/probability.py @@ -6,6 +6,7 @@ from collections.abc import Iterable, MutableSequence, Sequence +from typing import Any import numpy as np @@ -13,11 +14,11 @@ from manim.constants import * from manim.mobject.geometry.polygram import Rectangle from manim.mobject.graphing.coordinate_systems import Axes -from manim.mobject.mobject import Mobject -from manim.mobject.opengl.opengl_mobject import OpenGLMobject +from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject from manim.mobject.svg.brace import Brace from manim.mobject.text.tex_mobject import MathTex, Tex from manim.mobject.types.vectorized_mobject import VGroup, VMobject +from manim.typing import Vector3D from manim.utils.color import ( BLUE_E, DARK_GREY, @@ -54,13 +55,13 @@ def construct(self): def __init__( self, - height=3, - width=3, - fill_color=DARK_GREY, - fill_opacity=1, - stroke_width=0.5, - stroke_color=LIGHT_GREY, - default_label_scale_val=1, + height: float = 3, + width: float = 3, + fill_color: ParsableManimColor = DARK_GREY, + fill_opacity: float = 1, + stroke_width: float = 0.5, + stroke_color: ParsableManimColor = LIGHT_GREY, + default_label_scale_val: float = 1, ): super().__init__( height=height, @@ -72,7 +73,9 @@ def __init__( ) self.default_label_scale_val = default_label_scale_val - def add_title(self, title="Sample space", buff=MED_SMALL_BUFF): + def add_title( + self, title: str = "Sample space", buff: float = MED_SMALL_BUFF + ) -> None: # TODO, should this really exist in SampleSpaceScene title_mob = Tex(title) if title_mob.width > self.width: @@ -81,23 +84,32 @@ def add_title(self, title="Sample space", buff=MED_SMALL_BUFF): self.title = title_mob self.add(title_mob) - def add_label(self, label): + def add_label(self, label: str) -> None: self.label = label - def complete_p_list(self, p_list): - new_p_list = list(tuplify(p_list)) + def complete_p_list(self, p_list: float | Iterable[float]) -> list[float]: + p_list_tuplified: tuple[float] = tuplify(p_list) + new_p_list = list(p_list_tuplified) remainder = 1.0 - sum(new_p_list) if abs(remainder) > EPSILON: new_p_list.append(remainder) return new_p_list - def get_division_along_dimension(self, p_list, dim, colors, vect): - p_list = self.complete_p_list(p_list) - colors = color_gradient(colors, len(p_list)) + def get_division_along_dimension( + self, + p_list: float | Iterable[float], + dim: int, + colors: Sequence[ParsableManimColor], + vect: Vector3D, + ) -> VGroup: + p_list_complete = self.complete_p_list(p_list) + colors_in_gradient = color_gradient(colors, len(p_list_complete)) + + assert isinstance(colors_in_gradient, list) last_point = self.get_edge_center(-vect) parts = VGroup() - for factor, color in zip(p_list, colors): + for factor, color in zip(p_list_complete, colors_in_gradient): part = SampleSpace() part.set_fill(color, 1) part.replace(self, stretch=True) @@ -107,33 +119,43 @@ def get_division_along_dimension(self, p_list, dim, colors, vect): parts.add(part) return parts - def get_horizontal_division(self, p_list, colors=[GREEN_E, BLUE_E], vect=DOWN): + def get_horizontal_division( + self, + p_list: float | Iterable[float], + colors: Sequence[ParsableManimColor] = [GREEN_E, BLUE_E], + vect: Vector3D = DOWN, + ) -> VGroup: return self.get_division_along_dimension(p_list, 1, colors, vect) - def get_vertical_division(self, p_list, colors=[MAROON_B, YELLOW], vect=RIGHT): + def get_vertical_division( + self, + p_list: float | Iterable[float], + colors: Sequence[ParsableManimColor] = [MAROON_B, YELLOW], + vect: Vector3D = RIGHT, + ) -> VGroup: return self.get_division_along_dimension(p_list, 0, colors, vect) - def divide_horizontally(self, *args, **kwargs): + def divide_horizontally(self, *args: Any, **kwargs: Any) -> None: self.horizontal_parts = self.get_horizontal_division(*args, **kwargs) self.add(self.horizontal_parts) - def divide_vertically(self, *args, **kwargs): + def divide_vertically(self, *args: Any, **kwargs: Any) -> None: self.vertical_parts = self.get_vertical_division(*args, **kwargs) self.add(self.vertical_parts) def get_subdivision_braces_and_labels( self, - parts, - labels, - direction, - buff=SMALL_BUFF, - min_num_quads=1, - ): + parts: VGroup, + labels: list[str | VMobject | OpenGLVMobject], + direction: Vector3D, + buff: float = SMALL_BUFF, + min_num_quads: int = 1, + ) -> VGroup: label_mobs = VGroup() braces = VGroup() for label, part in zip(labels, parts): brace = Brace(part, direction, min_num_quads=min_num_quads, buff=buff) - if isinstance(label, (Mobject, OpenGLMobject)): + if isinstance(label, (VMobject, OpenGLVMobject)): label_mob = label else: label_mob = MathTex(label) @@ -141,34 +163,44 @@ def get_subdivision_braces_and_labels( label_mob.next_to(brace, direction, buff) braces.add(brace) + assert isinstance(label_mob, VMobject) label_mobs.add(label_mob) - parts.braces = braces - parts.labels = label_mobs - parts.label_kwargs = { + parts.braces = braces # type: ignore[attr-defined] + parts.labels = label_mobs # type: ignore[attr-defined] + parts.label_kwargs = { # type: ignore[attr-defined] "labels": label_mobs.copy(), "direction": direction, "buff": buff, } - return VGroup(parts.braces, parts.labels) + return VGroup(parts.braces, parts.labels) # type: ignore[arg-type] - def get_side_braces_and_labels(self, labels, direction=LEFT, **kwargs): + def get_side_braces_and_labels( + self, + labels: list[str | VMobject | OpenGLVMobject], + direction: Vector3D = LEFT, + **kwargs: Any, + ) -> VGroup: assert hasattr(self, "horizontal_parts") parts = self.horizontal_parts return self.get_subdivision_braces_and_labels( parts, labels, direction, **kwargs ) - def get_top_braces_and_labels(self, labels, **kwargs): + def get_top_braces_and_labels( + self, labels: list[str | VMobject | OpenGLVMobject], **kwargs: Any + ) -> VGroup: assert hasattr(self, "vertical_parts") parts = self.vertical_parts return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs) - def get_bottom_braces_and_labels(self, labels, **kwargs): + def get_bottom_braces_and_labels( + self, labels: list[str | VMobject | OpenGLVMobject], **kwargs: Any + ) -> VGroup: assert hasattr(self, "vertical_parts") parts = self.vertical_parts return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs) - def add_braces_and_labels(self): + def add_braces_and_labels(self) -> None: for attr in "horizontal_parts", "vertical_parts": if not hasattr(self, attr): continue @@ -177,11 +209,13 @@ def add_braces_and_labels(self): if hasattr(parts, subattr): self.add(getattr(parts, subattr)) - def __getitem__(self, index): + def __getitem__(self, index: int) -> SampleSpace: if hasattr(self, "horizontal_parts"): - return self.horizontal_parts[index] + val: SampleSpace = self.horizontal_parts[index] + return val elif hasattr(self, "vertical_parts"): - return self.vertical_parts[index] + val = self.vertical_parts[index] + return val return self.split()[index] @@ -253,7 +287,7 @@ def __init__( bar_width: float = 0.6, bar_fill_opacity: float = 0.7, bar_stroke_width: float = 3, - **kwargs, + **kwargs: Any, ): if isinstance(bar_colors, str): logger.warning( @@ -311,7 +345,7 @@ def __init__( self.y_axis.add_numbers() - def _update_colors(self): + def _update_colors(self) -> None: """Initialize the colors of the bars of the chart. Sets the color of ``self.bars`` via ``self.bar_colors``. @@ -321,13 +355,14 @@ def _update_colors(self): """ self.bars.set_color_by_gradient(*self.bar_colors) - def _add_x_axis_labels(self): + def _add_x_axis_labels(self) -> None: """Essentially :meth`:~.NumberLine.add_labels`, but differs in that the direction of the label with respect to the x_axis changes to UP or DOWN depending on the value. UP for negative values and DOWN for positive values. """ + assert isinstance(self.bar_names, list) val_range = np.arange( 0.5, len(self.bar_names), 1 ) # 0.5 shifted so that labels are centered, not on ticks @@ -338,7 +373,7 @@ def _add_x_axis_labels(self): # to accommodate negative bars, the label may need to be # below or above the x_axis depending on the value of the bar direction = UP if self.values[i] < 0 else DOWN - bar_name_label = self.x_axis.label_constructor(bar_name) + bar_name_label: MathTex = self.x_axis.label_constructor(bar_name) bar_name_label.font_size = self.x_axis.font_size bar_name_label.next_to( @@ -398,8 +433,8 @@ def get_bar_labels( color: ParsableManimColor | None = None, font_size: float = 24, buff: float = MED_SMALL_BUFF, - label_constructor: type[VMobject] = Tex, - ): + label_constructor: type[MathTex] = Tex, + ) -> VGroup: """Annotates each bar with its corresponding value. Use ``self.bar_labels`` to access the labels after creation. @@ -431,7 +466,7 @@ def construct(self): """ bar_labels = VGroup() for bar, value in zip(self.bars, self.values): - bar_lbl = label_constructor(str(value)) + bar_lbl: MathTex = label_constructor(str(value)) if color is None: bar_lbl.set_color(bar.get_fill_color()) @@ -446,7 +481,9 @@ def construct(self): return bar_labels - def change_bar_values(self, values: Iterable[float], update_colors: bool = True): + def change_bar_values( + self, values: Iterable[float], update_colors: bool = True + ) -> None: """Updates the height of the bars of the chart. Parameters @@ -512,4 +549,4 @@ def construct(self): if update_colors: self._update_colors() - self.values[: len(values)] = values + self.values[: len(list(values))] = values diff --git a/manim/mobject/graphing/scale.py b/manim/mobject/graphing/scale.py index 78ffa2308b..b6ed2b4ce3 100644 --- a/manim/mobject/graphing/scale.py +++ b/manim/mobject/graphing/scale.py @@ -2,7 +2,7 @@ import math from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload import numpy as np @@ -11,7 +11,9 @@ from manim.mobject.text.numbers import Integer if TYPE_CHECKING: - from manim.mobject.mobject import Mobject + from typing import Callable + + from manim.mobject.types.vectorized_mobject import VMobject class _ScaleBase: @@ -26,6 +28,12 @@ class _ScaleBase: def __init__(self, custom_labels: bool = False): self.custom_labels = custom_labels + @overload + def function(self, value: float) -> float: ... + + @overload + def function(self, value: np.ndarray) -> np.ndarray: ... + def function(self, value: float) -> float: """The function that will be used to scale the values. @@ -59,7 +67,8 @@ def inverse_function(self, value: float) -> float: def get_custom_labels( self, val_range: Iterable[float], - ) -> Iterable[Mobject]: + **kw_args: Any, + ) -> Iterable[VMobject]: """Custom instructions for generating labels along an axis. Parameters @@ -139,15 +148,19 @@ def __init__(self, base: float = 10, custom_labels: bool = True): def function(self, value: float) -> float: """Scales the value to fit it to a logarithmic scale.``self.function(5)==10**5``""" - return self.base**value + return_value: float = self.base**value + return return_value def inverse_function(self, value: float) -> float: """Inverse of ``function``. The value must be greater than 0""" if isinstance(value, np.ndarray): condition = value.any() <= 0 - def func(value, base): - return np.log(value) / np.log(base) + func: Callable[[float, float], float] + + def func(value: float, base: float) -> float: + return_value: float = np.log(value) / np.log(base) + return return_value else: condition = value <= 0 func = math.log @@ -163,8 +176,8 @@ def get_custom_labels( self, val_range: Iterable[float], unit_decimal_places: int = 0, - **base_config: dict[str, Any], - ) -> list[Mobject]: + **base_config: Any, + ) -> list[Integer]: """Produces custom :class:`~.Integer` labels in the form of ``10^2``. Parameters @@ -177,7 +190,7 @@ def get_custom_labels( Additional arguments to be passed to :class:`~.Integer`. """ # uses `format` syntax to control the number of decimal places. - tex_labels = [ + tex_labels: list[Integer] = [ Integer( self.base, unit="^{%s}" % (f"{self.inverse_function(i):.{unit_decimal_places}f}"), # noqa: UP031 diff --git a/manim/mobject/logo.py b/manim/mobject/logo.py index 6242a4c645..3cdab467f2 100644 --- a/manim/mobject/logo.py +++ b/manim/mobject/logo.py @@ -4,11 +4,15 @@ __all__ = ["ManimBanner"] +from typing import Any + import svgelements as se from manim.animation.updaters.update import UpdateFromAlphaFunc from manim.mobject.geometry.arc import Circle from manim.mobject.geometry.polygram import Square, Triangle +from manim.mobject.mobject import Mobject +from manim.typing import Vector3D from .. import constants as cst from ..animation.animation import override_animation @@ -146,7 +150,7 @@ def __init__(self, dark_theme: bool = True): m_height_over_anim_height = 0.75748 self.font_color = "#ece6e2" if dark_theme else "#343434" - self.scale_factor = 1 + self.scale_factor = 1.0 self.M = VMobjectFromSVGPath(MANIM_SVG_PATHS[0]).flip(cst.RIGHT).center() self.M.set(stroke_width=0).scale( @@ -180,7 +184,7 @@ def __init__(self, dark_theme: bool = True): # and thus not yet added to the submobjects of self. self.anim = anim - def scale(self, scale_factor: float, **kwargs) -> ManimBanner: + def scale(self, scale_factor: float, **kwargs: Any) -> ManimBanner: """Scale the banner by the specified scale factor. Parameters @@ -219,7 +223,7 @@ def create(self, run_time: float = 2) -> AnimationGroup: lag_ratio=0.1, ) - def expand(self, run_time: float = 1.5, direction="center") -> Succession: + def expand(self, run_time: float = 1.5, direction: str = "center") -> Succession: """An animation that expands Manim's logo into its banner. The returned animation transforms the banner from its initial @@ -277,7 +281,7 @@ def construct(self): self.M.save_state() left_group = VGroup(self.M, self.anim, m_clone) - def shift(vector): + def shift(vector: Vector3D) -> None: self.shapes.restore() left_group.align_to(self.M.saved_state, cst.LEFT) if direction == "right": @@ -288,7 +292,7 @@ def shift(vector): elif direction == "left": left_group.shift(-vector) - def slide_and_uncover(mob, alpha): + def slide_and_uncover(mob: Mobject, alpha: float) -> None: shift(alpha * (m_shape_offset + shape_sliding_overshoot) * cst.RIGHT) # Add letters when they are covered @@ -305,7 +309,7 @@ def slide_and_uncover(mob, alpha): mob.shapes.save_state() mob.M.save_state() - def slide_back(mob, alpha): + def slide_back(mob: Mobject, alpha: float) -> None: if alpha == 0: m_clone.set_opacity(1) m_clone.move_to(mob.anim[-1]) diff --git a/manim/mobject/matrix.py b/manim/mobject/matrix.py index 673aba1877..36513e4e1d 100644 --- a/manim/mobject/matrix.py +++ b/manim/mobject/matrix.py @@ -40,9 +40,11 @@ def construct(self): import itertools as it -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence +from typing import Any import numpy as np +from typing_extensions import Self from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL @@ -56,7 +58,7 @@ def construct(self): # Not sure if we should keep it or not. -def matrix_to_tex_string(matrix): +def matrix_to_tex_string(matrix: np.ndarray) -> str: matrix = np.array(matrix).astype("str") if matrix.ndim == 1: matrix = matrix.reshape((matrix.size, 1)) @@ -67,7 +69,7 @@ def matrix_to_tex_string(matrix): return prefix + " \\\\ ".join(rows) + suffix -def matrix_to_mobject(matrix): +def matrix_to_mobject(matrix: np.ndarray) -> MathTex: return MathTex(matrix_to_tex_string(matrix)) @@ -170,14 +172,14 @@ def __init__( bracket_v_buff: float = MED_SMALL_BUFF, add_background_rectangles_to_entries: bool = False, include_background_rectangle: bool = False, - element_to_mobject: type[MathTex] = MathTex, + element_to_mobject: type[Mobject] | Callable[..., Mobject] = MathTex, element_to_mobject_config: dict = {}, element_alignment_corner: Sequence[float] = DR, left_bracket: str = "[", right_bracket: str = "]", stretch_brackets: bool = True, bracket_config: dict = {}, - **kwargs, + **kwargs: Any, ): self.v_buff = v_buff self.h_buff = h_buff @@ -205,7 +207,7 @@ def __init__( if self.include_background_rectangle: self.add_background_rectangle() - def _matrix_to_mob_matrix(self, matrix): + def _matrix_to_mob_matrix(self, matrix: np.ndarray) -> list[list[Mobject]]: return [ [ self.element_to_mobject(item, **self.element_to_mobject_config) @@ -214,7 +216,7 @@ def _matrix_to_mob_matrix(self, matrix): for row in matrix ] - def _organize_mob_matrix(self, matrix): + def _organize_mob_matrix(self, matrix: list[list[Mobject]]) -> Self: for i, row in enumerate(matrix): for j, _ in enumerate(row): mob = matrix[i][j] @@ -224,7 +226,7 @@ def _organize_mob_matrix(self, matrix): ) return self - def _add_brackets(self, left: str = "[", right: str = "]", **kwargs): + def _add_brackets(self, left: str = "[", right: str = "]", **kwargs: Any) -> Self: """Adds the brackets to the Matrix mobject. See Latex document for various bracket types. @@ -278,13 +280,13 @@ def _add_brackets(self, left: str = "[", right: str = "]", **kwargs): self.add(l_bracket, r_bracket) return self - def get_columns(self): + def get_columns(self) -> VGroup: r"""Return columns of the matrix as VGroups. Returns -------- - List[:class:`~.VGroup`] - Each VGroup contains a column of the matrix. + :class:`~.VGroup` + The VGroup contains a nested VGroup for each column of the matrix. Examples -------- @@ -305,7 +307,7 @@ def construct(self): ) ) - def set_column_colors(self, *colors: str): + def set_column_colors(self, *colors: str) -> Self: r"""Set individual colors for each columns of the matrix. Parameters @@ -335,13 +337,13 @@ def construct(self): column.set_color(color) return self - def get_rows(self): + def get_rows(self) -> VGroup: r"""Return rows of the matrix as VGroups. Returns -------- - List[:class:`~.VGroup`] - Each VGroup contains a row of the matrix. + :class:`~.VGroup` + The VGroup contains a nested VGroup for each row of the matrix. Examples -------- @@ -357,7 +359,7 @@ def construct(self): """ return VGroup(*(VGroup(*row) for row in self.mob_matrix)) - def set_row_colors(self, *colors: str): + def set_row_colors(self, *colors: str) -> Self: r"""Set individual colors for each row of the matrix. Parameters @@ -387,7 +389,7 @@ def construct(self): row.set_color(color) return self - def add_background_to_entries(self): + def add_background_to_entries(self) -> Self: """Add a black background rectangle to the matrix, see above for an example. @@ -400,7 +402,7 @@ def add_background_to_entries(self): mob.add_background_rectangle() return self - def get_mob_matrix(self): + def get_mob_matrix(self) -> list[list[Mobject]]: """Return the underlying mob matrix mobjects. Returns @@ -410,7 +412,7 @@ def get_mob_matrix(self): """ return self.mob_matrix - def get_entries(self): + def get_entries(self) -> VGroup: """Return the individual entries of the matrix. Returns @@ -435,13 +437,13 @@ def construct(self): """ return self.elements - def get_brackets(self): + def get_brackets(self) -> VGroup: r"""Return the bracket mobjects. Returns -------- - List[:class:`~.VGroup`] - Each VGroup contains a bracket + :class:`~.VGroup` + A VGroup containing the left and right bracket. Examples -------- @@ -483,9 +485,9 @@ def construct(self): def __init__( self, matrix: Iterable, - element_to_mobject: Mobject = DecimalNumber, - element_to_mobject_config: dict[str, Mobject] = {"num_decimal_places": 1}, - **kwargs, + element_to_mobject: type[Mobject] = DecimalNumber, + element_to_mobject_config: dict[str, Any] = {"num_decimal_places": 1}, + **kwargs: Any, ): """ Will round/truncate the decimal places as per the provided config. @@ -526,7 +528,10 @@ def construct(self): """ def __init__( - self, matrix: Iterable, element_to_mobject: Mobject = Integer, **kwargs + self, + matrix: Iterable, + element_to_mobject: type[Mobject] = Integer, + **kwargs: Any, ): """ Will round if there are decimal entries in the matrix. @@ -560,7 +565,12 @@ def construct(self): self.add(m0) """ - def __init__(self, matrix, element_to_mobject=lambda m: m, **kwargs): + def __init__( + self, + matrix: Iterable, + element_to_mobject: type[Mobject] | Callable[..., Mobject] = lambda m: m, + **kwargs: Any, + ): super().__init__(matrix, element_to_mobject=element_to_mobject, **kwargs) @@ -569,7 +579,7 @@ def get_det_text( determinant: int | str | None = None, background_rect: bool = False, initial_scale_factor: float = 2, -): +) -> VGroup: r"""Helper function to create determinant. Parameters diff --git a/manim/mobject/mobject.py b/manim/mobject/mobject.py index f46679e75d..e313a9e718 100644 --- a/manim/mobject/mobject.py +++ b/manim/mobject/mobject.py @@ -14,13 +14,14 @@ import sys import types import warnings -from collections.abc import Iterable +from collections.abc import Callable, Iterable from functools import partialmethod, reduce from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal import numpy as np +from manim.data_structures import MethodWithArgs from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from .. import config, logger @@ -40,8 +41,6 @@ from ..utils.space_ops import angle_between_vectors, normalize, rotation_matrix if TYPE_CHECKING: - from typing import Any, Callable, Literal - from typing_extensions import Self, TypeAlias from manim.typing import ( @@ -53,7 +52,7 @@ Point3D, Point3DLike, Point3DLike_Array, - Vector3D, + Vector3DLike, ) from ..animation.animation import Animation @@ -387,9 +386,9 @@ def construct(self): will interpolate the :class:`~.Mobject` between its points prior to ``.animate`` and its points after applying ``.animate`` to it. This may result in unexpected behavior when attempting to interpolate along paths, - or rotations. + or rotations (see :meth:`.rotate`). If you want animations to consider the points between, consider using - :class:`~.ValueTracker` with updaters instead. + :class:`~.ValueTracker` with updaters instead (see :meth:`.add_updater`). """ return _AnimationBuilder(self) @@ -1002,16 +1001,16 @@ def add_updater( class NextToUpdater(Scene): def construct(self): - def dot_position(mobject): + def update_label(mobject): mobject.set_value(dot.get_center()[0]) mobject.next_to(dot) dot = Dot(RIGHT*3) label = DecimalNumber() - label.add_updater(dot_position) + label.add_updater(update_label) self.add(dot, label) - self.play(Rotating(dot, about_point=ORIGIN, angle=TAU, run_time=TAU, rate_func=linear)) + self.play(Rotating(dot, angle=TAU, about_point=ORIGIN, run_time=TAU, rate_func=linear)) .. manim:: DtUpdater @@ -1029,6 +1028,9 @@ def construct(self): :meth:`get_updaters` :meth:`remove_updater` :class:`~.UpdateFromFunc` + :class:`~.Rotating` + :meth:`rotate` + :attr:`~.Mobject.animate` """ if index is None: self.updaters.append(update_function) @@ -1200,7 +1202,7 @@ def apply_to_family(self, func: Callable[[Mobject], None]) -> None: for mob in self.family_members_with_points(): func(mob) - def shift(self, *vectors: Vector3D) -> Self: + def shift(self, *vectors: Vector3DLike) -> Self: """Shift by the given vectors. Parameters @@ -1271,25 +1273,82 @@ def construct(self): ) return self - def rotate_about_origin(self, angle: float, axis: Vector3D = OUT, axes=[]) -> Self: + def rotate_about_origin(self, angle: float, axis: Vector3DLike = OUT) -> Self: """Rotates the :class:`~.Mobject` about the ORIGIN, which is at [0,0,0].""" return self.rotate(angle, axis, about_point=ORIGIN) def rotate( self, angle: float, - axis: Vector3D = OUT, + axis: Vector3DLike = OUT, about_point: Point3DLike | None = None, **kwargs, ) -> Self: - """Rotates the :class:`~.Mobject` about a certain point.""" + """Rotates the :class:`~.Mobject` around a specified axis and point. + + Parameters + ---------- + angle + The angle of rotation in radians. Predefined constants such as ``DEGREES`` + can also be used to specify the angle in degrees. + axis + The rotation axis (see :class:`~.Rotating` for more). + about_point + The point about which the mobject rotates. If ``None``, rotation occurs around + the center of the mobject. + **kwargs + Additional keyword arguments passed to :meth:`apply_points_function_about_point`, + such as ``about_edge``. + + Returns + ------- + :class:`Mobject` + ``self`` (for method chaining) + + + .. note:: + To animate a rotation, use :class:`~.Rotating` or :class:`~.Rotate` + instead of ``.animate.rotate(...)``. + The ``.animate.rotate(...)`` syntax only applies a transformation + from the initial state to the final rotated state + (interpolation between the two states), without showing proper rotational motion + based on the angle (from 0 to the given angle). + + Examples + -------- + + .. manim:: RotateMethodExample + :save_last_frame: + + class RotateMethodExample(Scene): + def construct(self): + circle = Circle(radius=1, color=BLUE) + line = Line(start=ORIGIN, end=RIGHT) + arrow1 = Arrow(start=ORIGIN, end=RIGHT, buff=0, color=GOLD) + group1 = VGroup(circle, line, arrow1) + + group2 = group1.copy() + arrow2 = group2[2] + arrow2.rotate(angle=PI / 4, about_point=arrow2.get_start()) + + group3 = group1.copy() + arrow3 = group3[2] + arrow3.rotate(angle=120 * DEGREES, about_point=arrow3.get_start()) + + self.add(VGroup(group1, group2, group3).arrange(RIGHT, buff=1)) + + See also + -------- + :class:`~.Rotating`, :class:`~.Rotate`, :attr:`~.Mobject.animate`, :meth:`apply_points_function_about_point` + + """ rot_matrix = rotation_matrix(angle, axis) self.apply_points_function_about_point( lambda points: np.dot(points, rot_matrix.T), about_point, **kwargs ) return self - def flip(self, axis: Vector3D = UP, **kwargs) -> Self: + def flip(self, axis: Vector3DLike = UP, **kwargs) -> Self: """Flips/Mirrors an mobject about its center. Examples @@ -1409,7 +1468,7 @@ def apply_points_function_about_point( self, func: MultiMappingFunction, about_point: Point3DLike | None = None, - about_edge: Vector3D | None = None, + about_edge: Vector3DLike | None = None, ) -> Self: if about_point is None: if about_edge is None: @@ -1439,7 +1498,7 @@ def center(self) -> Self: return self def align_on_border( - self, direction: Vector3D, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER + self, direction: Vector3DLike, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER ) -> Self: """Direction just needs to be a vector pointing towards side or corner in the 2d plane. @@ -1456,7 +1515,7 @@ def align_on_border( return self def to_corner( - self, corner: Vector3D = DL, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER + self, corner: Vector3DLike = DL, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER ) -> Self: """Moves this :class:`~.Mobject` to the given corner of the screen. @@ -1484,7 +1543,7 @@ def construct(self): return self.align_on_border(corner, buff) def to_edge( - self, edge: Vector3D = LEFT, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER + self, edge: Vector3DLike = LEFT, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER ) -> Self: """Moves this :class:`~.Mobject` to the given edge of the screen, without affecting its position in the other dimension. @@ -1516,12 +1575,12 @@ def construct(self): def next_to( self, mobject_or_point: Mobject | Point3DLike, - direction: Vector3D = RIGHT, + direction: Vector3DLike = RIGHT, buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, - aligned_edge: Vector3D = ORIGIN, + aligned_edge: Vector3DLike = ORIGIN, submobject_to_align: Mobject | None = None, index_of_submobject_to_align: int | None = None, - coor_mask: Vector3D = np.array([1, 1, 1]), + coor_mask: Vector3DLike = np.array([1, 1, 1]), ) -> Self: """Move this :class:`~.Mobject` next to another's :class:`~.Mobject` or Point3D. @@ -1543,13 +1602,18 @@ def construct(self): self.add(d, c, s, t) """ + np_direction = np.asarray(direction) + np_aligned_edge = np.asarray(aligned_edge) + if isinstance(mobject_or_point, Mobject): mob = mobject_or_point if index_of_submobject_to_align is not None: target_aligner = mob[index_of_submobject_to_align] else: target_aligner = mob - target_point = target_aligner.get_critical_point(aligned_edge + direction) + target_point = target_aligner.get_critical_point( + np_aligned_edge + np_direction + ) else: target_point = mobject_or_point if submobject_to_align is not None: @@ -1558,8 +1622,8 @@ def construct(self): aligner = self[index_of_submobject_to_align] else: aligner = self - point_to_align = aligner.get_critical_point(aligned_edge - direction) - self.shift((target_point - point_to_align + buff * direction) * coor_mask) + point_to_align = aligner.get_critical_point(np_aligned_edge - np_direction) + self.shift((target_point - point_to_align + buff * np_direction) * coor_mask) return self def shift_onto_screen(self, **kwargs) -> Self: @@ -1705,22 +1769,22 @@ def stretch_to_fit_depth(self, depth: float, **kwargs) -> Self: """Stretches the :class:`~.Mobject` to fit a depth, not keeping width/height proportional.""" return self.rescale_to_fit(depth, 2, stretch=True, **kwargs) - def set_coord(self, value, dim: int, direction: Vector3D = ORIGIN) -> Self: + def set_coord(self, value, dim: int, direction: Vector3DLike = ORIGIN) -> Self: curr = self.get_coord(dim, direction) shift_vect = np.zeros(self.dim) shift_vect[dim] = value - curr self.shift(shift_vect) return self - def set_x(self, x: float, direction: Vector3D = ORIGIN) -> Self: + def set_x(self, x: float, direction: Vector3DLike = ORIGIN) -> Self: """Set x value of the center of the :class:`~.Mobject` (``int`` or ``float``)""" return self.set_coord(x, 0, direction) - def set_y(self, y: float, direction: Vector3D = ORIGIN) -> Self: + def set_y(self, y: float, direction: Vector3DLike = ORIGIN) -> Self: """Set y value of the center of the :class:`~.Mobject` (``int`` or ``float``)""" return self.set_coord(y, 1, direction) - def set_z(self, z: float, direction: Vector3D = ORIGIN) -> Self: + def set_z(self, z: float, direction: Vector3DLike = ORIGIN) -> Self: """Set z value of the center of the :class:`~.Mobject` (``int`` or ``float``)""" return self.set_coord(z, 2, direction) @@ -1733,8 +1797,8 @@ def space_out_submobjects(self, factor: float = 1.5, **kwargs) -> Self: def move_to( self, point_or_mobject: Point3DLike | Mobject, - aligned_edge: Vector3D = ORIGIN, - coor_mask: Vector3D = np.array([1, 1, 1]), + aligned_edge: Vector3DLike = ORIGIN, + coor_mask: Vector3DLike = np.array([1, 1, 1]), ) -> Self: """Move center of the :class:`~.Mobject` to certain Point3D.""" if isinstance(point_or_mobject, Mobject): @@ -2053,7 +2117,7 @@ def get_extremum_along_dim( else: return np.max(values) - def get_critical_point(self, direction: Vector3D) -> Point3D: + def get_critical_point(self, direction: Vector3DLike) -> Point3D: """Picture a box bounding the :class:`~.Mobject`. Such a box has 9 'critical points': 4 corners, 4 edge center, the center. This returns one of them, along the given direction. @@ -2082,11 +2146,11 @@ def get_critical_point(self, direction: Vector3D) -> Point3D: # Pseudonyms for more general get_critical_point method - def get_edge_center(self, direction: Vector3D) -> Point3D: + def get_edge_center(self, direction: Vector3DLike) -> Point3D: """Get edge Point3Ds for certain direction.""" return self.get_critical_point(direction) - def get_corner(self, direction: Vector3D) -> Point3D: + def get_corner(self, direction: Vector3DLike) -> Point3D: """Get corner Point3Ds for certain direction.""" return self.get_critical_point(direction) @@ -2097,9 +2161,9 @@ def get_center(self) -> Point3D: def get_center_of_mass(self) -> Point3D: return np.apply_along_axis(np.mean, 0, self.get_all_points()) - def get_boundary_point(self, direction: Vector3D) -> Point3D: + def get_boundary_point(self, direction: Vector3DLike) -> Point3D: all_points = self.get_points_defining_boundary() - index = np.argmax(np.dot(all_points, np.array(direction).T)) + index = np.argmax(np.dot(all_points, direction)) return all_points[index] def get_midpoint(self) -> Point3D: @@ -2156,19 +2220,19 @@ def length_over_dim(self, dim: int) -> float: dim, ) - self.reduce_across_dimension(min, dim) - def get_coord(self, dim: int, direction: Vector3D = ORIGIN): + def get_coord(self, dim: int, direction: Vector3DLike = ORIGIN) -> float: """Meant to generalize ``get_x``, ``get_y`` and ``get_z``""" return self.get_extremum_along_dim(dim=dim, key=direction[dim]) - def get_x(self, direction: Vector3D = ORIGIN) -> float: + def get_x(self, direction: Vector3DLike = ORIGIN) -> float: """Returns x Point3D of the center of the :class:`~.Mobject` as ``float``""" return self.get_coord(0, direction) - def get_y(self, direction: Vector3D = ORIGIN) -> float: + def get_y(self, direction: Vector3DLike = ORIGIN) -> float: """Returns y Point3D of the center of the :class:`~.Mobject` as ``float``""" return self.get_coord(1, direction) - def get_z(self, direction: Vector3D = ORIGIN) -> float: + def get_z(self, direction: Vector3DLike = ORIGIN) -> float: """Returns z Point3D of the center of the :class:`~.Mobject` as ``float``""" return self.get_coord(2, direction) @@ -2239,7 +2303,7 @@ def match_depth(self, mobject: Mobject, **kwargs) -> Self: return self.match_dim_size(mobject, 2, **kwargs) def match_coord( - self, mobject: Mobject, dim: int, direction: Vector3D = ORIGIN + self, mobject: Mobject, dim: int, direction: Vector3DLike = ORIGIN ) -> Self: """Match the Point3Ds with the Point3Ds of another :class:`~.Mobject`.""" return self.set_coord( @@ -2263,7 +2327,7 @@ def match_z(self, mobject: Mobject, direction=ORIGIN) -> Self: def align_to( self, mobject_or_point: Mobject | Point3DLike, - direction: Vector3D = ORIGIN, + direction: Vector3DLike = ORIGIN, ) -> Self: """Aligns mobject to another :class:`~.Mobject` in a certain direction. @@ -2370,7 +2434,7 @@ def family_members_with_points(self) -> list[Self]: def arrange( self, - direction: Vector3D = RIGHT, + direction: Vector3DLike = RIGHT, buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, center: bool = True, **kwargs, @@ -2403,7 +2467,7 @@ def arrange_in_grid( rows: int | None = None, cols: int | None = None, buff: float | tuple[float, float] = MED_SMALL_BUFF, - cell_alignment: Vector3D = ORIGIN, + cell_alignment: Vector3DLike = ORIGIN, row_alignments: str | None = None, # "ucd" col_alignments: str | None = None, # "lcr" row_heights: Iterable[float | None] | None = None, @@ -3042,21 +3106,22 @@ def construct(self): -------- :meth:`~.Mobject.align_data`, :meth:`~.VMobject.interpolate_color` """ - mobject = mobject.copy() - if stretch: - mobject.stretch_to_fit_height(self.height) - mobject.stretch_to_fit_width(self.width) - mobject.stretch_to_fit_depth(self.depth) - else: - if match_height: - mobject.match_height(self) - if match_width: - mobject.match_width(self) - if match_depth: - mobject.match_depth(self) + if stretch or match_height or match_width or match_depth or match_center: + mobject = mobject.copy() + if stretch: + mobject.stretch_to_fit_height(self.height) + mobject.stretch_to_fit_width(self.width) + mobject.stretch_to_fit_depth(self.depth) + else: + if match_height: + mobject.match_height(self) + if match_width: + mobject.match_width(self) + if match_depth: + mobject.match_depth(self) - if match_center: - mobject.move_to(self.get_center()) + if match_center: + mobject.move_to(self.get_center()) self.align_data(mobject, skip_point_alignment=True) for sm1, sm2 in zip(self.get_family(), mobject.get_family()): @@ -3172,7 +3237,7 @@ def __init__(self, mobject) -> None: self.overridden_animation = None self.is_chaining = False - self.methods = [] + self.methods: list[MethodWithArgs] = [] # Whether animation args can be passed self.cannot_pass_args = False @@ -3207,7 +3272,7 @@ def update_target(*method_args, **method_kwargs): **method_kwargs, ) else: - self.methods.append([method, method_args, method_kwargs]) + self.methods.append(MethodWithArgs(method, method_args, method_kwargs)) method(*method_args, **method_kwargs) return self @@ -3221,10 +3286,7 @@ def build(self) -> Animation: _MethodAnimation, ) - if self.overridden_animation: - anim = self.overridden_animation - else: - anim = _MethodAnimation(self.mobject, self.methods) + anim = self.overridden_animation or _MethodAnimation(self.mobject, self.methods) for attr, value in self.anim_args.items(): setattr(anim, attr, value) @@ -3241,7 +3303,7 @@ def override_animate(method) -> types.FunctionType: .. seealso:: - :attr:`Mobject.animate` + :attr:`~.Mobject.animate` .. note:: diff --git a/manim/mobject/opengl/dot_cloud.py b/manim/mobject/opengl/dot_cloud.py index 4cb0ed8bc7..7161c46376 100644 --- a/manim/mobject/opengl/dot_cloud.py +++ b/manim/mobject/opengl/dot_cloud.py @@ -2,16 +2,25 @@ __all__ = ["TrueDot", "DotCloud"] +from typing import Any + import numpy as np +from typing_extensions import Self from manim.constants import ORIGIN, RIGHT, UP from manim.mobject.opengl.opengl_point_cloud_mobject import OpenGLPMobject -from manim.utils.color import YELLOW +from manim.typing import Point3DLike +from manim.utils.color import YELLOW, ParsableManimColor class DotCloud(OpenGLPMobject): def __init__( - self, color=YELLOW, stroke_width=2.0, radius=2.0, density=10, **kwargs + self, + color: ParsableManimColor = YELLOW, + stroke_width: float = 2.0, + radius: float = 2.0, + density: float = 10, + **kwargs: Any, ): self.radius = radius self.epsilon = 1.0 / density @@ -19,7 +28,7 @@ def __init__( stroke_width=stroke_width, density=density, color=color, **kwargs ) - def init_points(self): + def init_points(self) -> None: self.points = np.array( [ r * (np.cos(theta) * RIGHT + np.sin(theta) * UP) @@ -34,7 +43,7 @@ def init_points(self): dtype=np.float32, ) - def make_3d(self, gloss=0.5, shadow=0.2): + def make_3d(self, gloss: float = 0.5, shadow: float = 0.2) -> Self: self.set_gloss(gloss) self.set_shadow(shadow) self.apply_depth_test() @@ -42,6 +51,8 @@ def make_3d(self, gloss=0.5, shadow=0.2): class TrueDot(DotCloud): - def __init__(self, center=ORIGIN, stroke_width=2.0, **kwargs): + def __init__( + self, center: Point3DLike = ORIGIN, stroke_width: float = 2.0, **kwargs: Any + ): self.radius = stroke_width super().__init__(points=[center], stroke_width=stroke_width, **kwargs) diff --git a/manim/mobject/opengl/opengl_mobject.py b/manim/mobject/opengl/opengl_mobject.py index 6428995cd5..c8312eddc9 100644 --- a/manim/mobject/opengl/opengl_mobject.py +++ b/manim/mobject/opengl/opengl_mobject.py @@ -6,16 +6,17 @@ import random import sys import types -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence from functools import partialmethod, wraps from math import ceil -from typing import TYPE_CHECKING, Any, Callable, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar import moderngl import numpy as np from manim import config, logger from manim.constants import * +from manim.data_structures import MethodWithArgs from manim.renderer.shader_wrapper import get_colormap_code from manim.utils.bezier import integer_interpolate, interpolate from manim.utils.color import ( @@ -62,6 +63,7 @@ Point3DLike, Point3DLike_Array, Vector3D, + Vector3DLike, ) TimeBasedUpdater: TypeAlias = Callable[["Mobject", float], object] @@ -635,7 +637,7 @@ def apply_points_function( self, func: MultiMappingFunction, about_point: Point3DLike | None = None, - about_edge: Vector3D | None = ORIGIN, + about_edge: Vector3DLike | None = ORIGIN, works_on_bounding_box: bool = False, ) -> Self: if about_point is None and about_edge is not None: @@ -991,7 +993,7 @@ def replace_submobject(self, index: int, new_submob: OpenGLMobject) -> Self: # Submobject organization def arrange( - self, direction: Vector3D = RIGHT, center: bool = True, **kwargs + self, direction: Vector3DLike = RIGHT, center: bool = True, **kwargs ) -> Self: """Sorts :class:`~.OpenGLMobject` next to each other on screen. @@ -1021,7 +1023,7 @@ def arrange_in_grid( rows: int | None = None, cols: int | None = None, buff: float | tuple[float, float] = MED_SMALL_BUFF, - cell_alignment: Vector3D = ORIGIN, + cell_alignment: Vector3DLike = ORIGIN, row_alignments: str | None = None, # "ucd" col_alignments: str | None = None, # "lcr" row_heights: Sequence[float | None] | None = None, @@ -1552,7 +1554,7 @@ def refresh_has_updater_status(self) -> Self: # Transforming operations - def shift(self, vector: Vector3D) -> Self: + def shift(self, vector: Vector3DLike) -> Self: self.apply_points_function( lambda points: points + vector, about_edge=None, @@ -1630,14 +1632,14 @@ def func(points: Point3D_Array) -> Point3D_Array: self.apply_points_function(func, works_on_bounding_box=True, **kwargs) return self - def rotate_about_origin(self, angle: float, axis: Vector3D = OUT) -> Self: + def rotate_about_origin(self, angle: float, axis: Vector3DLike = OUT) -> Self: return self.rotate(angle, axis, about_point=ORIGIN) def rotate( self, angle: float, - axis: Vector3D = OUT, - about_point: Sequence[float] | None = None, + axis: Vector3DLike = OUT, + about_point: Point3DLike | None = None, **kwargs, ) -> Self: """Rotates the :class:`~.OpenGLMobject` about a certain point.""" @@ -1649,7 +1651,7 @@ def rotate( ) return self - def flip(self, axis: Vector3D = UP, **kwargs) -> Self: + def flip(self, axis: Vector3DLike = UP, **kwargs) -> Self: """Flips/Mirrors an mobject about its center. Examples @@ -1750,8 +1752,8 @@ def hierarchical_model_matrix(self) -> MatrixMN: def wag( self, - direction: Vector3D = RIGHT, - axis: Vector3D = DOWN, + direction: Vector3DLike = RIGHT, + axis: Vector3DLike = DOWN, wag_factor: float = 1.0, ) -> Self: for mob in self.family_members_with_points(): @@ -1777,7 +1779,7 @@ def center(self) -> Self: def align_on_border( self, - direction: Vector3D, + direction: Vector3DLike, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, ) -> Self: """ @@ -1790,21 +1792,21 @@ def align_on_border( 0, ) point_to_align = self.get_bounding_box_point(direction) - shift_val = target_point - point_to_align - buff * np.array(direction) + shift_val = target_point - point_to_align - buff * np.asarray(direction) shift_val = shift_val * abs(np.sign(direction)) self.shift(shift_val) return self def to_corner( self, - corner: Vector3D = LEFT + DOWN, + corner: Vector3DLike = LEFT + DOWN, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, ) -> Self: return self.align_on_border(corner, buff) def to_edge( self, - edge: Vector3D = LEFT, + edge: Vector3DLike = LEFT, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, ) -> Self: return self.align_on_border(edge, buff) @@ -1812,12 +1814,12 @@ def to_edge( def next_to( self, mobject_or_point: OpenGLMobject | Point3DLike, - direction: Vector3D = RIGHT, + direction: Vector3DLike = RIGHT, buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, - aligned_edge: Vector3D = ORIGIN, + aligned_edge: Vector3DLike = ORIGIN, submobject_to_align: OpenGLMobject | None = None, index_of_submobject_to_align: int | None = None, - coor_mask: Point3DLike = np.array([1, 1, 1]), + coor_mask: Vector3DLike = np.array([1, 1, 1]), ) -> Self: """Move this :class:`~.OpenGLMobject` next to another's :class:`~.OpenGLMobject` or coordinate. @@ -1839,6 +1841,9 @@ def construct(self): self.add(d, c, s, t) """ + np_direction = np.asarray(direction) + np_aligned_edge = np.asarray(aligned_edge) + if isinstance(mobject_or_point, OpenGLMobject): mob = mobject_or_point if index_of_submobject_to_align is not None: @@ -1846,7 +1851,7 @@ def construct(self): else: target_aligner = mob target_point = target_aligner.get_bounding_box_point( - aligned_edge + direction, + np_aligned_edge + np_direction, ) else: target_point = mobject_or_point @@ -1856,8 +1861,8 @@ def construct(self): aligner = self[index_of_submobject_to_align] else: aligner = self - point_to_align = aligner.get_bounding_box_point(aligned_edge - direction) - self.shift((target_point - point_to_align + buff * direction) * coor_mask) + point_to_align = aligner.get_bounding_box_point(np_aligned_edge - np_direction) + self.shift((target_point - point_to_align + buff * np_direction) * coor_mask) return self def shift_onto_screen(self, **kwargs) -> Self: @@ -1969,22 +1974,24 @@ def set_depth(self, depth: float, stretch: bool = False, **kwargs): scale_to_fit_depth = set_depth - def set_coord(self, value: float, dim: int, direction: Vector3D = ORIGIN) -> Self: + def set_coord( + self, value: float, dim: int, direction: Vector3DLike = ORIGIN + ) -> Self: curr = self.get_coord(dim, direction) shift_vect = np.zeros(self.dim) shift_vect[dim] = value - curr self.shift(shift_vect) return self - def set_x(self, x: float, direction: Vector3D = ORIGIN) -> Self: + def set_x(self, x: float, direction: Vector3DLike = ORIGIN) -> Self: """Set x value of the center of the :class:`~.OpenGLMobject` (``int`` or ``float``)""" return self.set_coord(x, 0, direction) - def set_y(self, y: float, direction: Vector3D = ORIGIN) -> Self: + def set_y(self, y: float, direction: Vector3DLike = ORIGIN) -> Self: """Set y value of the center of the :class:`~.OpenGLMobject` (``int`` or ``float``)""" return self.set_coord(y, 1, direction) - def set_z(self, z: float, direction: Vector3D = ORIGIN) -> Self: + def set_z(self, z: float, direction: Vector3DLike = ORIGIN) -> Self: """Set z value of the center of the :class:`~.OpenGLMobject` (``int`` or ``float``)""" return self.set_coord(z, 2, direction) @@ -1997,8 +2004,8 @@ def space_out_submobjects(self, factor: float = 1.5, **kwargs) -> Self: def move_to( self, point_or_mobject: Point3DLike | OpenGLMobject, - aligned_edge: Vector3D = ORIGIN, - coor_mask: Point3DLike = np.array([1, 1, 1]), + aligned_edge: Vector3DLike = ORIGIN, + coor_mask: Vector3DLike = np.array([1, 1, 1]), ) -> Self: """Move center of the :class:`~.OpenGLMobject` to certain coordinate.""" if isinstance(point_or_mobject, OpenGLMobject): @@ -2251,16 +2258,16 @@ def add_background_rectangle_to_family_members_with_points(self, **kwargs) -> Se # Getters - def get_bounding_box_point(self, direction: Vector3D) -> Point3D: + def get_bounding_box_point(self, direction: Vector3DLike) -> Point3D: bb = self.get_bounding_box() indices = (np.sign(direction) + 1).astype(int) return np.array([bb[indices[i]][i] for i in range(3)]) - def get_edge_center(self, direction: Vector3D) -> Point3D: + def get_edge_center(self, direction: Vector3DLike) -> Point3D: """Get edge coordinates for certain direction.""" return self.get_bounding_box_point(direction) - def get_corner(self, direction: Vector3D) -> Point3D: + def get_corner(self, direction: Vector3DLike) -> Point3D: """Get corner coordinates for certain direction.""" return self.get_bounding_box_point(direction) @@ -2271,23 +2278,24 @@ def get_center(self) -> Point3D: def get_center_of_mass(self) -> Point3D: return self.get_all_points().mean(0) - def get_boundary_point(self, direction: Vector3D) -> Point3D: + def get_boundary_point(self, direction: Vector3DLike) -> Point3D: all_points = self.get_all_points() boundary_directions = all_points - self.get_center() norms = np.linalg.norm(boundary_directions, axis=1) boundary_directions /= np.repeat(norms, 3).reshape((len(norms), 3)) - index = np.argmax(np.dot(boundary_directions, np.array(direction).T)) + index = np.argmax(np.dot(boundary_directions, direction)) return all_points[index] - def get_continuous_bounding_box_point(self, direction: Vector3D) -> Point3D: + def get_continuous_bounding_box_point(self, direction: Vector3DLike) -> Point3D: dl, center, ur = self.get_bounding_box() corner_vect = ur - center - return center + direction / np.max( + np_direction = np.asarray(direction) + return center + np_direction / np.max( np.abs( np.true_divide( - direction, + np_direction, corner_vect, - out=np.zeros(len(direction)), + out=np.zeros(len(np_direction)), where=((corner_vect) != 0), ), ), @@ -2333,19 +2341,19 @@ def get_depth(self) -> float: """Returns the depth of the mobject.""" return self.length_over_dim(2) - def get_coord(self, dim: int, direction: Vector3D = ORIGIN) -> ManimFloat: + def get_coord(self, dim: int, direction: Vector3DLike = ORIGIN) -> ManimFloat: """Meant to generalize ``get_x``, ``get_y`` and ``get_z``""" return self.get_bounding_box_point(direction)[dim] - def get_x(self, direction: Vector3D = ORIGIN) -> ManimFloat: + def get_x(self, direction: Vector3DLike = ORIGIN) -> ManimFloat: """Returns x coordinate of the center of the :class:`~.OpenGLMobject` as ``float``""" return self.get_coord(0, direction) - def get_y(self, direction: Vector3D = ORIGIN) -> ManimFloat: + def get_y(self, direction: Vector3DLike = ORIGIN) -> ManimFloat: """Returns y coordinate of the center of the :class:`~.OpenGLMobject` as ``float``""" return self.get_coord(1, direction) - def get_z(self, direction: Vector3D = ORIGIN) -> ManimFloat: + def get_z(self, direction: Vector3DLike = ORIGIN) -> ManimFloat: """Returns z coordinate of the center of the :class:`~.OpenGLMobject` as ``float``""" return self.get_coord(2, direction) @@ -2411,7 +2419,7 @@ def match_depth(self, mobject: OpenGLMobject, **kwargs) -> Self: return self.match_dim_size(mobject, 2, **kwargs) def match_coord( - self, mobject: OpenGLMobject, dim: int, direction: Vector3D = ORIGIN + self, mobject: OpenGLMobject, dim: int, direction: Vector3DLike = ORIGIN ) -> Self: """Match the coordinates with the coordinates of another :class:`~.OpenGLMobject`.""" return self.set_coord( @@ -2420,22 +2428,22 @@ def match_coord( direction=direction, ) - def match_x(self, mobject: OpenGLMobject, direction: Vector3D = ORIGIN) -> Self: + def match_x(self, mobject: OpenGLMobject, direction: Vector3DLike = ORIGIN) -> Self: """Match x coord. to the x coord. of another :class:`~.OpenGLMobject`.""" return self.match_coord(mobject, 0, direction) - def match_y(self, mobject: OpenGLMobject, direction: Vector3D = ORIGIN) -> Self: + def match_y(self, mobject: OpenGLMobject, direction: Vector3DLike = ORIGIN) -> Self: """Match y coord. to the x coord. of another :class:`~.OpenGLMobject`.""" return self.match_coord(mobject, 1, direction) - def match_z(self, mobject: OpenGLMobject, direction: Vector3D = ORIGIN) -> Self: + def match_z(self, mobject: OpenGLMobject, direction: Vector3DLike = ORIGIN) -> Self: """Match z coord. to the x coord. of another :class:`~.OpenGLMobject`.""" return self.match_coord(mobject, 2, direction) def align_to( self, mobject_or_point: OpenGLMobject | Point3DLike, - direction: Vector3D = ORIGIN, + direction: Vector3DLike = ORIGIN, ) -> Self: """ Examples: @@ -2938,7 +2946,7 @@ def __init__(self, mobject: OpenGLMobject): self.overridden_animation = None self.is_chaining = False - self.methods = [] + self.methods: list[MethodWithArgs] = [] # Whether animation args can be passed self.cannot_pass_args = False @@ -2973,7 +2981,7 @@ def update_target(*method_args, **method_kwargs): **method_kwargs, ) else: - self.methods.append([method, method_args, method_kwargs]) + self.methods.append(MethodWithArgs(method, method_args, method_kwargs)) method(*method_args, **method_kwargs) return self @@ -2985,10 +2993,7 @@ def update_target(*method_args, **method_kwargs): def build(self) -> _MethodAnimation: from manim.animation.transform import _MethodAnimation - if self.overridden_animation: - anim = self.overridden_animation - else: - anim = _MethodAnimation(self.mobject, self.methods) + anim = self.overridden_animation or _MethodAnimation(self.mobject, self.methods) for attr, value in self.anim_args.items(): setattr(anim, attr, value) diff --git a/manim/mobject/opengl/opengl_point_cloud_mobject.py b/manim/mobject/opengl/opengl_point_cloud_mobject.py index 1725eccee8..72e196fb9b 100644 --- a/manim/mobject/opengl/opengl_point_cloud_mobject.py +++ b/manim/mobject/opengl/opengl_point_cloud_mobject.py @@ -4,11 +4,19 @@ import moderngl import numpy as np +from typing_extensions import Self from manim.constants import * from manim.mobject.opengl.opengl_mobject import OpenGLMobject from manim.utils.bezier import interpolate -from manim.utils.color import BLACK, WHITE, YELLOW, color_gradient, color_to_rgba +from manim.utils.color import ( + BLACK, + WHITE, + YELLOW, + ParsableManimColor, + color_gradient, + color_to_rgba, +) from manim.utils.config_ops import _Uniforms from manim.utils.iterables import resize_with_interpolation @@ -27,7 +35,11 @@ class OpenGLPMobject(OpenGLMobject): point_radius = _Uniforms() def __init__( - self, stroke_width=2.0, color=YELLOW, render_primitive=moderngl.POINTS, **kwargs + self, + stroke_width: float = 2.0, + color: ParsableManimColor = YELLOW, + render_primitive: int = moderngl.POINTS, + **kwargs, ): self.stroke_width = stroke_width super().__init__(color=color, render_primitive=render_primitive, **kwargs) @@ -35,7 +47,7 @@ def __init__( self.stroke_width * OpenGLPMobject.OPENGL_POINT_RADIUS_SCALE_FACTOR ) - def reset_points(self): + def reset_points(self) -> Self: self.rgbas = np.zeros((1, 4)) self.points = np.zeros((0, 3)) return self diff --git a/manim/mobject/opengl/opengl_surface.py b/manim/mobject/opengl/opengl_surface.py index 565b8c71cf..e6fdeee456 100644 --- a/manim/mobject/opengl/opengl_surface.py +++ b/manim/mobject/opengl/opengl_surface.py @@ -9,6 +9,7 @@ from manim.constants import * from manim.mobject.opengl.opengl_mobject import OpenGLMobject +from manim.typing import Point3D_Array, Vector3D_Array from manim.utils.bezier import integer_interpolate, interpolate from manim.utils.color import * from manim.utils.config_ops import _Data, _Uniforms @@ -160,12 +161,14 @@ def compute_triangle_indices(self): def get_triangle_indices(self): return self.triangle_indices - def get_surface_points_and_nudged_points(self): + def get_surface_points_and_nudged_points( + self, + ) -> tuple[Point3D_Array, Point3D_Array, Point3D_Array]: points = self.points k = len(points) // 3 return points[:k], points[k : 2 * k], points[2 * k :] - def get_unit_normals(self): + def get_unit_normals(self) -> Vector3D_Array: s_points, du_points, dv_points = self.get_surface_points_and_nudged_points() normals = np.cross( (du_points - s_points) / self.epsilon, diff --git a/manim/mobject/opengl/opengl_three_dimensions.py b/manim/mobject/opengl/opengl_three_dimensions.py index 930ff9ef20..56ed68a416 100644 --- a/manim/mobject/opengl/opengl_three_dimensions.py +++ b/manim/mobject/opengl/opengl_three_dimensions.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + import numpy as np from manim.mobject.opengl.opengl_surface import OpenGLSurface @@ -11,13 +13,13 @@ class OpenGLSurfaceMesh(OpenGLVGroup): def __init__( self, - uv_surface, - resolution=None, - stroke_width=1, - normal_nudge=1e-2, - depth_test=True, - flat_stroke=False, - **kwargs, + uv_surface: OpenGLSurface, + resolution: tuple[int, int] | None = None, + stroke_width: float = 1, + normal_nudge: float = 1e-2, + depth_test: bool = True, + flat_stroke: bool = False, + **kwargs: Any, ): if not isinstance(uv_surface, OpenGLSurface): raise Exception("uv_surface must be of type OpenGLSurface") @@ -31,7 +33,7 @@ def __init__( **kwargs, ) - def init_points(self): + def init_points(self) -> None: uv_surface = self.uv_surface full_nu, full_nv = uv_surface.resolution diff --git a/manim/mobject/opengl/opengl_vectorized_mobject.py b/manim/mobject/opengl/opengl_vectorized_mobject.py index b31934e999..d8ced06d65 100644 --- a/manim/mobject/opengl/opengl_vectorized_mobject.py +++ b/manim/mobject/opengl/opengl_vectorized_mobject.py @@ -2,12 +2,13 @@ import itertools as it import operator as op -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from functools import reduce, wraps -from typing import Callable +from typing import Any import moderngl import numpy as np +from typing_extensions import Self from manim import config from manim.constants import * @@ -171,6 +172,15 @@ def get_group_class(self): def get_mobject_type_class(): return OpenGLVMobject + @property + def submobjects(self) -> Sequence[OpenGLVMobject]: + return self._submobjects if hasattr(self, "_submobjects") else [] + + @submobjects.setter + def submobjects(self, submobject_list: Iterable[OpenGLVMobject]) -> None: + self.remove(*self.submobjects) + self.add(*submobject_list) + def init_data(self): super().init_data() self.data.pop("rgbas") @@ -594,7 +604,9 @@ def set_points_as_corners(self, points: Iterable[float]) -> OpenGLVMobject: ) return self - def set_points_smoothly(self, points, true_smooth=False): + def set_points_smoothly( + self, points: Point3DLike_Array, true_smooth: bool = False + ) -> Self: self.set_points_as_corners(points) self.make_smooth() return self @@ -1654,7 +1666,7 @@ def construct(self): self.add(circles_group) """ - def __init__(self, *vmobjects, **kwargs): + def __init__(self, *vmobjects: OpenGLVMobject, **kwargs: Any): super().__init__(**kwargs) self.add(*vmobjects) diff --git a/manim/mobject/svg/brace.py b/manim/mobject/svg/brace.py index 3d826f4f01..ea969e0712 100644 --- a/manim/mobject/svg/brace.py +++ b/manim/mobject/svg/brace.py @@ -4,19 +4,21 @@ __all__ = ["Brace", "BraceLabel", "ArcBrace", "BraceText", "BraceBetweenPoints"] -from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np import svgelements as se +from typing_extensions import Self from manim._config import config from manim.mobject.geometry.arc import Arc from manim.mobject.geometry.line import Line from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL -from manim.mobject.text.tex_mobject import MathTex, Tex +from manim.mobject.text.tex_mobject import MathTex, SingleStringMathTex, Tex +from manim.mobject.text.text_mobject import Text +from ...animation.animation import Animation from ...animation.composition import AnimationGroup from ...animation.fading import FadeIn from ...animation.growing import GrowFromCenter @@ -26,11 +28,9 @@ from ..svg.svg_mobject import VMobjectFromSVGPath if TYPE_CHECKING: - from manim.typing import Point3DLike, Vector3D + from manim.typing import Point3D, Point3DLike, Vector3D, Vector3DLike from manim.utils.color.core import ParsableManimColor -__all__ = ["Brace", "BraceBetweenPoints", "BraceLabel", "ArcBrace"] - class Brace(VMobjectFromSVGPath): """Takes a mobject and draws a brace adjacent to it. @@ -70,14 +70,14 @@ def construct(self): def __init__( self, mobject: Mobject, - direction: Vector3D | None = DOWN, + direction: Vector3DLike = DOWN, buff: float = 0.2, sharpness: float = 2, stroke_width: float = 0, fill_opacity: float = 1.0, background_stroke_width: float = 0, background_stroke_color: ParsableManimColor = BLACK, - **kwargs, + **kwargs: Any, ): path_string_template = ( "m0.01216 0c-0.01152 0-0.01216 6.103e-4 -0.01216 0.01311v0.007762c0.06776 " @@ -130,7 +130,7 @@ def __init__( for mob in mobject, self: mob.rotate(angle, about_point=ORIGIN) - def put_at_tip(self, mob: Mobject, use_next_to: bool = True, **kwargs): + def put_at_tip(self, mob: Mobject, use_next_to: bool = True, **kwargs: Any) -> Self: """Puts the given mobject at the brace tip. Parameters @@ -153,7 +153,7 @@ def put_at_tip(self, mob: Mobject, use_next_to: bool = True, **kwargs): mob.shift(self.get_direction() * shift_distance) return self - def get_text(self, *text, **kwargs): + def get_text(self, *text: str, **kwargs: Any) -> Tex: """Places the text at the brace tip. Parameters @@ -172,7 +172,7 @@ def get_text(self, *text, **kwargs): self.put_at_tip(text_mob, **kwargs) return text_mob - def get_tex(self, *tex, **kwargs): + def get_tex(self, *tex: str, **kwargs: Any) -> MathTex: """Places the tex at the brace tip. Parameters @@ -191,7 +191,7 @@ def get_tex(self, *tex, **kwargs): self.put_at_tip(tex_mob, **kwargs) return tex_mob - def get_tip(self): + def get_tip(self) -> Point3D: """Returns the point at the brace tip.""" # Returns the position of the seventh point in the path, which is the tip. if config["renderer"] == "opengl": @@ -199,7 +199,7 @@ def get_tip(self): return self.points[28] # = 7*4 - def get_direction(self): + def get_direction(self) -> Vector3D: """Returns the direction from the center to the brace tip.""" vect = self.get_tip() - self.get_center() return vect / np.linalg.norm(vect) @@ -233,12 +233,12 @@ def __init__( self, obj: Mobject, text: str, - brace_direction: np.ndarray = DOWN, - label_constructor: type = MathTex, + brace_direction: Vector3DLike = DOWN, + label_constructor: type[SingleStringMathTex | Text] = MathTex, font_size: float = DEFAULT_FONT_SIZE, buff: float = 0.2, - brace_config: dict | None = None, - **kwargs, + brace_config: dict[str, Any] | None = None, + **kwargs: Any, ): self.label_constructor = label_constructor super().__init__(**kwargs) @@ -249,37 +249,94 @@ def __init__( self.brace = Brace(obj, brace_direction, buff, **brace_config) if isinstance(text, (tuple, list)): - self.label = self.label_constructor(*text, font_size=font_size, **kwargs) + self.label: VMobject = self.label_constructor( + *text, font_size=font_size, **kwargs + ) else: self.label = self.label_constructor(str(text), font_size=font_size) self.brace.put_at_tip(self.label) self.add(self.brace, self.label) - def creation_anim(self, label_anim=FadeIn, brace_anim=GrowFromCenter): + def creation_anim( + self, + label_anim: type[Animation] = FadeIn, + brace_anim: type[Animation] = GrowFromCenter, + ) -> AnimationGroup: return AnimationGroup(brace_anim(self.brace), label_anim(self.label)) - def shift_brace(self, obj, **kwargs): + def shift_brace(self, obj: Mobject, **kwargs: Any) -> Self: if isinstance(obj, list): obj = self.get_group_class()(*obj) self.brace = Brace(obj, self.brace_direction, **kwargs) self.brace.put_at_tip(self.label) return self - def change_label(self, *text, **kwargs): - self.label = self.label_constructor(*text, **kwargs) - + def change_label(self, *text: str, **kwargs: Any) -> Self: + self.remove(self.label) + self.label = self.label_constructor(*text, **kwargs) # type: ignore[arg-type] self.brace.put_at_tip(self.label) + self.add(self.label) return self - def change_brace_label(self, obj, *text, **kwargs): + def change_brace_label(self, obj: Mobject, *text: str, **kwargs: Any) -> Self: self.shift_brace(obj) self.change_label(*text, **kwargs) return self class BraceText(BraceLabel): - def __init__(self, obj, text, label_constructor=Tex, **kwargs): + """Create a brace with a text label attached. + + Parameters + ---------- + obj + The mobject adjacent to which the brace is placed. + text + The label text. + brace_direction + The direction of the brace. By default ``DOWN``. + label_constructor + A class or function used to construct a mobject representing + the label. By default :class:`~.Text`. + font_size + The font size of the label, passed to the ``label_constructor``. + buff + The buffer between the mobject and the brace. + brace_config + Arguments to be passed to :class:`.Brace`. + kwargs + Additional arguments to be passed to :class:`~.VMobject`. + + + Examples + -------- + .. manim:: BraceTextExample + :save_last_frame: + + class BraceTextExample(Scene): + def construct(self): + s1 = Square().move_to(2*LEFT) + self.add(s1) + br1 = BraceText(s1, "Label") + self.add(br1) + + s2 = Square().move_to(2*RIGHT) + self.add(s2) + br2 = BraceText(s2, "Label") + + br2.change_label("new") + self.add(br2) + self.wait(0.1) + """ + + def __init__( + self, + obj: Mobject, + text: str, + label_constructor: type[SingleStringMathTex | Text] = Text, + **kwargs: Any, + ): super().__init__(obj, text, label_constructor=label_constructor, **kwargs) @@ -317,10 +374,10 @@ def construct(self): def __init__( self, - point_1: Point3DLike | None, - point_2: Point3DLike | None, - direction: Vector3D | None = ORIGIN, - **kwargs, + point_1: Point3DLike, + point_2: Point3DLike, + direction: Vector3DLike = ORIGIN, + **kwargs: Any, ): if all(direction == ORIGIN): line_vector = np.array(point_2) - np.array(point_1) @@ -386,8 +443,8 @@ def construct(self): def __init__( self, arc: Arc | None = None, - direction: Sequence[float] = RIGHT, - **kwargs, + direction: Vector3DLike = RIGHT, + **kwargs: Any, ): if arc is None: arc = Arc(start_angle=-1, angle=2, radius=1) diff --git a/manim/mobject/svg/svg_mobject.py b/manim/mobject/svg/svg_mobject.py index 82c121fce7..03264dd8c7 100644 --- a/manim/mobject/svg/svg_mobject.py +++ b/manim/mobject/svg/svg_mobject.py @@ -4,12 +4,14 @@ import os from pathlib import Path +from typing import Any from xml.etree import ElementTree as ET import numpy as np import svgelements as se from manim import config, logger +from manim.utils.color import ParsableManimColor from ...constants import RIGHT from ...utils.bezier import get_quadratic_approximation_of_cubic @@ -98,17 +100,17 @@ def __init__( should_center: bool = True, height: float | None = 2, width: float | None = None, - color: str | None = None, + color: ParsableManimColor | None = None, opacity: float | None = None, - fill_color: str | None = None, + fill_color: ParsableManimColor | None = None, fill_opacity: float | None = None, - stroke_color: str | None = None, + stroke_color: ParsableManimColor | None = None, stroke_opacity: float | None = None, stroke_width: float | None = None, svg_default: dict | None = None, path_string_config: dict | None = None, use_svg_cache: bool = True, - **kwargs, + **kwargs: Any, ): super().__init__(color=None, stroke_color=None, fill_color=None, **kwargs) @@ -121,10 +123,12 @@ def __init__( self.color = color self.opacity = opacity self.fill_color = fill_color - self.fill_opacity = fill_opacity + self.fill_opacity = fill_opacity # type: ignore[assignment] self.stroke_color = stroke_color - self.stroke_opacity = stroke_opacity - self.stroke_width = stroke_width + self.stroke_opacity = stroke_opacity # type: ignore[assignment] + self.stroke_width = stroke_width # type: ignore[assignment] + if self.stroke_width is None: + self.stroke_width = 0 if svg_default is None: svg_default = { @@ -191,7 +195,7 @@ def generate_mobject(self) -> None: """Parse the SVG and translate its elements to submobjects.""" file_path = self.get_file_path() element_tree = ET.parse(file_path) - new_tree = self.modify_xml_tree(element_tree) + new_tree = self.modify_xml_tree(element_tree) # type: ignore[arg-type] # Create a temporary svg file to dump modified svg to be parsed modified_file_path = file_path.with_name(f"{file_path.stem}_{file_path.suffix}") new_tree.write(modified_file_path) @@ -228,12 +232,12 @@ def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree: "style", ) root = element_tree.getroot() - root_style_dict = {k: v for k, v in root.attrib.items() if k in style_keys} + root_style_dict = {k: v for k, v in root.attrib.items() if k in style_keys} # type: ignore[union-attr] new_root = ET.Element("svg", {}) config_style_node = ET.SubElement(new_root, "g", config_style_dict) root_style_node = ET.SubElement(config_style_node, "g", root_style_dict) - root_style_node.extend(root) + root_style_node.extend(root) # type: ignore[arg-type] return ET.ElementTree(new_root) def generate_config_style_dict(self) -> dict[str, str]: @@ -262,13 +266,13 @@ def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]: svg The parsed SVG file. """ - result = [] + result: list[VMobject] = [] for shape in svg.elements(): # can we combine the two continue cases into one? if isinstance(shape, se.Group): # noqa: SIM114 continue elif isinstance(shape, se.Path): - mob = self.path_to_mobject(shape) + mob: VMobject = self.path_to_mobject(shape) elif isinstance(shape, se.SimpleLine): mob = self.line_to_mobject(shape) elif isinstance(shape, se.Rect): @@ -422,7 +426,7 @@ def polyline_to_mobject(self, polyline: se.Polyline) -> VMobject: return vmobject_class().set_points_as_corners(points) @staticmethod - def text_to_mobject(text: se.Text): + def text_to_mobject(text: se.Text) -> VMobject: """Convert a text element to a vectorized mobject. .. warning:: @@ -435,7 +439,7 @@ def text_to_mobject(text: se.Text): The parsed SVG text. """ logger.warning(f"Unsupported element type: {type(text)}") - return + return # type: ignore[return-value] def move_into_position(self) -> None: """Scale and move the generated mobject into position.""" @@ -480,7 +484,7 @@ def __init__( long_lines: bool = False, should_subdivide_sharp_curves: bool = False, should_remove_null_curves: bool = False, - **kwargs, + **kwargs: Any, ): # Get rid of arcs path_obj.approximate_arcs_with_quads() @@ -492,7 +496,7 @@ def __init__( super().__init__(**kwargs) - def init_points(self) -> None: + def generate_points(self) -> None: # TODO: cache mobject in a re-importable way self.handle_commands() @@ -505,15 +509,16 @@ def init_points(self) -> None: # Get rid of any null curves self.set_points(self.get_points_without_null_curves()) - generate_points = init_points + def init_points(self) -> None: + self.generate_points() def handle_commands(self) -> None: all_points: list[np.ndarray] = [] - last_move = None + last_move: np.ndarray = None curve_start = None last_true_move = None - def move_pen(pt, *, true_move: bool = False): + def move_pen(pt: np.ndarray, *, true_move: bool = False) -> None: nonlocal last_move, curve_start, last_true_move last_move = pt if curve_start is None: @@ -523,17 +528,19 @@ def move_pen(pt, *, true_move: bool = False): if self.n_points_per_curve == 4: - def add_cubic(start, cp1, cp2, end): + def add_cubic( + start: np.ndarray, cp1: np.ndarray, cp2: np.ndarray, end: np.ndarray + ) -> None: nonlocal all_points assert len(all_points) % 4 == 0, len(all_points) all_points += [start, cp1, cp2, end] move_pen(end) - def add_quad(start, cp, end): + def add_quad(start: np.ndarray, cp: np.ndarray, end: np.ndarray) -> None: add_cubic(start, (start + cp + cp) / 3, (cp + cp + end) / 3, end) move_pen(end) - def add_line(start, end): + def add_line(start: np.ndarray, end: np.ndarray) -> None: add_cubic( start, (start + start + end) / 3, (start + end + end) / 3, end ) @@ -541,7 +548,9 @@ def add_line(start, end): else: - def add_cubic(start, cp1, cp2, end): + def add_cubic( + start: np.ndarray, cp1: np.ndarray, cp2: np.ndarray, end: np.ndarray + ) -> None: nonlocal all_points assert len(all_points) % 3 == 0, len(all_points) two_quads = get_quadratic_approximation_of_cubic( @@ -554,13 +563,13 @@ def add_cubic(start, cp1, cp2, end): all_points += two_quads[3:].tolist() move_pen(end) - def add_quad(start, cp, end): + def add_quad(start: np.ndarray, cp: np.ndarray, end: np.ndarray) -> None: nonlocal all_points assert len(all_points) % 3 == 0, len(all_points) all_points += [start, cp, end] move_pen(end) - def add_line(start, end): + def add_line(start: np.ndarray, end: np.ndarray) -> None: add_quad(start, (start + end) / 2, end) move_pen(end) diff --git a/manim/mobject/table.py b/manim/mobject/table.py index 9d65263dfd..432fa1a428 100644 --- a/manim/mobject/table.py +++ b/manim/mobject/table.py @@ -65,8 +65,7 @@ def construct(self): import itertools as it -from collections.abc import Iterable, Sequence -from typing import Callable +from collections.abc import Callable, Iterable, Sequence from manim.mobject.geometry.line import Line from manim.mobject.geometry.polygram import Polygon diff --git a/manim/mobject/text/numbers.py b/manim/mobject/text/numbers.py index 1c74cf5f0a..9e4accb5ef 100644 --- a/manim/mobject/text/numbers.py +++ b/manim/mobject/text/numbers.py @@ -4,10 +4,10 @@ __all__ = ["DecimalNumber", "Integer", "Variable"] -from collections.abc import Sequence from typing import Any import numpy as np +from typing_extensions import Self from manim import config from manim.constants import * @@ -16,10 +16,9 @@ from manim.mobject.text.text_mobject import Text from manim.mobject.types.vectorized_mobject import VMobject from manim.mobject.value_tracker import ValueTracker +from manim.typing import Vector3DLike -string_to_mob_map = {} - -__all__ = ["DecimalNumber", "Integer", "Variable"] +string_to_mob_map: dict[str, VMobject] = {} class DecimalNumber(VMobject, metaclass=ConvertToOpenGL): @@ -86,7 +85,7 @@ def __init__( self, number: float = 0, num_decimal_places: int = 2, - mob_class: VMobject = MathTex, + mob_class: type[SingleStringMathTex] = MathTex, include_sign: bool = False, group_with_commas: bool = True, digit_buff_per_font_unit: float = 0.001, @@ -94,13 +93,13 @@ def __init__( unit: str | None = None, # Aligned to bottom unless it starts with "^" unit_buff_per_font_unit: float = 0, include_background_rectangle: bool = False, - edge_to_fix: Sequence[float] = LEFT, + edge_to_fix: Vector3DLike = LEFT, font_size: float = DEFAULT_FONT_SIZE, stroke_width: float = 0, fill_opacity: float = 1.0, - **kwargs, + **kwargs: Any, ): - super().__init__(**kwargs, stroke_width=stroke_width) + super().__init__(**kwargs, fill_opacity=fill_opacity, stroke_width=stroke_width) self.number = number self.num_decimal_places = num_decimal_places self.include_sign = include_sign @@ -137,12 +136,13 @@ def __init__( self.init_colors() @property - def font_size(self): + def font_size(self) -> float: """The font size of the tex mobject.""" - return self.height / self.initial_height * self._font_size + return_value: float = self.height / self.initial_height * self._font_size + return return_value @font_size.setter - def font_size(self, font_val): + def font_size(self, font_val: float) -> None: if font_val <= 0: raise ValueError("font_size must be greater than 0.") elif self.height > 0: @@ -153,7 +153,7 @@ def font_size(self, font_val): # font_size does not depend on current size. self.scale(font_val / self.font_size) - def _set_submobjects_from_number(self, number): + def _set_submobjects_from_number(self, number: float) -> None: self.number = number self.submobjects = [] @@ -197,12 +197,12 @@ def _set_submobjects_from_number(self, number): self.unit_sign.align_to(self, UP) # track the initial height to enable scaling via font_size - self.initial_height = self.height + self.initial_height: float = self.height if self.include_background_rectangle: self.add_background_rectangle() - def _get_num_string(self, number): + def _get_num_string(self, number: float | complex) -> str: if isinstance(number, complex): formatter = self._get_complex_formatter() else: @@ -215,17 +215,22 @@ def _get_num_string(self, number): return num_string - def _string_to_mob(self, string: str, mob_class: VMobject | None = None, **kwargs): + def _string_to_mob( + self, + string: str, + mob_class: type[SingleStringMathTex] | None = None, + **kwargs: Any, + ) -> VMobject: if mob_class is None: mob_class = self.mob_class if string not in string_to_mob_map: string_to_mob_map[string] = mob_class(string, **kwargs) mob = string_to_mob_map[string].copy() - mob.font_size = self._font_size + mob.font_size = self._font_size # type: ignore[attr-defined] return mob - def _get_formatter(self, **kwargs): + def _get_formatter(self, **kwargs: Any) -> str: """ Configuration is based first off instance attributes, but overwritten by any kew word argument. Relevant @@ -258,7 +263,7 @@ def _get_formatter(self, **kwargs): ], ) - def _get_complex_formatter(self): + def _get_complex_formatter(self) -> str: return "".join( [ self._get_formatter(field_name="0.real"), @@ -267,7 +272,7 @@ def _get_complex_formatter(self): ], ) - def set_value(self, number: float): + def set_value(self, number: float) -> Self: """Set the value of the :class:`~.DecimalNumber` to a new number. Parameters @@ -304,10 +309,10 @@ def set_value(self, number: float): self.init_colors() return self - def get_value(self): + def get_value(self) -> float: return self.number - def increment_value(self, delta_t=1): + def increment_value(self, delta_t: float = 1) -> None: self.set_value(self.get_value() + delta_t) @@ -333,7 +338,7 @@ def __init__( ) -> None: super().__init__(number=number, num_decimal_places=num_decimal_places, **kwargs) - def get_value(self): + def get_value(self) -> int: return int(np.round(super().get_value())) @@ -444,9 +449,9 @@ def __init__( self, var: float, label: str | Tex | MathTex | Text | SingleStringMathTex, - var_type: DecimalNumber | Integer = DecimalNumber, + var_type: type[DecimalNumber | Integer] = DecimalNumber, num_decimal_places: int = 2, - **kwargs, + **kwargs: Any, ): self.label = MathTex(label) if isinstance(label, str) else label equals = MathTex("=").next_to(self.label, RIGHT) diff --git a/manim/mobject/text/tex_mobject.py b/manim/mobject/text/tex_mobject.py index 26334a60d9..b219694f7c 100644 --- a/manim/mobject/text/tex_mobject.py +++ b/manim/mobject/text/tex_mobject.py @@ -26,9 +26,12 @@ import itertools as it import operator as op import re -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from functools import reduce from textwrap import dedent +from typing import Any + +from typing_extensions import Self from manim import config, logger from manim.constants import * @@ -38,8 +41,6 @@ from manim.utils.tex import TexTemplate from manim.utils.tex_file_writing import tex_to_svg_file -tex_string_to_mob_map = {} - class SingleStringMathTex(SVGMobject): """Elementary building block for rendering text with LaTeX. @@ -59,11 +60,11 @@ def __init__( should_center: bool = True, height: float | None = None, organize_left_to_right: bool = False, - tex_environment: str = "align*", + tex_environment: str | None = "align*", tex_template: TexTemplate | None = None, font_size: float = DEFAULT_FONT_SIZE, color: ParsableManimColor | None = None, - **kwargs, + **kwargs: Any, ): if color is None: color = VMobject().color @@ -73,9 +74,8 @@ def __init__( self.tex_environment = tex_environment if tex_template is None: tex_template = config["tex_template"] - self.tex_template = tex_template + self.tex_template: TexTemplate = tex_template - assert isinstance(tex_string, str) self.tex_string = tex_string file_name = tex_to_svg_file( self._get_modified_expression(tex_string), @@ -105,16 +105,16 @@ def __init__( if self.organize_left_to_right: self._organize_submobjects_left_to_right() - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}({repr(self.tex_string)})" @property - def font_size(self): + def font_size(self) -> float: """The font size of the tex mobject.""" return self.height / self.initial_height / SCALE_FACTOR_PER_FONT_POINT @font_size.setter - def font_size(self, font_val): + def font_size(self, font_val: float) -> None: if font_val <= 0: raise ValueError("font_size must be greater than 0.") elif self.height > 0: @@ -125,13 +125,13 @@ def font_size(self, font_val): # font_size does not depend on current size. self.scale(font_val / self.font_size) - def _get_modified_expression(self, tex_string): + def _get_modified_expression(self, tex_string: str) -> str: result = tex_string result = result.strip() result = self._modify_special_strings(result) return result - def _modify_special_strings(self, tex): + def _modify_special_strings(self, tex: str) -> str: tex = tex.strip() should_add_filler = reduce( op.or_, @@ -184,7 +184,7 @@ def _modify_special_strings(self, tex): tex = "" return tex - def _remove_stray_braces(self, tex): + def _remove_stray_braces(self, tex: str) -> str: r""" Makes :class:`~.MathTex` resilient to unmatched braces. @@ -202,14 +202,14 @@ def _remove_stray_braces(self, tex): num_rights += 1 return tex - def _organize_submobjects_left_to_right(self): + def _organize_submobjects_left_to_right(self) -> Self: self.sort(lambda p: p[0]) return self - def get_tex_string(self): + def get_tex_string(self) -> str: return self.tex_string - def init_colors(self, propagate_colors=True): + def init_colors(self, propagate_colors: bool = True) -> Self: for submobject in self.submobjects: # needed to preserve original (non-black) # TeX colors of individual submobjects @@ -220,6 +220,7 @@ def init_colors(self, propagate_colors=True): submobject.init_colors() elif config.renderer == RendererType.CAIRO: submobject.init_colors(propagate_colors=propagate_colors) + return self class MathTex(SingleStringMathTex): @@ -255,21 +256,22 @@ def construct(self): def __init__( self, - *tex_strings, + *tex_strings: str, arg_separator: str = " ", substrings_to_isolate: Iterable[str] | None = None, - tex_to_color_map: dict[str, ManimColor] = None, - tex_environment: str = "align*", - **kwargs, + tex_to_color_map: dict[str, ManimColor] | None = None, + tex_environment: str | None = "align*", + **kwargs: Any, ): self.tex_template = kwargs.pop("tex_template", config["tex_template"]) self.arg_separator = arg_separator self.substrings_to_isolate = ( [] if substrings_to_isolate is None else substrings_to_isolate ) - self.tex_to_color_map = tex_to_color_map - if self.tex_to_color_map is None: - self.tex_to_color_map = {} + if tex_to_color_map is None: + self.tex_to_color_map: dict[str, ManimColor] = {} + else: + self.tex_to_color_map = tex_to_color_map self.tex_environment = tex_environment self.brace_notation_split_occurred = False self.tex_strings = self._break_up_tex_strings(tex_strings) @@ -301,12 +303,14 @@ def __init__( if self.organize_left_to_right: self._organize_submobjects_left_to_right() - def _break_up_tex_strings(self, tex_strings): + def _break_up_tex_strings(self, tex_strings: Sequence[str]) -> list[str]: # Separate out anything surrounded in double braces pre_split_length = len(tex_strings) - tex_strings = [re.split("{{(.*?)}}", str(t)) for t in tex_strings] - tex_strings = sum(tex_strings, []) - if len(tex_strings) > pre_split_length: + tex_strings_brace_splitted = [ + re.split("{{(.*?)}}", str(t)) for t in tex_strings + ] + tex_strings_combined = sum(tex_strings_brace_splitted, []) + if len(tex_strings_combined) > pre_split_length: self.brace_notation_split_occurred = True # Separate out any strings specified in the isolate @@ -324,19 +328,19 @@ def _break_up_tex_strings(self, tex_strings): pattern = "|".join(patterns) if pattern: pieces = [] - for s in tex_strings: + for s in tex_strings_combined: pieces.extend(re.split(pattern, s)) else: - pieces = tex_strings + pieces = tex_strings_combined return [p for p in pieces if p] - def _break_up_by_substrings(self): + def _break_up_by_substrings(self) -> Self: """ Reorganize existing submobjects one layer deeper based on the structure of tex_strings (as a list of tex_strings) """ - new_submobjects = [] + new_submobjects: list[VMobject] = [] curr_index = 0 for tex_string in self.tex_strings: sub_tex_mob = SingleStringMathTex( @@ -358,8 +362,10 @@ def _break_up_by_substrings(self): self.submobjects = new_submobjects return self - def get_parts_by_tex(self, tex, substring=True, case_sensitive=True): - def test(tex1, tex2): + def get_parts_by_tex( + self, tex: str, substring: bool = True, case_sensitive: bool = True + ) -> VGroup: + def test(tex1: str, tex2: str) -> bool: if not case_sensitive: tex1 = tex1.lower() tex2 = tex2.lower() @@ -370,19 +376,25 @@ def test(tex1, tex2): return VGroup(*(m for m in self.submobjects if test(tex, m.get_tex_string()))) - def get_part_by_tex(self, tex, **kwargs): + def get_part_by_tex(self, tex: str, **kwargs: Any) -> MathTex | None: all_parts = self.get_parts_by_tex(tex, **kwargs) return all_parts[0] if all_parts else None - def set_color_by_tex(self, tex, color, **kwargs): + def set_color_by_tex( + self, tex: str, color: ParsableManimColor, **kwargs: Any + ) -> Self: parts_to_color = self.get_parts_by_tex(tex, **kwargs) for part in parts_to_color: part.set_color(color) return self def set_opacity_by_tex( - self, tex: str, opacity: float = 0.5, remaining_opacity: float = None, **kwargs - ): + self, + tex: str, + opacity: float = 0.5, + remaining_opacity: float | None = None, + **kwargs: Any, + ) -> Self: """ Sets the opacity of the tex specified. If 'remaining_opacity' is specified, then the remaining tex will be set to that opacity. @@ -403,7 +415,9 @@ def set_opacity_by_tex( part.set_opacity(opacity) return self - def set_color_by_tex_to_color_map(self, texs_to_color_map, **kwargs): + def set_color_by_tex_to_color_map( + self, texs_to_color_map: dict[str, ManimColor], **kwargs: Any + ) -> Self: for texs, color in list(texs_to_color_map.items()): try: # If the given key behaves like tex_strings @@ -415,17 +429,19 @@ def set_color_by_tex_to_color_map(self, texs_to_color_map, **kwargs): self.set_color_by_tex(tex, color, **kwargs) return self - def index_of_part(self, part): + def index_of_part(self, part: MathTex) -> int: split_self = self.split() if part not in split_self: raise ValueError("Trying to get index of part not in MathTex") return split_self.index(part) - def index_of_part_by_tex(self, tex, **kwargs): + def index_of_part_by_tex(self, tex: str, **kwargs: Any) -> int: part = self.get_part_by_tex(tex, **kwargs) + if part is None: + return -1 return self.index_of_part(part) - def sort_alphabetically(self): + def sort_alphabetically(self) -> None: self.submobjects.sort(key=lambda m: m.get_tex_string()) @@ -447,7 +463,11 @@ class Tex(MathTex): """ def __init__( - self, *tex_strings, arg_separator="", tex_environment="center", **kwargs + self, + *tex_strings: str, + arg_separator: str = "", + tex_environment: str | None = "center", + **kwargs: Any, ): super().__init__( *tex_strings, @@ -477,11 +497,11 @@ def construct(self): def __init__( self, - *items, - buff=MED_LARGE_BUFF, - dot_scale_factor=2, - tex_environment=None, - **kwargs, + *items: str, + buff: float = MED_LARGE_BUFF, + dot_scale_factor: float = 2, + tex_environment: str | None = None, + **kwargs: Any, ): self.buff = buff self.dot_scale_factor = dot_scale_factor @@ -496,12 +516,12 @@ def __init__( part.add_to_back(dot) self.arrange(DOWN, aligned_edge=LEFT, buff=self.buff) - def fade_all_but(self, index_or_string, opacity=0.5): + def fade_all_but(self, index_or_string: int | str, opacity: float = 0.5) -> None: arg = index_or_string if isinstance(arg, str): part = self.get_part_by_tex(arg) elif isinstance(arg, int): - part = self.submobjects[arg] + part = self.submobjects[arg] # type: ignore[assignment] else: raise TypeError(f"Expected int or string, got {arg}") for other_part in self.submobjects: @@ -531,11 +551,11 @@ def construct(self): def __init__( self, - *text_parts, - include_underline=True, - match_underline_width_to_text=False, - underline_buff=MED_SMALL_BUFF, - **kwargs, + *text_parts: str, + include_underline: bool = True, + match_underline_width_to_text: bool = False, + underline_buff: float = MED_SMALL_BUFF, + **kwargs: Any, ): self.include_underline = include_underline self.match_underline_width_to_text = match_underline_width_to_text diff --git a/manim/mobject/text/text_mobject.py b/manim/mobject/text/text_mobject.py index 5f20dff2a4..c35f874ed5 100644 --- a/manim/mobject/text/text_mobject.py +++ b/manim/mobject/text/text_mobject.py @@ -678,6 +678,7 @@ def _text2hash(self, color: ManimColor): settings += str(self.t2f) + str(self.t2s) + str(self.t2w) + str(self.t2c) settings += str(self.line_spacing) + str(self._font_size) settings += str(self.disable_ligatures) + settings += str(self.gradient) id_str = self.text + settings hasher = hashlib.sha256() hasher.update(id_str.encode()) @@ -1561,7 +1562,7 @@ def register_font(font_file: str | Path): logger.debug("Found file at %s", file_path.absolute()) break else: - error = f"Can't find {font_file}.Tried these : {possible_paths}" + error = f"Can't find {font_file}. Checked paths: {possible_paths}" raise FileNotFoundError(error) try: diff --git a/manim/mobject/three_d/polyhedra.py b/manim/mobject/three_d/polyhedra.py index 8046f6066c..1f72873f7b 100644 --- a/manim/mobject/three_d/polyhedra.py +++ b/manim/mobject/three_d/polyhedra.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import Hashable +from typing import TYPE_CHECKING, Any import numpy as np @@ -14,7 +15,7 @@ if TYPE_CHECKING: from manim.mobject.mobject import Mobject - from manim.typing import Point3D + from manim.typing import Point3D, Point3DLike_Array __all__ = [ "Polyhedron", @@ -96,10 +97,10 @@ def construct(self): def __init__( self, - vertex_coords: list[list[float] | np.ndarray], + vertex_coords: Point3DLike_Array, faces_list: list[list[int]], faces_config: dict[str, str | int | float | bool] = {}, - graph_config: dict[str, str | int | float | bool] = {}, + graph_config: dict[str, Any] = {}, ): super().__init__() self.faces_config = dict( @@ -116,7 +117,7 @@ def __init__( ) self.vertex_coords = vertex_coords self.vertex_indices = list(range(len(self.vertex_coords))) - self.layout = dict(enumerate(self.vertex_coords)) + self.layout: dict[Hashable, Any] = dict(enumerate(self.vertex_coords)) self.faces_list = faces_list self.face_coords = [[self.layout[j] for j in i] for i in faces_list] self.edges = self.get_edges(self.faces_list) @@ -129,14 +130,14 @@ def __init__( def get_edges(self, faces_list: list[list[int]]) -> list[tuple[int, int]]: """Creates list of cyclic pairwise tuples.""" - edges = [] + edges: list[tuple[int, int]] = [] for face in faces_list: edges += zip(face, face[1:] + face[:1]) return edges def create_faces( self, - face_coords: list[list[list | np.ndarray]], + face_coords: Point3DLike_Array, ) -> VGroup: """Creates VGroup of faces from a list of face coordinates.""" face_group = VGroup() @@ -144,12 +145,12 @@ def create_faces( face_group.add(Polygon(*face, **self.faces_config)) return face_group - def update_faces(self, m: Mobject): + def update_faces(self, m: Mobject) -> None: face_coords = self.extract_face_coords() new_faces = self.create_faces(face_coords) self.faces.match_points(new_faces) - def extract_face_coords(self) -> list[list[np.ndarray]]: + def extract_face_coords(self) -> Point3DLike_Array: """Extracts the coordinates of the vertices in the graph. Used for updating faces. """ @@ -181,7 +182,7 @@ def construct(self): self.add(obj) """ - def __init__(self, edge_length: float = 1, **kwargs): + def __init__(self, edge_length: float = 1, **kwargs: Any): unit = edge_length * np.sqrt(2) / 4 super().__init__( vertex_coords=[ @@ -216,7 +217,7 @@ def construct(self): self.add(obj) """ - def __init__(self, edge_length: float = 1, **kwargs): + def __init__(self, edge_length: float = 1, **kwargs: Any): unit = edge_length * np.sqrt(2) / 2 super().__init__( vertex_coords=[ @@ -262,7 +263,7 @@ def construct(self): self.add(obj) """ - def __init__(self, edge_length: float = 1, **kwargs): + def __init__(self, edge_length: float = 1, **kwargs: Any): unit_a = edge_length * ((1 + np.sqrt(5)) / 4) unit_b = edge_length * (1 / 2) super().__init__( @@ -327,7 +328,7 @@ def construct(self): self.add(obj) """ - def __init__(self, edge_length: float = 1, **kwargs): + def __init__(self, edge_length: float = 1, **kwargs: Any): unit_a = edge_length * ((1 + np.sqrt(5)) / 4) unit_b = edge_length * ((3 + np.sqrt(5)) / 4) unit_c = edge_length * (1 / 2) @@ -427,7 +428,7 @@ def construct(self): self.add(dots) """ - def __init__(self, *points: Point3D, tolerance: float = 1e-5, **kwargs): + def __init__(self, *points: Point3D, tolerance: float = 1e-5, **kwargs: Any): # Build Convex Hull array = np.array(points) hull = QuickHull(tolerance) diff --git a/manim/mobject/three_d/three_d_utils.py b/manim/mobject/three_d/three_d_utils.py index 0a9ccb8a2d..997efea335 100644 --- a/manim/mobject/three_d/three_d_utils.py +++ b/manim/mobject/three_d/three_d_utils.py @@ -24,35 +24,39 @@ if TYPE_CHECKING: from manim.typing import Point3D, Vector3D + from ..types.vectorized_mobject import VMobject -def get_3d_vmob_gradient_start_and_end_points(vmob) -> tuple[Point3D, Point3D]: + +def get_3d_vmob_gradient_start_and_end_points( + vmob: VMobject, +) -> tuple[Point3D, Point3D]: return ( get_3d_vmob_start_corner(vmob), get_3d_vmob_end_corner(vmob), ) -def get_3d_vmob_start_corner_index(vmob) -> Literal[0]: +def get_3d_vmob_start_corner_index(vmob: VMobject) -> Literal[0]: return 0 -def get_3d_vmob_end_corner_index(vmob) -> int: +def get_3d_vmob_end_corner_index(vmob: VMobject) -> int: return ((len(vmob.points) - 1) // 6) * 3 -def get_3d_vmob_start_corner(vmob) -> Point3D: +def get_3d_vmob_start_corner(vmob: VMobject) -> Point3D: if vmob.get_num_points() == 0: return np.array(ORIGIN) return vmob.points[get_3d_vmob_start_corner_index(vmob)] -def get_3d_vmob_end_corner(vmob) -> Point3D: +def get_3d_vmob_end_corner(vmob: VMobject) -> Point3D: if vmob.get_num_points() == 0: return np.array(ORIGIN) return vmob.points[get_3d_vmob_end_corner_index(vmob)] -def get_3d_vmob_unit_normal(vmob, point_index: int) -> Vector3D: +def get_3d_vmob_unit_normal(vmob: VMobject, point_index: int) -> Vector3D: n_points = vmob.get_num_points() if len(vmob.get_anchors()) <= 2: return np.array(UP) @@ -68,9 +72,9 @@ def get_3d_vmob_unit_normal(vmob, point_index: int) -> Vector3D: return unit_normal -def get_3d_vmob_start_corner_unit_normal(vmob) -> Vector3D: +def get_3d_vmob_start_corner_unit_normal(vmob: VMobject) -> Vector3D: return get_3d_vmob_unit_normal(vmob, get_3d_vmob_start_corner_index(vmob)) -def get_3d_vmob_end_corner_unit_normal(vmob) -> Vector3D: +def get_3d_vmob_end_corner_unit_normal(vmob: VMobject) -> Vector3D: return get_3d_vmob_unit_normal(vmob, get_3d_vmob_end_corner_index(vmob)) diff --git a/manim/mobject/three_d/three_dimensions.py b/manim/mobject/three_d/three_dimensions.py index 5732ebb98c..161b86dc8f 100644 --- a/manim/mobject/three_d/three_dimensions.py +++ b/manim/mobject/three_d/three_dimensions.py @@ -2,9 +2,6 @@ from __future__ import annotations -from manim.typing import Point3DLike, Vector3D -from manim.utils.color import BLUE, BLUE_D, BLUE_E, LIGHT_GREY, WHITE, interpolate_color - __all__ = [ "ThreeDVMobject", "Surface", @@ -19,8 +16,8 @@ "Torus", ] -from collections.abc import Iterable, Sequence -from typing import Any, Callable +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any import numpy as np from typing_extensions import Self @@ -34,12 +31,21 @@ from manim.mobject.opengl.opengl_mobject import OpenGLMobject from manim.mobject.types.vectorized_mobject import VectorizedPoint, VGroup, VMobject from manim.utils.color import ( + BLUE, + BLUE_D, + BLUE_E, + LIGHT_GREY, + WHITE, ManimColor, ParsableManimColor, + interpolate_color, ) from manim.utils.iterables import tuplify from manim.utils.space_ops import normalize, perpendicular_bisector, z_to_vector +if TYPE_CHECKING: + from manim.typing import Point3D, Point3DLike, Vector3DLike + class ThreeDVMobject(VMobject, metaclass=ConvertToOpenGL): def __init__(self, shade_in_3d: bool = True, **kwargs): @@ -116,19 +122,21 @@ def __init__( ) -> None: self.u_range = u_range self.v_range = v_range - super().__init__(**kwargs) + super().__init__( + fill_color=fill_color, + fill_opacity=fill_opacity, + stroke_color=stroke_color, + stroke_width=stroke_width, + **kwargs, + ) self.resolution = resolution self.surface_piece_config = surface_piece_config - self.fill_color: ManimColor = ManimColor(fill_color) - self.fill_opacity = fill_opacity if checkerboard_colors: self.checkerboard_colors: list[ManimColor] = [ ManimColor(x) for x in checkerboard_colors ] else: self.checkerboard_colors = checkerboard_colors - self.stroke_color: ManimColor = ManimColor(stroke_color) - self.stroke_width = stroke_width self.should_make_jagged = should_make_jagged self.pre_function_handle_to_anchor_scale_factor = ( pre_function_handle_to_anchor_scale_factor @@ -510,6 +518,7 @@ def generate_points(self) -> None: face = Square( side_length=self.side_length, shade_in_3d=True, + joint_type=LineJointType.BEVEL, ) face.flip() face.shift(self.side_length * OUT / 2.0) @@ -517,7 +526,8 @@ def generate_points(self) -> None: self.add(face) - init_points = generate_points + def init_points(self) -> None: + self.generate_points() class Prism(Cube): @@ -967,8 +977,8 @@ def set_start_and_end_attrs( def pointify( self, mob_or_point: Mobject | Point3DLike, - direction: Vector3D = None, - ) -> np.ndarray: + direction: Vector3DLike | None = None, + ) -> Point3D: """Gets a point representing the center of the :class:`Mobjects <.Mobject>`. Parameters @@ -1015,7 +1025,7 @@ def get_end(self) -> np.ndarray: def parallel_to( cls, line: Line3D, - point: Vector3D = ORIGIN, + point: Point3DLike = ORIGIN, length: float = 5, **kwargs, ) -> Line3D: @@ -1051,11 +1061,11 @@ def construct(self): line2 = Line3D.parallel_to(line1, color=YELLOW) self.add(ax, line1, line2) """ - point = np.array(point) + np_point = np.asarray(point) vect = normalize(line.vect) return cls( - point + vect * length / 2, - point - vect * length / 2, + np_point + vect * length / 2, + np_point - vect * length / 2, **kwargs, ) @@ -1063,7 +1073,7 @@ def construct(self): def perpendicular_to( cls, line: Line3D, - point: Vector3D = ORIGIN, + point: Vector3DLike = ORIGIN, length: float = 5, **kwargs, ) -> Line3D: @@ -1099,17 +1109,17 @@ def construct(self): line2 = Line3D.perpendicular_to(line1, color=BLUE) self.add(ax, line1, line2) """ - point = np.array(point) + np_point = np.asarray(point) - norm = np.cross(line.vect, point - line.start) + norm = np.cross(line.vect, np_point - line.start) if all(np.linalg.norm(norm) == np.zeros(3)): raise ValueError("Could not find the perpendicular.") start, end = perpendicular_bisector([line.start, line.end], norm) vect = normalize(end - start) return cls( - point + vect * length / 2, - point - vect * length / 2, + np_point + vect * length / 2, + np_point - vect * length / 2, **kwargs, ) diff --git a/manim/mobject/types/image_mobject.py b/manim/mobject/types/image_mobject.py index 56029f941e..d2cb1a3f0e 100644 --- a/manim/mobject/types/image_mobject.py +++ b/manim/mobject/types/image_mobject.py @@ -5,7 +5,7 @@ __all__ = ["AbstractImageMobject", "ImageMobject", "ImageMobjectFromCamera"] import pathlib -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np from PIL import Image @@ -14,6 +14,7 @@ from manim.mobject.geometry.shape_matchers import SurroundingRectangle from ... import config +from ...camera.moving_camera import MovingCamera from ...constants import * from ...mobject.mobject import Mobject from ...utils.bezier import interpolate @@ -23,12 +24,12 @@ __all__ = ["ImageMobject", "ImageMobjectFromCamera"] if TYPE_CHECKING: - from typing import Any - import numpy.typing as npt from typing_extensions import Self - from manim.typing import StrPath + from manim.typing import PixelArray, StrPath + + from ...camera.moving_camera import MovingCamera class AbstractImageMobject(Mobject): @@ -57,7 +58,7 @@ def __init__( self.set_resampling_algorithm(resampling_algorithm) super().__init__(**kwargs) - def get_pixel_array(self) -> None: + def get_pixel_array(self) -> PixelArray: raise NotImplementedError() def set_color(self, color, alpha=None, family=True): @@ -205,6 +206,7 @@ def __init__( self.pixel_array[:, :, :3] = ( np.iinfo(self.pixel_array_dtype).max - self.pixel_array[:, :, :3] ) + self.orig_alpha_pixel_array = self.pixel_array[:, :, 3].copy() super().__init__(scale_to_resolution, **kwargs) def get_pixel_array(self): @@ -230,8 +232,7 @@ def set_opacity(self, alpha: float) -> Self: The alpha value of the object, 1 being opaque and 0 being transparent. """ - self.pixel_array[:, :, 3] = int(255 * alpha) - self.fill_opacity = alpha + self.pixel_array[:, :, 3] = self.orig_alpha_pixel_array * alpha self.stroke_opacity = alpha return self @@ -303,7 +304,7 @@ def get_style(self) -> dict[str, Any]: class ImageMobjectFromCamera(AbstractImageMobject): def __init__( self, - camera, + camera: MovingCamera, default_display_frame_config: dict[str, Any] | None = None, **kwargs: Any, ) -> None: diff --git a/manim/mobject/types/point_cloud_mobject.py b/manim/mobject/types/point_cloud_mobject.py index f5953aab8c..f820f49bfc 100644 --- a/manim/mobject/types/point_cloud_mobject.py +++ b/manim/mobject/types/point_cloud_mobject.py @@ -4,7 +4,8 @@ __all__ = ["PMobject", "Mobject1D", "Mobject2D", "PGroup", "PointCloudDot", "Point"] -from typing import TYPE_CHECKING +from collections.abc import Callable +from typing import TYPE_CHECKING, Any import numpy as np @@ -29,13 +30,10 @@ __all__ = ["PMobject", "Mobject1D", "Mobject2D", "PGroup", "PointCloudDot", "Point"] if TYPE_CHECKING: - from collections.abc import Callable - from typing import Any - import numpy.typing as npt from typing_extensions import Self - from manim.typing import ManimFloat, Point3DLike, Vector3D + from manim.typing import ManimFloat, Point3DLike class PMobject(Mobject, metaclass=ConvertToOpenGL): @@ -349,7 +347,7 @@ def construct(self): def __init__( self, - center: Vector3D = ORIGIN, + center: Point3DLike = ORIGIN, radius: float = 2.0, stroke_width: int = 2, density: int = DEFAULT_POINT_DENSITY_1D, @@ -406,7 +404,7 @@ def construct(self): """ def __init__( - self, location: Vector3D = ORIGIN, color: ManimColor = BLACK, **kwargs: Any + self, location: Point3DLike = ORIGIN, color: ManimColor = BLACK, **kwargs: Any ) -> None: self.location = location super().__init__(color=color, **kwargs) diff --git a/manim/mobject/types/vectorized_mobject.py b/manim/mobject/types/vectorized_mobject.py index f5c97e448e..f937d8ab58 100644 --- a/manim/mobject/types/vectorized_mobject.py +++ b/manim/mobject/types/vectorized_mobject.py @@ -11,11 +11,10 @@ "DashedVMobject", ] - import itertools as it import sys -from collections.abc import Hashable, Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Callable, Literal +from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Literal import numpy as np from PIL.Image import Image @@ -48,8 +47,6 @@ from manim.utils.space_ops import rotate_vector, shoelace_direction if TYPE_CHECKING: - from typing import Any - import numpy.typing as npt from typing_extensions import Self @@ -66,6 +63,7 @@ Point3DLike_Array, RGBA_Array_Float, Vector3D, + Vector3DLike, Zeros, ) @@ -117,7 +115,7 @@ def __init__( background_stroke_width: float = 0, sheen_factor: float = 0.0, joint_type: LineJointType | None = None, - sheen_direction: Vector3D = UL, + sheen_direction: Vector3DLike = UL, close_new_points: bool = False, pre_function_handle_to_anchor_scale_factor: float = 0.01, make_smooth_after_applying_functions: bool = False, @@ -142,7 +140,7 @@ def __init__( self.joint_type: LineJointType = ( LineJointType.AUTO if joint_type is None else joint_type ) - self.sheen_direction: Vector3D = sheen_direction + self.sheen_direction = sheen_direction self.close_new_points: bool = close_new_points self.pre_function_handle_to_anchor_scale_factor: float = ( pre_function_handle_to_anchor_scale_factor @@ -395,7 +393,7 @@ def set_style( background_stroke_width: float | None = None, background_stroke_opacity: float | None = None, sheen_factor: float | None = None, - sheen_direction: Vector3D | None = None, + sheen_direction: Vector3DLike | None = None, background_image: Image | str | None = None, family: bool = True, ) -> Self: @@ -620,7 +618,7 @@ def get_color(self) -> ManimColor: color = property(get_color, set_color) - def set_sheen_direction(self, direction: Vector3D, family: bool = True) -> Self: + def set_sheen_direction(self, direction: Vector3DLike, family: bool = True) -> Self: """Sets the direction of the applied sheen. Parameters @@ -639,16 +637,16 @@ def set_sheen_direction(self, direction: Vector3D, family: bool = True) -> Self: :meth:`~.VMobject.set_sheen` :meth:`~.VMobject.rotate_sheen_direction` """ - direction = np.array(direction) + direction_copy = np.array(direction) if family: for submob in self.get_family(): - submob.sheen_direction = direction + submob.sheen_direction = direction_copy.copy() else: - self.sheen_direction: Vector3D = direction + self.sheen_direction = direction_copy return self def rotate_sheen_direction( - self, angle: float, axis: Vector3D = OUT, family: bool = True + self, angle: float, axis: Vector3DLike = OUT, family: bool = True ) -> Self: """Rotates the direction of the applied sheen. @@ -681,7 +679,7 @@ def rotate_sheen_direction( return self def set_sheen( - self, factor: float, direction: Vector3D | None = None, family: bool = True + self, factor: float, direction: Vector3DLike | None = None, family: bool = True ) -> Self: """Applies a color gradient from a direction. @@ -1189,7 +1187,7 @@ def apply_function(self, function: MappingFunction) -> Self: def rotate( self, angle: float, - axis: Vector3D = OUT, + axis: Vector3DLike = OUT, about_point: Point3DLike | None = None, **kwargs, ) -> Self: @@ -1916,7 +1914,6 @@ def pointwise_become_partial( return self num_curves = vmobject.get_num_curves() if num_curves == 0: - self.clear_points() return self # The following two lines will compute which Bézier curves of the given Mobject must be processed. diff --git a/manim/mobject/value_tracker.py b/manim/mobject/value_tracker.py index 9d81035e89..a3648cd5ad 100644 --- a/manim/mobject/value_tracker.py +++ b/manim/mobject/value_tracker.py @@ -4,6 +4,7 @@ __all__ = ["ValueTracker", "ComplexValueTracker"] +from typing import TYPE_CHECKING, Any import numpy as np @@ -11,6 +12,11 @@ from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from manim.utils.paths import straight_path +if TYPE_CHECKING: + from typing_extensions import Self + + from manim.typing import PathFuncType + class ValueTracker(Mobject, metaclass=ConvertToOpenGL): """A mobject that can be used for tracking (real-valued) parameters. @@ -69,69 +75,131 @@ def construct(self): """ - def __init__(self, value=0, **kwargs): + def __init__(self, value: float = 0, **kwargs: Any) -> None: super().__init__(**kwargs) self.set(points=np.zeros((1, 3))) self.set_value(value) def get_value(self) -> float: """Get the current value of this ValueTracker.""" - return self.points[0, 0] + value: float = self.points[0, 0] + return value - def set_value(self, value: float): - """Sets a new scalar value to the ValueTracker""" + def set_value(self, value: float) -> Self: + """Sets a new scalar value to the ValueTracker.""" self.points[0, 0] = value return self - def increment_value(self, d_value: float): - """Increments (adds) a scalar value to the ValueTracker""" + def increment_value(self, d_value: float) -> Self: + """Increments (adds) a scalar value to the ValueTracker.""" self.set_value(self.get_value() + d_value) return self - def __bool__(self): - """Return whether the value of this value tracker evaluates as true.""" + def __bool__(self) -> bool: + """Return whether the value of this ValueTracker evaluates as true.""" return bool(self.get_value()) - def __iadd__(self, d_value: float): - """adds ``+=`` syntax to increment the value of the ValueTracker""" + def __add__(self, d_value: float | Mobject) -> ValueTracker: + """Return a new :class:`ValueTracker` whose value is the current tracker's value plus + ``d_value``. + """ + if isinstance(d_value, Mobject): + raise ValueError( + "Cannot increment ValueTracker by a Mobject. Please provide a scalar value." + ) + return ValueTracker(self.get_value() + d_value) + + def __iadd__(self, d_value: float | Mobject) -> Self: + """adds ``+=`` syntax to increment the value of the ValueTracker.""" + if isinstance(d_value, Mobject): + raise ValueError( + "Cannot increment ValueTracker by a Mobject. Please provide a scalar value." + ) self.increment_value(d_value) return self - def __ifloordiv__(self, d_value: float): - """Set the value of this value tracker to the floor division of the current value by ``d_value``.""" + def __floordiv__(self, d_value: float) -> ValueTracker: + """Return a new :class:`ValueTracker` whose value is the floor division of the current + tracker's value by ``d_value``. + """ + return ValueTracker(self.get_value() // d_value) + + def __ifloordiv__(self, d_value: float) -> Self: + """Set the value of this ValueTracker to the floor division of the current value by ``d_value``.""" self.set_value(self.get_value() // d_value) return self - def __imod__(self, d_value: float): - """Set the value of this value tracker to the current value modulo ``d_value``.""" + def __mod__(self, d_value: float) -> ValueTracker: + """Return a new :class:`ValueTracker` whose value is the current tracker's value + modulo ``d_value``. + """ + return ValueTracker(self.get_value() % d_value) + + def __imod__(self, d_value: float) -> Self: + """Set the value of this ValueTracker to the current value modulo ``d_value``.""" self.set_value(self.get_value() % d_value) return self - def __imul__(self, d_value: float): - """Set the value of this value tracker to the product of the current value and ``d_value``.""" + def __mul__(self, d_value: float) -> ValueTracker: + """Return a new :class:`ValueTracker` whose value is the current tracker's value multiplied by + ``d_value``. + """ + return ValueTracker(self.get_value() * d_value) + + def __imul__(self, d_value: float) -> Self: + """Set the value of this ValueTracker to the product of the current value and ``d_value``.""" self.set_value(self.get_value() * d_value) return self - def __ipow__(self, d_value: float): - """Set the value of this value tracker to the current value raised to the power of ``d_value``.""" + def __pow__(self, d_value: float) -> ValueTracker: + """Return a new :class:`ValueTracker` whose value is the current tracker's value raised to the + power of ``d_value``. + """ + return ValueTracker(self.get_value() ** d_value) + + def __ipow__(self, d_value: float) -> Self: + """Set the value of this ValueTracker to the current value raised to the power of ``d_value``.""" self.set_value(self.get_value() ** d_value) return self - def __isub__(self, d_value: float): - """adds ``-=`` syntax to decrement the value of the ValueTracker""" + def __sub__(self, d_value: float | Mobject) -> ValueTracker: + """Return a new :class:`ValueTracker` whose value is the current tracker's value minus + ``d_value``. + """ + if isinstance(d_value, Mobject): + raise ValueError( + "Cannot decrement ValueTracker by a Mobject. Please provide a scalar value." + ) + return ValueTracker(self.get_value() - d_value) + + def __isub__(self, d_value: float | Mobject) -> Self: + """Adds ``-=`` syntax to decrement the value of the ValueTracker.""" + if isinstance(d_value, Mobject): + raise ValueError( + "Cannot decrement ValueTracker by a Mobject. Please provide a scalar value." + ) self.increment_value(-d_value) return self - def __itruediv__(self, d_value: float): - """Sets the value of this value tracker to the current value divided by ``d_value``.""" + def __truediv__(self, d_value: float) -> ValueTracker: + """Return a new :class:`ValueTracker` whose value is the current tracker's value + divided by ``d_value``. + """ + return ValueTracker(self.get_value() / d_value) + + def __itruediv__(self, d_value: float) -> Self: + """Sets the value of this ValueTracker to the current value divided by ``d_value``.""" self.set_value(self.get_value() / d_value) return self - def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path()): - """ - Turns self into an interpolation between mobject1 - and mobject2. - """ + def interpolate( + self, + mobject1: Mobject, + mobject2: Mobject, + alpha: float, + path_func: PathFuncType = straight_path(), + ) -> Self: + """Turns ``self`` into an interpolation between ``mobject1`` and ``mobject2``.""" self.set(points=path_func(mobject1.points, mobject2.points, alpha)) return self @@ -139,6 +207,8 @@ def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path()): class ComplexValueTracker(ValueTracker): """Tracks a complex-valued parameter. + The value is internally stored as a points array [a, b, 0]. This can be accessed directly + to represent the value geometrically, see the usage example. When the value is set through :attr:`animate`, the value will take a straight path from the source point to the destination point. @@ -161,16 +231,12 @@ def construct(self): self.play(tracker.animate.set_value(tracker.get_value() / (-2 + 3j))) """ - def get_value(self): - """Get the current value of this value tracker as a complex number. - - The value is internally stored as a points array [a, b, 0]. This can be accessed directly - to represent the value geometrically, see the usage example. - """ + def get_value(self) -> complex: # type: ignore [override] + """Get the current value of this ComplexValueTracker as a complex number.""" return complex(*self.points[0, :2]) - def set_value(self, z): - """Sets a new complex value to the ComplexValueTracker""" - z = complex(z) + def set_value(self, value: complex | float) -> Self: + """Sets a new complex value to the ComplexValueTracker.""" + z = complex(value) self.points[0, :2] = (z.real, z.imag) return self diff --git a/manim/mobject/vector_field.py b/manim/mobject/vector_field.py index 28f5c6d26f..44b3be6c0d 100644 --- a/manim/mobject/vector_field.py +++ b/manim/mobject/vector_field.py @@ -10,9 +10,8 @@ import itertools as it import random -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from math import ceil, floor -from typing import Callable import numpy as np from PIL import Image diff --git a/manim/renderer/cairo_renderer.py b/manim/renderer/cairo_renderer.py index 7efd7b022f..351203a871 100644 --- a/manim/renderer/cairo_renderer.py +++ b/manim/renderer/cairo_renderer.py @@ -1,6 +1,7 @@ from __future__ import annotations -import typing +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any import numpy as np @@ -13,9 +14,7 @@ from ..utils.exceptions import EndSceneEarlyException from ..utils.iterables import list_update -if typing.TYPE_CHECKING: - from typing import Any - +if TYPE_CHECKING: from manim.animation.animation import Animation from manim.scene.scene import Scene @@ -33,11 +32,11 @@ class CairoRenderer: def __init__( self, - file_writer_class=SceneFileWriter, - camera_class=None, - skip_animations=False, - **kwargs, - ): + file_writer_class: type[SceneFileWriter] = SceneFileWriter, + camera_class: type[Camera] | None = None, + skip_animations: bool = False, + **kwargs: Any, + ) -> None: # All of the following are set to EITHER the value passed via kwargs, # OR the value stored in the global config dict at the time of # _instance construction_. @@ -51,7 +50,7 @@ def __init__( self.time = 0 self.static_image = None - def init_scene(self, scene): + def init_scene(self, scene: Scene) -> None: self.file_writer: Any = self._file_writer_class( self, scene.__class__.__name__, @@ -119,12 +118,12 @@ def play( def update_frame( # TODO Description in Docstring self, - scene, - mobjects: typing.Iterable[Mobject] | None = None, + scene: Scene, + mobjects: Iterable[Mobject] | None = None, include_submobjects: bool = True, ignore_skipping: bool = True, - **kwargs, - ): + **kwargs: Any, + ) -> None: """Update the frame. Parameters @@ -214,8 +213,8 @@ def show_frame(self): def save_static_frame_data( self, scene: Scene, - static_mobjects: typing.Iterable[Mobject], - ) -> typing.Iterable[Mobject] | None: + static_mobjects: Iterable[Mobject], + ) -> Iterable[Mobject] | None: """Compute and save the static frame, that will be reused at each frame to avoid unnecessarily computing static mobjects. @@ -263,7 +262,7 @@ def update_skipping_status(self): self.skip_animations = True raise EndSceneEarlyException() - def scene_finished(self, scene): + def scene_finished(self, scene: Scene) -> None: # If no animations in scene, render an image instead if self.num_plays: self.file_writer.finish() diff --git a/manim/renderer/opengl_renderer.py b/manim/renderer/opengl_renderer.py index 7d4b6a4467..76ecdad54b 100644 --- a/manim/renderer/opengl_renderer.py +++ b/manim/renderer/opengl_renderer.py @@ -4,14 +4,18 @@ import itertools as it import time from functools import cached_property -from typing import Any +from typing import TYPE_CHECKING, Any import moderngl import numpy as np from PIL import Image from manim import config, logger -from manim.mobject.opengl.opengl_mobject import OpenGLMobject, OpenGLPoint +from manim.mobject.opengl.opengl_mobject import ( + OpenGLMobject, + OpenGLPoint, + _AnimationBuilder, +) from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject from manim.utils.caching import handle_caching_play from manim.utils.color import color_to_rgba @@ -35,6 +39,15 @@ render_opengl_vectorized_mobject_stroke, ) +if TYPE_CHECKING: + from typing_extensions import Self + + from manim.animation.animation import Animation + from manim.mobject.mobject import Mobject + from manim.scene.scene import Scene + from manim.typing import Point3D + + __all__ = ["OpenGLCamera", "OpenGLRenderer"] @@ -102,7 +115,7 @@ def __init__( self.euler_angles = euler_angles self.refresh_rotation_matrix() - def get_position(self): + def get_position(self) -> Point3D: return self.model_matrix[:, 3][:3] def set_position(self, position): @@ -123,7 +136,7 @@ def init_points(self): self.set_height(self.frame_shape[1], stretch=True) self.move_to(self.center_point) - def to_default_state(self): + def to_default_state(self) -> Self: self.center() self.set_height(config["frame_height"]) self.set_width(config["frame_width"]) @@ -166,28 +179,28 @@ def set_euler_angles(self, theta=None, phi=None, gamma=None): self.refresh_rotation_matrix() return self - def set_theta(self, theta): + def set_theta(self, theta: float) -> Self: return self.set_euler_angles(theta=theta) - def set_phi(self, phi): + def set_phi(self, phi: float) -> Self: return self.set_euler_angles(phi=phi) - def set_gamma(self, gamma): + def set_gamma(self, gamma: float) -> Self: return self.set_euler_angles(gamma=gamma) - def increment_theta(self, dtheta): + def increment_theta(self, dtheta: float) -> Self: self.euler_angles[0] += dtheta self.refresh_rotation_matrix() return self - def increment_phi(self, dphi): + def increment_phi(self, dphi: float) -> Self: phi = self.euler_angles[1] new_phi = clip(phi + dphi, -PI / 2, PI / 2) self.euler_angles[1] = new_phi self.refresh_rotation_matrix() return self - def increment_gamma(self, dgamma): + def increment_gamma(self, dgamma: float) -> Self: self.euler_angles[2] += dgamma self.refresh_rotation_matrix() return self @@ -199,15 +212,15 @@ def get_center(self): # Assumes first point is at the center return self.points[0] - def get_width(self): + def get_width(self) -> float: points = self.points return points[2, 0] - points[1, 0] - def get_height(self): + def get_height(self) -> float: points = self.points return points[4, 1] - points[3, 1] - def get_focal_distance(self): + def get_focal_distance(self) -> float: return self.focal_distance * self.get_height() def interpolate(self, *args, **kwargs): @@ -236,12 +249,14 @@ def __init__( self.camera = OpenGLCamera() self.pressed_keys = set() + self.window = None + # Initialize texture map. self.path_to_texture_id = {} self.background_color = config["background_color"] - def init_scene(self, scene): + def init_scene(self, scene: Scene) -> None: self.partial_movie_files = [] self.file_writer: Any = self._file_writer_class( self, @@ -249,32 +264,31 @@ def init_scene(self, scene): ) self.scene = scene self.background_color = config["background_color"] - if not hasattr(self, "window"): - if self.should_create_window(): - from .opengl_renderer_window import Window + if self.should_create_window(): + from .opengl_renderer_window import Window - self.window = Window(self) - self.context = self.window.ctx - self.frame_buffer_object = self.context.detect_framebuffer() - else: - self.window = None - try: - self.context = moderngl.create_context(standalone=True) - except Exception: - self.context = moderngl.create_context( - standalone=True, - backend="egl", - ) - self.frame_buffer_object = self.get_frame_buffer_object(self.context, 0) - self.frame_buffer_object.use() - self.context.enable(moderngl.BLEND) - self.context.wireframe = config["enable_wireframe"] - self.context.blend_func = ( - moderngl.SRC_ALPHA, - moderngl.ONE_MINUS_SRC_ALPHA, - moderngl.ONE, - moderngl.ONE, - ) + self.window = Window(self) + self.context = self.window.ctx + self.frame_buffer_object = self.context.detect_framebuffer() + else: + # self.window = None + try: + self.context = moderngl.create_context(standalone=True) + except Exception: + self.context = moderngl.create_context( + standalone=True, + backend="egl", + ) + self.frame_buffer_object = self.get_frame_buffer_object(self.context, 0) + self.frame_buffer_object.use() + self.context.enable(moderngl.BLEND) + self.context.wireframe = config["enable_wireframe"] + self.context.blend_func = ( + moderngl.SRC_ALPHA, + moderngl.ONE_MINUS_SRC_ALPHA, + moderngl.ONE, + moderngl.ONE, + ) def should_create_window(self): if config["force_window"]: @@ -412,7 +426,12 @@ def update_skipping_status(self) -> None: raise EndSceneEarlyException() @handle_caching_play - def play(self, scene, *args, **kwargs): + def play( + self, + scene: Scene, + *args: Animation | Mobject | _AnimationBuilder, + **kwargs: Any, + ) -> None: # TODO: Handle data locking / unlocking. self.animation_start_time = time.time() self.file_writer.begin_animation(not self.skip_animations) @@ -440,11 +459,13 @@ def play(self, scene, *args, **kwargs): self.time += scene.duration self.num_plays += 1 - def clear_screen(self): + def clear_screen(self) -> None: self.frame_buffer_object.clear(*self.background_color) self.window.swap_buffers() - def render(self, scene, frame_offset, moving_mobjects): + def render( + self, scene: Scene, frame_offset, moving_mobjects: list[Mobject] + ) -> None: self.update_frame(scene) if self.skip_animations: @@ -566,7 +587,9 @@ def get_frame(self): # Returns offset from the bottom left corner in pixels. # top_left flag should be set to True when using a GUI framework # where the (0,0) is at the top left: e.g. PySide6 - def pixel_coords_to_space_coords(self, px, py, relative=False, top_left=False): + def pixel_coords_to_space_coords( + self, px: float, py: float, relative: bool = False, top_left: bool = False + ) -> Point3D: pixel_shape = self.get_pixel_shape() if pixel_shape is None: return np.array([0, 0, 0]) diff --git a/manim/renderer/opengl_renderer_window.py b/manim/renderer/opengl_renderer_window.py index 610f61646b..4472ba6c2a 100644 --- a/manim/renderer/opengl_renderer_window.py +++ b/manim/renderer/opengl_renderer_window.py @@ -1,12 +1,17 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any + import moderngl_window as mglw from moderngl_window.context.pyglet.window import Window as PygletWindow from moderngl_window.timers.clock import Timer -from screeninfo import get_monitors +from screeninfo import Monitor, get_monitors from .. import __version__, config +if TYPE_CHECKING: + from .opengl_renderer import OpenGLRenderer + __all__ = ["Window"] @@ -17,15 +22,19 @@ class Window(PygletWindow): vsync = True cursor = True - def __init__(self, renderer, size=config.window_size, **kwargs): + def __init__( + self, + renderer: OpenGLRenderer, + window_size: str = config.window_size, + **kwargs: Any, + ) -> None: monitors = get_monitors() mon_index = config.window_monitor monitor = monitors[min(mon_index, len(monitors) - 1)] - if size == "default": + if window_size == "default": # make window_width half the width of the monitor # but make it full screen if --fullscreen - window_width = monitor.width if not config.fullscreen: window_width //= 2 @@ -35,8 +44,13 @@ def __init__(self, renderer, size=config.window_size, **kwargs): window_width * config.frame_height // config.frame_width, ) size = (window_width, window_height) + elif len(window_size.split(",")) == 2: + (window_width, window_height) = tuple(map(int, window_size.split(","))) + size = (window_width, window_height) else: - size = tuple(size) + raise ValueError( + "Window_size must be specified as 'width,height' or 'default'.", + ) super().__init__(size=size) @@ -55,13 +69,13 @@ def __init__(self, renderer, size=config.window_size, **kwargs): self.position = initial_position # Delegate event handling to scene. - def on_mouse_motion(self, x, y, dx, dy): + def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> None: super().on_mouse_motion(x, y, dx, dy) point = self.renderer.pixel_coords_to_space_coords(x, y) d_point = self.renderer.pixel_coords_to_space_coords(dx, dy, relative=True) self.renderer.scene.on_mouse_motion(point, d_point) - def on_mouse_scroll(self, x, y, x_offset: float, y_offset: float): + def on_mouse_scroll(self, x: int, y: int, x_offset: float, y_offset: float) -> None: super().on_mouse_scroll(x, y, x_offset, y_offset) point = self.renderer.pixel_coords_to_space_coords(x, y) offset = self.renderer.pixel_coords_to_space_coords( @@ -71,28 +85,32 @@ def on_mouse_scroll(self, x, y, x_offset: float, y_offset: float): ) self.renderer.scene.on_mouse_scroll(point, offset) - def on_key_press(self, symbol, modifiers): + def on_key_press(self, symbol: int, modifiers: int) -> bool: self.renderer.pressed_keys.add(symbol) - super().on_key_press(symbol, modifiers) + event_handled: bool = super().on_key_press(symbol, modifiers) self.renderer.scene.on_key_press(symbol, modifiers) + return event_handled - def on_key_release(self, symbol, modifiers): + def on_key_release(self, symbol: int, modifiers: int) -> None: if symbol in self.renderer.pressed_keys: self.renderer.pressed_keys.remove(symbol) super().on_key_release(symbol, modifiers) self.renderer.scene.on_key_release(symbol, modifiers) - def on_mouse_drag(self, x, y, dx, dy, buttons, modifiers): + def on_mouse_drag( + self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int + ) -> None: super().on_mouse_drag(x, y, dx, dy, buttons, modifiers) point = self.renderer.pixel_coords_to_space_coords(x, y) d_point = self.renderer.pixel_coords_to_space_coords(dx, dy, relative=True) self.renderer.scene.on_mouse_drag(point, d_point, buttons, modifiers) - def find_initial_position(self, size, monitor): + def find_initial_position( + self, size: tuple[int, int], monitor: Monitor + ) -> tuple[int, int]: custom_position = config.window_position window_width, window_height = size - # Position might be specified with a string of the form - # x,y for integers x and y + # Position might be specified with a string of the form x,y for integers x and y if len(custom_position) == 1: raise ValueError( "window_position must specify both Y and X positions (Y/X -> UR). Also accepts LEFT/RIGHT/ORIGIN/UP/DOWN.", @@ -105,20 +123,21 @@ def find_initial_position(self, size, monitor): elif custom_position == "ORIGIN": custom_position = "O" * 2 elif "," in custom_position: - return tuple(map(int, custom_position.split(","))) + pos_y, pos_x = tuple(map(int, custom_position.split(","))) + return (pos_x, pos_y) # Alternatively, it might be specified with a string like # UR, OO, DL, etc. specifying what corner it should go to char_to_n = {"L": 0, "U": 0, "O": 1, "R": 2, "D": 2} - width_diff = monitor.width - window_width - height_diff = monitor.height - window_height + width_diff: int = monitor.width - window_width + height_diff: int = monitor.height - window_height return ( monitor.x + char_to_n[custom_position[1]] * width_diff // 2, -monitor.y + char_to_n[custom_position[0]] * height_diff // 2, ) - def on_mouse_press(self, x, y, button, modifiers): + def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> None: super().on_mouse_press(x, y, button, modifiers) point = self.renderer.pixel_coords_to_space_coords(x, y) mouse_button_map = { diff --git a/manim/renderer/shader.py b/manim/renderer/shader.py index a098ed30ca..0cdf62a2d7 100644 --- a/manim/renderer/shader.py +++ b/manim/renderer/shader.py @@ -4,17 +4,30 @@ import inspect import re import textwrap +from collections.abc import Callable, Iterator, Sequence from pathlib import Path +from typing import TYPE_CHECKING, Any import moderngl import numpy as np +import numpy.typing as npt +from typing_extensions import Self, TypeAlias + +if TYPE_CHECKING: + from manim.renderer.opengl_renderer import OpenGLRenderer + + MeshTimeBasedUpdater: TypeAlias = Callable[["Object3D", float], None] + MeshNonTimeBasedUpdater: TypeAlias = Callable[["Object3D"], None] + MeshUpdater: TypeAlias = MeshNonTimeBasedUpdater | MeshTimeBasedUpdater + +from manim.typing import MatrixMN, Point3D from .. import config from ..utils import opengl SHADER_FOLDER = Path(__file__).parent / "shaders" -shader_program_cache: dict = {} -file_path_to_code_map: dict = {} +shader_program_cache: dict[str, moderngl.Program] = {} +file_path_to_code_map: dict[Path, str] = {} __all__ = [ "Object3D", @@ -43,7 +56,9 @@ def get_shader_code_from_file(file_path: Path) -> str: return source -def filter_attributes(unfiltered_attributes, attributes): +def filter_attributes( + unfiltered_attributes: npt.NDArray, attributes: Sequence[str] +) -> npt.NDArray: # Construct attributes for only those needed by the shader. filtered_attributes_dtype = [] for i, dtype_name in enumerate(unfiltered_attributes.dtype.names): @@ -69,28 +84,28 @@ def filter_attributes(unfiltered_attributes, attributes): class Object3D: - def __init__(self, *children): + def __init__(self, *children: Object3D): self.model_matrix = np.eye(4) self.normal_matrix = np.eye(4) - self.children = [] - self.parent = None + self.children: list[Object3D] = [] + self.parent: Object3D | None = None self.add(*children) self.init_updaters() # TODO: Use path_func. - def interpolate(self, start, end, alpha, _): + def interpolate(self, start: Object3D, end: Object3D, alpha: float, _: Any) -> None: self.model_matrix = (1 - alpha) * start.model_matrix + alpha * end.model_matrix self.normal_matrix = ( 1 - alpha ) * start.normal_matrix + alpha * end.normal_matrix - def single_copy(self): + def single_copy(self) -> Object3D: copy = Object3D() copy.model_matrix = self.model_matrix.copy() copy.normal_matrix = self.normal_matrix.copy() return copy - def copy(self): + def copy(self) -> Object3D: node_to_copy = {} bfs = [self] @@ -106,7 +121,7 @@ def copy(self): node_to_copy[node.parent].add(node_copy) return node_to_copy[self] - def add(self, *children): + def add(self, *children: Object3D) -> None: for child in children: if child.parent is not None: raise Exception( @@ -117,7 +132,7 @@ def add(self, *children): for child in children: child.parent = self - def remove(self, *children, current_children_only=True): + def remove(self, *children: Object3D, current_children_only: bool = True) -> None: if current_children_only: for child in children: if child.parent != self: @@ -128,14 +143,14 @@ def remove(self, *children, current_children_only=True): for child in children: child.parent = None - def get_position(self): + def get_position(self) -> Point3D: return self.model_matrix[:, 3][:3] - def set_position(self, position): + def set_position(self, position: Point3D) -> Self: self.model_matrix[:, 3][:3] = position return self - def get_meshes(self): + def get_meshes(self) -> Iterator[Mesh]: dfs = [self] while dfs: parent = dfs.pop() @@ -143,17 +158,17 @@ def get_meshes(self): yield parent dfs.extend(parent.children) - def get_family(self): + def get_family(self) -> Iterator[Object3D]: dfs = [self] while dfs: parent = dfs.pop() yield parent dfs.extend(parent.children) - def align_data_and_family(self, _): + def align_data_and_family(self, _: Any) -> None: pass - def hierarchical_model_matrix(self): + def hierarchical_model_matrix(self) -> MatrixMN: if self.parent is None: return self.model_matrix @@ -164,7 +179,7 @@ def hierarchical_model_matrix(self): current_object = current_object.parent return np.linalg.multi_dot(list(reversed(model_matrices))) - def hierarchical_normal_matrix(self): + def hierarchical_normal_matrix(self) -> MatrixMN: if self.parent is None: return self.normal_matrix[:3, :3] @@ -175,76 +190,93 @@ def hierarchical_normal_matrix(self): current_object = current_object.parent return np.linalg.multi_dot(list(reversed(normal_matrices)))[:3, :3] - def init_updaters(self): - self.time_based_updaters = [] - self.non_time_updaters = [] + def init_updaters(self) -> None: + self.time_based_updaters: list[MeshTimeBasedUpdater] = [] + self.non_time_updaters: list[MeshNonTimeBasedUpdater] = [] self.has_updaters = False self.updating_suspended = False - def update(self, dt=0): + def update(self, dt: float = 0) -> Self: if not self.has_updaters or self.updating_suspended: return self - for updater in self.time_based_updaters: - updater(self, dt) - for updater in self.non_time_updaters: - updater(self) + for time_based_updater in self.time_based_updaters: + time_based_updater(self, dt) + for non_time_based_updater in self.non_time_updaters: + non_time_based_updater(self) return self - def get_time_based_updaters(self): + def get_time_based_updaters(self) -> list[MeshTimeBasedUpdater]: return self.time_based_updaters - def has_time_based_updater(self): + def has_time_based_updater(self) -> bool: return len(self.time_based_updaters) > 0 - def get_updaters(self): + def get_updaters(self) -> list[MeshUpdater]: return self.time_based_updaters + self.non_time_updaters - def add_updater(self, update_function, index=None, call_updater=True): + def add_updater( + self, + update_function: MeshUpdater, + index: int | None = None, + call_updater: bool = True, + ) -> Self: if "dt" in inspect.signature(update_function).parameters: - updater_list = self.time_based_updaters + self._add_time_based_updater(update_function, index) # type: ignore[arg-type] else: - updater_list = self.non_time_updaters - - if index is None: - updater_list.append(update_function) - else: - updater_list.insert(index, update_function) + self._add_non_time_updater(update_function, index) # type: ignore[arg-type] self.refresh_has_updater_status() if call_updater: self.update() return self - def remove_updater(self, update_function): - for updater_list in [self.time_based_updaters, self.non_time_updaters]: - while update_function in updater_list: - updater_list.remove(update_function) + def _add_time_based_updater( + self, update_function: MeshTimeBasedUpdater, index: int | None = None + ) -> None: + if index is None: + self.time_based_updaters.append(update_function) + else: + self.time_based_updaters.insert(index, update_function) + + def _add_non_time_updater( + self, update_function: MeshNonTimeBasedUpdater, index: int | None = None + ) -> None: + if index is None: + self.non_time_updaters.append(update_function) + else: + self.non_time_updaters.insert(index, update_function) + + def remove_updater(self, update_function: MeshUpdater) -> Self: + while update_function in self.time_based_updaters: + self.time_based_updaters.remove(update_function) # type: ignore[arg-type] + while update_function in self.non_time_updaters: + self.non_time_updaters.remove(update_function) # type: ignore[arg-type] self.refresh_has_updater_status() return self - def clear_updaters(self): + def clear_updaters(self) -> Self: self.time_based_updaters = [] self.non_time_updaters = [] self.refresh_has_updater_status() return self - def match_updaters(self, mobject): + def match_updaters(self, mesh: Object3D) -> Self: self.clear_updaters() - for updater in mobject.get_updaters(): + for updater in mesh.get_updaters(): self.add_updater(updater) return self - def suspend_updating(self): + def suspend_updating(self) -> Self: self.updating_suspended = True return self - def resume_updating(self, call_updater=True): + def resume_updating(self, call_updater: bool = True) -> Self: self.updating_suspended = False if call_updater: self.update(dt=0) return self - def refresh_has_updater_status(self): + def refresh_has_updater_status(self) -> Self: self.has_updaters = len(self.get_updaters()) > 0 return self @@ -252,23 +284,23 @@ def refresh_has_updater_status(self): class Mesh(Object3D): def __init__( self, - shader=None, - attributes=None, - geometry=None, - material=None, - indices=None, - use_depth_test=True, - primitive=moderngl.TRIANGLES, + shader: Shader | None = None, + attributes: npt.NDArray | None = None, + geometry: Mesh | None = None, + material: Shader | None = None, + indices: npt.NDArray | None = None, + use_depth_test: bool = True, + primitive: int = moderngl.TRIANGLES, ): super().__init__() if shader is not None and attributes is not None: - self.shader = shader + self.shader: Shader = shader self.attributes = attributes self.indices = indices elif geometry is not None and material is not None: self.shader = material self.attributes = geometry.attributes - self.indices = geometry.index + self.indices = geometry.indices else: raise Exception( "Mesh requires either attributes and a Shader or a Geometry and a " @@ -276,10 +308,10 @@ def __init__( ) self.use_depth_test = use_depth_test self.primitive = primitive - self.skip_render = False + self.skip_render: bool = False self.init_updaters() - def single_copy(self): + def single_copy(self) -> Mesh: copy = Mesh( attributes=self.attributes.copy(), shader=self.shader, @@ -293,7 +325,7 @@ def single_copy(self): # TODO: Copy updaters? return copy - def set_uniforms(self, renderer): + def set_uniforms(self, renderer: OpenGLRenderer) -> None: self.shader.set_uniform( "u_model_matrix", opengl.matrix_to_shader_input(self.model_matrix), @@ -304,7 +336,7 @@ def set_uniforms(self, renderer): renderer.camera.projection_matrix, ) - def render(self): + def render(self) -> None: if self.skip_render: return @@ -313,15 +345,17 @@ def render(self): else: self.shader.context.disable(moderngl.DEPTH_TEST) - from moderngl import Attribute - - shader_attributes = [] - for k, v in self.shader.shader_program._members.items(): - if isinstance(v, Attribute): - shader_attributes.append(k) - shader_attributes = filter_attributes(self.attributes, shader_attributes) + shader_attribute_names: list[str] = [] + for member_name, member in self.shader.shader_program._members.items(): + if isinstance(member, moderngl.Attribute): + shader_attribute_names.append(member_name) + filtered_shader_attributes = filter_attributes( + self.attributes, shader_attribute_names + ) - vertex_buffer_object = self.shader.context.buffer(shader_attributes.tobytes()) + vertex_buffer_object = self.shader.context.buffer( + filtered_shader_attributes.tobytes() + ) if self.indices is None: index_buffer_object = None else: @@ -333,7 +367,7 @@ def render(self): vertex_array_object = self.shader.context.simple_vertex_array( self.shader.shader_program, vertex_buffer_object, - *shader_attributes.dtype.names, + *filtered_shader_attributes.dtype.names, index_buffer=index_buffer_object, ) vertex_array_object.render(self.primitive) @@ -346,13 +380,14 @@ def render(self): class Shader: def __init__( self, - context, - name=None, - source=None, + context: moderngl.Context, + name: str | None = None, + source: dict[str, Any] | None = None, ): global shader_program_cache self.context = context self.name = name + self.source = source # See if the program is cached. if ( @@ -360,10 +395,10 @@ def __init__( and shader_program_cache[self.name].ctx == self.context ): self.shader_program = shader_program_cache[self.name] - elif source is not None: + elif self.source is not None: # Generate the shader from inline code if it was passed. - self.shader_program = context.program(**source) - else: + self.shader_program = context.program(**self.source) + elif self.name is not None: # Search for a file containing the shader. source_dict = {} source_dict_key = { @@ -371,18 +406,20 @@ def __init__( "frag": "fragment_shader", "geom": "geometry_shader", } - shader_folder = SHADER_FOLDER / name + shader_folder = SHADER_FOLDER / self.name for shader_file in shader_folder.iterdir(): shader_file_path = shader_folder / shader_file shader_source = get_shader_code_from_file(shader_file_path) source_dict[source_dict_key[shader_file_path.stem]] = shader_source self.shader_program = context.program(**source_dict) + else: + raise Exception("Must either pass shader name or shader source.") # Cache the shader. - if name is not None and name not in shader_program_cache: + if self.name is not None and self.name not in shader_program_cache: shader_program_cache[self.name] = self.shader_program - def set_uniform(self, name, value): + def set_uniform(self, name: str, value: Any) -> None: with contextlib.suppress(KeyError): self.shader_program[name] = value @@ -390,9 +427,9 @@ def set_uniform(self, name, value): class FullScreenQuad(Mesh): def __init__( self, - context, - fragment_shader_source=None, - fragment_shader_name=None, + context: moderngl.Context, + fragment_shader_source: str | None = None, + fragment_shader_name: str | None = None, ): if fragment_shader_source is None and fragment_shader_name is None: raise Exception("Must either pass shader name or shader source.") @@ -439,5 +476,5 @@ def __init__( ) super().__init__(shader, attributes) - def render(self): + def render(self) -> None: super().render() diff --git a/manim/renderer/vectorized_mobject_rendering.py b/manim/renderer/vectorized_mobject_rendering.py index 4245d65b0e..f4c85b05d6 100644 --- a/manim/renderer/vectorized_mobject_rendering.py +++ b/manim/renderer/vectorized_mobject_rendering.py @@ -1,9 +1,18 @@ from __future__ import annotations -import collections +from collections import defaultdict +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING import numpy as np +if TYPE_CHECKING: + from manim.renderer.opengl_renderer import ( + OpenGLRenderer, + OpenGLVMobject, + ) + from manim.typing import MatrixMN + from ..utils import opengl from ..utils.space_ops import cross2d, earclip_triangulation from .shader import Shader @@ -14,9 +23,11 @@ ] -def build_matrix_lists(mob): +def build_matrix_lists( + mob: OpenGLVMobject, +) -> defaultdict[tuple[float, ...], list[OpenGLVMobject]]: root_hierarchical_matrix = mob.hierarchical_model_matrix() - matrix_to_mobject_list = collections.defaultdict(list) + matrix_to_mobject_list = defaultdict(list) if mob.has_points(): matrix_to_mobject_list[tuple(root_hierarchical_matrix.ravel())].append(mob) mobject_to_hierarchical_matrix = {mob: root_hierarchical_matrix} @@ -36,7 +47,9 @@ def build_matrix_lists(mob): return matrix_to_mobject_list -def render_opengl_vectorized_mobject_fill(renderer, mobject): +def render_opengl_vectorized_mobject_fill( + renderer: OpenGLRenderer, mobject: OpenGLVMobject +) -> None: matrix_to_mobject_list = build_matrix_lists(mobject) for matrix_tuple, mobject_list in matrix_to_mobject_list.items(): @@ -44,7 +57,11 @@ def render_opengl_vectorized_mobject_fill(renderer, mobject): render_mobject_fills_with_matrix(renderer, model_matrix, mobject_list) -def render_mobject_fills_with_matrix(renderer, model_matrix, mobjects): +def render_mobject_fills_with_matrix( + renderer: OpenGLRenderer, + model_matrix: MatrixMN, + mobjects: Iterable[OpenGLVMobject], +) -> None: # Precompute the total number of vertices for which to reserve space. # Note that triangulate_mobject() will cache its results. total_size = 0 @@ -84,7 +101,7 @@ def render_mobject_fills_with_matrix(renderer, model_matrix, mobjects): ) fill_shader.set_uniform( "u_projection_matrix", - renderer.scene.camera.projection_matrix, + renderer.camera.projection_matrix, ) vbo = renderer.context.buffer(attributes.tobytes()) @@ -98,7 +115,7 @@ def render_mobject_fills_with_matrix(renderer, model_matrix, mobjects): vbo.release() -def triangulate_mobject(mob): +def triangulate_mobject(mob: OpenGLVMobject) -> np.ndarray: if not mob.needs_new_triangulation: return mob.triangulation @@ -192,14 +209,20 @@ def triangulate_mobject(mob): return attributes -def render_opengl_vectorized_mobject_stroke(renderer, mobject): +def render_opengl_vectorized_mobject_stroke( + renderer: OpenGLRenderer, mobject: OpenGLVMobject +) -> None: matrix_to_mobject_list = build_matrix_lists(mobject) for matrix_tuple, mobject_list in matrix_to_mobject_list.items(): model_matrix = np.array(matrix_tuple).reshape((4, 4)) render_mobject_strokes_with_matrix(renderer, model_matrix, mobject_list) -def render_mobject_strokes_with_matrix(renderer, model_matrix, mobjects): +def render_mobject_strokes_with_matrix( + renderer: OpenGLRenderer, + model_matrix: MatrixMN, + mobjects: Sequence[OpenGLVMobject], +) -> None: # Precompute the total number of vertices for which to reserve space. total_size = 0 for submob in mobjects: @@ -279,7 +302,7 @@ def render_mobject_strokes_with_matrix(renderer, model_matrix, mobjects): renderer.camera.unformatted_view_matrix @ model_matrix, ), ) - shader.set_uniform("u_projection_matrix", renderer.scene.camera.projection_matrix) + shader.set_uniform("u_projection_matrix", renderer.camera.projection_matrix) shader.set_uniform("manim_unit_normal", tuple(-mobjects[0].unit_normal[0])) vbo = renderer.context.buffer(stroke_data.tobytes()) diff --git a/manim/scene/moving_camera_scene.py b/manim/scene/moving_camera_scene.py index eafc992ef5..7487f05b0d 100644 --- a/manim/scene/moving_camera_scene.py +++ b/manim/scene/moving_camera_scene.py @@ -89,8 +89,12 @@ def create_scene(number): __all__ = ["MovingCameraScene"] +from typing import Any + from manim.animation.animation import Animation +from manim.mobject.mobject import Mobject +from ..camera.camera import Camera from ..camera.moving_camera import MovingCamera from ..scene.scene import Scene from ..utils.family import extract_mobject_family_members @@ -111,10 +115,12 @@ class MovingCameraScene(Scene): :class:`.MovingCamera` """ - def __init__(self, camera_class=MovingCamera, **kwargs): + def __init__( + self, camera_class: type[Camera] = MovingCamera, **kwargs: Any + ) -> None: super().__init__(camera_class=camera_class, **kwargs) - def get_moving_mobjects(self, *animations: Animation): + def get_moving_mobjects(self, *animations: Animation) -> list[Mobject]: """ This method returns a list of all of the Mobjects in the Scene that are moving, that are also in the animations passed. @@ -126,7 +132,7 @@ def get_moving_mobjects(self, *animations: Animation): """ moving_mobjects = super().get_moving_mobjects(*animations) all_moving_mobjects = extract_mobject_family_members(moving_mobjects) - movement_indicators = self.renderer.camera.get_mobjects_indicating_movement() + movement_indicators = self.renderer.camera.get_mobjects_indicating_movement() # type: ignore[union-attr] for movement_indicator in movement_indicators: if movement_indicator in all_moving_mobjects: # When one of these is moving, the camera should diff --git a/manim/scene/scene.py b/manim/scene/scene.py index fc3d3ede54..1a9c1d82cb 100644 --- a/manim/scene/scene.py +++ b/manim/scene/scene.py @@ -4,6 +4,8 @@ from manim.utils.parameter_parsing import flatten_iterable_parameters +from ..mobject.mobject import _AnimationBuilder + __all__ = ["Scene"] import copy @@ -13,7 +15,8 @@ import random import threading import time -import types +from dataclasses import dataclass +from pathlib import Path from queue import Queue import srt @@ -24,15 +27,21 @@ import dearpygui.dearpygui as dpg dearpygui_imported = True + dpg.create_context() + window = dpg.generate_uuid() except ImportError: dearpygui_imported = False -from typing import TYPE_CHECKING + +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Union import numpy as np from tqdm import tqdm -from watchdog.events import FileSystemEventHandler +from watchdog.events import DirModifiedEvent, FileModifiedEvent, FileSystemEventHandler from watchdog.observers import Observer +from manim import __version__ +from manim.data_structures import MethodWithArgs from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_mobject import OpenGLPoint @@ -40,9 +49,8 @@ from ..animation.animation import Animation, Wait, prepare_animation from ..camera.camera import Camera from ..constants import * -from ..gui.gui import configure_pygui from ..renderer.cairo_renderer import CairoRenderer -from ..renderer.opengl_renderer import OpenGLRenderer +from ..renderer.opengl_renderer import OpenGLCamera, OpenGLMobject, OpenGLRenderer from ..renderer.shader import Object3D from ..utils import opengl, space_ops from ..utils.exceptions import EndSceneEarlyException, RerunSceneException @@ -50,23 +58,83 @@ from ..utils.family_ops import restructure_list_to_exclude_certain_family_members from ..utils.file_ops import open_media_file from ..utils.iterables import list_difference_update, list_update +from ..utils.module_ops import scene_classes_from_file if TYPE_CHECKING: - from collections.abc import Sequence - from typing import Callable + from types import FrameType + + from typing_extensions import Self, TypeAlias + + from manim.typing import Point3D + + SceneInteractAction: TypeAlias = Union[ + MethodWithArgs, "SceneInteractContinue", "SceneInteractRerun" + ] + """The SceneInteractAction type alias is used for elements in the queue + used by :meth:`.Scene.interact()`. + + The elements can be one of the following three: + + - a :class:`~.MethodWithArgs` object, which represents a :class:`Scene` + method to be called along with its args and kwargs, + - a :class:`~.SceneInteractContinue` object, indicating that the scene + interaction is over and the scene will continue rendering after that, or + - a :class:`~.SceneInteractRerun` object, indicating that the scene should + render again. + """ + + +@dataclass +class SceneInteractContinue: + """Object which, when encountered in :meth:`.Scene.interact`, triggers + the end of the scene interaction, continuing with the rest of the + animations, if any. This object can be queued in :attr:`.Scene.queue` + for later use in :meth:`.Scene.interact`. + + Attributes + ---------- + sender : str + The name of the entity which issued the end of the scene interaction, + such as ``"gui"`` or ``"keyboard"``. + """ + + __slots__ = ["sender"] + + sender: str + + +class SceneInteractRerun: + """Object which, when encountered in :meth:`.Scene.interact`, triggers + the rerun of the scene. This object can be queued in :attr:`.Scene.queue` + for later use in :meth:`.Scene.interact`. + + Attributes + ---------- + sender : str + The name of the entity which issued the rerun of the scene, such as + ``"gui"``, ``"keyboard"``, ``"play"`` or ``"file"``. + kwargs : dict[str, Any] + Additional keyword arguments when rerunning the scene. Currently, + only ``"from_animation_number"`` is being used, which determines the + animation from which to start rerunning the scene. + """ - from manim.mobject.mobject import _AnimationBuilder + __slots__ = ["sender", "kwargs"] + + def __init__(self, sender: str, **kwargs: Any) -> None: + self.sender = sender + self.kwargs = kwargs class RerunSceneHandler(FileSystemEventHandler): """A class to handle rerunning a Scene after the input file is modified.""" - def __init__(self, queue): + def __init__(self, queue: Queue[SceneInteractAction]) -> None: super().__init__() self.queue = queue - def on_modified(self, event): - self.queue.put(("rerun_file", [], {})) + def on_modified(self, event: DirModifiedEvent | FileModifiedEvent) -> None: + self.queue.put(SceneInteractRerun("file")) class Scene: @@ -113,24 +181,22 @@ def __init__( self.random_seed = random_seed self.skip_animations = skip_animations - self.animations = None - self.stop_condition = None - self.moving_mobjects = [] - self.static_mobjects = [] - self.time_progression = None - self.duration = None - self.last_t = None - self.queue = Queue() + self.animations: list[Animation] | None = None + self.stop_condition: Callable[[], bool] | None = None + self.moving_mobjects: list[Mobject] = [] + self.static_mobjects: list[Mobject] = [] + self.time_progression: tqdm[float] | None = None + self.duration: float | None = None + self.last_t = 0.0 + self.queue: Queue[SceneInteractAction] = Queue() self.skip_animation_preview = False - self.meshes = [] + self.meshes: list[Object3D] = [] self.camera_target = ORIGIN - self.widgets = [] + self.widgets: list[dict[str, Any]] = [] self.dearpygui_imported = dearpygui_imported - self.updaters = [] - self.point_lights = [] - self.ambient_light = None - self.key_to_function_map = {} - self.mouse_press_callbacks = [] + self.updaters: list[Callable[[float], None]] = [] + self.key_to_function_map: dict[str, Callable[[], None]] = {} + self.mouse_press_callbacks: list[Callable[[], None]] = [] self.interactive_mode = False if config.renderer == RendererType.OPENGL: @@ -141,7 +207,9 @@ def __init__( renderer = OpenGLRenderer() if renderer is None: - self.renderer = CairoRenderer( + self.renderer: CairoRenderer | OpenGLRenderer = CairoRenderer( + # TODO: Is it a suitable approach to make an instance of + # the self.camera_class here? camera_class=self.camera_class, skip_animations=self.skip_animations, ) @@ -149,15 +217,15 @@ def __init__( self.renderer = renderer self.renderer.init_scene(self) - self.mobjects = [] + self.mobjects: list[Mobject] = [] # TODO, remove need for foreground mobjects - self.foreground_mobjects = [] + self.foreground_mobjects: list[Mobject] = [] if self.random_seed is not None: random.seed(self.random_seed) np.random.seed(self.random_seed) @property - def camera(self): + def camera(self) -> Camera | OpenGLCamera: return self.renderer.camera @property @@ -165,7 +233,7 @@ def time(self) -> float: """The time since the start of the scene.""" return self.renderer.time - def __deepcopy__(self, clone_from_id): + def __deepcopy__(self, clone_from_id: dict[int, Any]) -> Scene: cls = self.__class__ result = cls.__new__(cls) clone_from_id[id(self)] = result @@ -175,55 +243,10 @@ def __deepcopy__(self, clone_from_id): if k == "camera_class": setattr(result, k, v) setattr(result, k, copy.deepcopy(v, clone_from_id)) - result.mobject_updater_lists = [] - - # Update updaters - for mobject in self.mobjects: - cloned_updaters = [] - for updater in mobject.updaters: - # Make the cloned updater use the cloned Mobjects as free variables - # rather than the original ones. Analyzing function bytecode with the - # dis module will help in understanding this. - # https://docs.python.org/3/library/dis.html - # TODO: Do the same for function calls recursively. - free_variable_map = inspect.getclosurevars(updater).nonlocals - cloned_co_freevars = [] - cloned_closure = [] - for free_variable_name in updater.__code__.co_freevars: - free_variable_value = free_variable_map[free_variable_name] - - # If the referenced variable has not been cloned, raise. - if id(free_variable_value) not in clone_from_id: - raise Exception( - f"{free_variable_name} is referenced from an updater " - "but is not an attribute of the Scene, which isn't " - "allowed.", - ) - - # Add the cloned object's name to the free variable list. - cloned_co_freevars.append(free_variable_name) - - # Add a cell containing the cloned object's reference to the - # closure list. - cloned_closure.append( - types.CellType(clone_from_id[id(free_variable_value)]), - ) - cloned_updater = types.FunctionType( - updater.__code__.replace(co_freevars=tuple(cloned_co_freevars)), - updater.__globals__, - updater.__name__, - updater.__defaults__, - tuple(cloned_closure), - ) - cloned_updaters.append(cloned_updater) - mobject_clone = clone_from_id[id(mobject)] - mobject_clone.updaters = cloned_updaters - if len(cloned_updaters) > 0: - result.mobject_updater_lists.append((mobject_clone, cloned_updaters)) return result - def render(self, preview: bool = False): + def render(self, preview: bool = False) -> bool: """ Renders this Scene. @@ -239,7 +262,8 @@ def render(self, preview: bool = False): pass except RerunSceneException: self.remove(*self.mobjects) - self.renderer.clear_screen() + # TODO: The CairoRenderer does not have the method clear_screen() + self.renderer.clear_screen() # type: ignore[union-attr] self.renderer.num_plays = 0 return True self.tear_down() @@ -263,7 +287,9 @@ def render(self, preview: bool = False): if config["preview"] or config["show_in_file_browser"]: open_media_file(self.renderer.file_writer) - def setup(self): + return False + + def setup(self) -> None: """ This is meant to be implemented by any scenes which are commonly subclassed, and have some common setup @@ -271,7 +297,7 @@ def setup(self): """ pass - def tear_down(self): + def tear_down(self) -> None: """ This is meant to be implemented by any scenes which are commonly subclassed, and have some common method @@ -279,7 +305,7 @@ def tear_down(self): """ pass - def construct(self): + def construct(self) -> None: """Add content to the Scene. From within :meth:`Scene.construct`, display mobjects on screen by calling @@ -324,10 +350,10 @@ def next_section( """ self.renderer.file_writer.next_section(name, section_type, skip_animations) - def __str__(self): + def __str__(self) -> str: return self.__class__.__name__ - def get_attrs(self, *keys: str): + def get_attrs(self, *keys: str) -> list[Any]: """ Gets attributes of a scene given the attribute's identifier/name. @@ -343,7 +369,7 @@ def get_attrs(self, *keys: str): """ return [getattr(self, key) for key in keys] - def update_mobjects(self, dt: float): + def update_mobjects(self, dt: float) -> None: """ Begins updating all mobjects in the Scene. @@ -352,15 +378,15 @@ def update_mobjects(self, dt: float): dt Change in time between updates. Defaults (mostly) to 1/frames_per_second """ - for mobject in self.mobjects: - mobject.update(dt) + for mobj in self.mobjects: + mobj.update(dt) - def update_meshes(self, dt): + def update_meshes(self, dt: float) -> None: for obj in self.meshes: for mesh in obj.get_family(): mesh.update(dt) - def update_self(self, dt: float): + def update_self(self, dt: float) -> None: """Run all scene updater functions. Among all types of update functions (mobject updaters, mesh updaters, @@ -392,7 +418,9 @@ def should_update_mobjects(self) -> bool: This is only called when a single Wait animation is played. """ + assert self.animations is not None wait_animation = self.animations[0] + assert isinstance(wait_animation, Wait) if wait_animation.is_static_wait is None: should_update = ( self.always_update_mobjects @@ -406,7 +434,7 @@ def should_update_mobjects(self) -> bool: wait_animation.is_static_wait = not should_update return not wait_animation.is_static_wait - def get_top_level_mobjects(self): + def get_top_level_mobjects(self) -> list[Mobject]: """ Returns all mobjects which are not submobjects. @@ -419,13 +447,13 @@ def get_top_level_mobjects(self): # of another mobject from the scene families = [m.get_family() for m in self.mobjects] - def is_top_level(mobject): + def is_top_level(mobject: Mobject) -> bool: num_families = sum((mobject in family) for family in families) return num_families == 1 return list(filter(is_top_level, self.mobjects)) - def get_mobject_family_members(self): + def get_mobject_family_members(self) -> list[Mobject]: """ Returns list of family-members of all mobjects in scene. If a Circle() and a VGroup(Rectangle(),Triangle()) were added, @@ -442,13 +470,14 @@ def get_mobject_family_members(self): for mob in self.mobjects: family_members.extend(mob.get_family()) return family_members - elif config.renderer == RendererType.CAIRO: + else: + assert config.renderer == RendererType.CAIRO return extract_mobject_family_members( self.mobjects, use_z_index=self.renderer.camera.use_z_index, ) - def add(self, *mobjects: Mobject): + def add(self, *mobjects: Mobject | OpenGLMobject) -> Self: """ Mobjects will be displayed, from background to foreground in the order with which they are added. @@ -466,26 +495,30 @@ def add(self, *mobjects: Mobject): """ if config.renderer == RendererType.OPENGL: new_mobjects = [] - new_meshes = [] + new_meshes: list[Object3D] = [] for mobject_or_mesh in mobjects: if isinstance(mobject_or_mesh, Object3D): new_meshes.append(mobject_or_mesh) else: new_mobjects.append(mobject_or_mesh) - self.remove(*new_mobjects) - self.mobjects += new_mobjects - self.remove(*new_meshes) + self.remove(*new_mobjects) # type: ignore[arg-type] + self.mobjects += new_mobjects # type: ignore[arg-type] + self.remove(*new_meshes) # type: ignore[arg-type] self.meshes += new_meshes - elif config.renderer == RendererType.CAIRO: - mobjects = [*mobjects, *self.foreground_mobjects] - self.restructure_mobjects(to_remove=mobjects) - self.mobjects += mobjects + else: + assert config.renderer == RendererType.CAIRO + new_and_foreground_mobjects: list[Mobject] = [ + *mobjects, # type: ignore[list-item] + *self.foreground_mobjects, + ] + self.restructure_mobjects(to_remove=new_and_foreground_mobjects) + self.mobjects += new_and_foreground_mobjects if self.moving_mobjects: self.restructure_mobjects( - to_remove=mobjects, + to_remove=new_and_foreground_mobjects, mobject_list_name="moving_mobjects", ) - self.moving_mobjects += mobjects + self.moving_mobjects += new_and_foreground_mobjects return self def add_mobjects_from_animations(self, animations: list[Animation]) -> None: @@ -498,9 +531,9 @@ def add_mobjects_from_animations(self, animations: list[Animation]) -> None: mob = animation.mobject if mob is not None and mob not in curr_mobjects: self.add(mob) - curr_mobjects += mob.get_family() + curr_mobjects += mob.get_family() # type: ignore[arg-type] - def remove(self, *mobjects: Mobject): + def remove(self, *mobjects: Mobject) -> Self: """ Removes mobjects in the passed list of mobjects from the scene and the foreground, by removing them @@ -513,7 +546,8 @@ def remove(self, *mobjects: Mobject): """ if config.renderer == RendererType.OPENGL: mobjects_to_remove = [] - meshes_to_remove = set() + meshes_to_remove: set[Object3D] = set() + mobject_or_mesh: Mobject for mobject_or_mesh in mobjects: if isinstance(mobject_or_mesh, Object3D): meshes_to_remove.add(mobject_or_mesh) @@ -523,11 +557,16 @@ def remove(self, *mobjects: Mobject): self.mobjects, mobjects_to_remove, ) + + def lambda_function(mesh: Object3D) -> bool: + return mesh not in set(meshes_to_remove) + self.meshes = list( - filter(lambda mesh: mesh not in set(meshes_to_remove), self.meshes), + filter(lambda_function, self.meshes), ) return self - elif config.renderer == RendererType.CAIRO: + else: + assert config.renderer == RendererType.CAIRO for list_name in "mobjects", "foreground_mobjects": self.restructure_mobjects(mobjects, list_name, False) return self @@ -553,6 +592,19 @@ def replace(self, old_mobject: Mobject, new_mobject: Mobject) -> None: def replace_in_list( mobj_list: list[Mobject], old_m: Mobject, new_m: Mobject ) -> bool: + # Avoid duplicate references to the same object in self.mobjects + if new_m in mobj_list: + if old_m is new_m: + # In this case, one could say that the old Mobject was already found. + # No replacement is needed, since old_m is new_m, so no action is required. + # This might be unexpected, so raise a warning. + logger.warning( + f"Attempted to replace {type(old_m).__name__} " + "with itself in Scene.mobjects." + ) + return True + mobj_list.remove(new_m) + # We use breadth-first search because some Mobjects get very deep and # we expect top-level elements to be the most common targets for replace. for i in range(0, len(mobj_list)): @@ -630,7 +682,7 @@ def restructure_mobjects( to_remove: Sequence[Mobject], mobject_list_name: str = "mobjects", extract_families: bool = True, - ): + ) -> Scene: """ tl:wr If your scene has a Group(), and you removed a mobject from the Group, @@ -668,7 +720,9 @@ def restructure_mobjects( setattr(self, mobject_list_name, new_list) return self - def get_restructured_mobject_list(self, mobjects: list, to_remove: list): + def get_restructured_mobject_list( + self, mobjects: Iterable[Mobject], to_remove: Iterable[Mobject] + ) -> list[Mobject]: """ Given a list of mobjects and a list of mobjects to be removed, this filters out the removable mobjects from the list of mobjects. @@ -687,9 +741,11 @@ def get_restructured_mobject_list(self, mobjects: list, to_remove: list): list The list of mobjects with the mobjects to remove removed. """ - new_mobjects = [] + new_mobjects: list[Mobject] = [] - def add_safe_mobjects_from_list(list_to_examine, set_to_remove): + def add_safe_mobjects_from_list( + list_to_examine: Iterable[Mobject], set_to_remove: set[Mobject] + ) -> None: for mob in list_to_examine: if mob in set_to_remove: continue @@ -703,7 +759,7 @@ def add_safe_mobjects_from_list(list_to_examine, set_to_remove): return new_mobjects # TODO, remove this, and calls to this - def add_foreground_mobjects(self, *mobjects: Mobject): + def add_foreground_mobjects(self, *mobjects: Mobject) -> Scene: """ Adds mobjects to the foreground, and internally to the list foreground_mobjects, and mobjects. @@ -722,7 +778,7 @@ def add_foreground_mobjects(self, *mobjects: Mobject): self.add(*mobjects) return self - def add_foreground_mobject(self, mobject: Mobject): + def add_foreground_mobject(self, mobject: Mobject) -> Scene: """ Adds a single mobject to the foreground, and internally to the list foreground_mobjects, and mobjects. @@ -739,7 +795,7 @@ def add_foreground_mobject(self, mobject: Mobject): """ return self.add_foreground_mobjects(mobject) - def remove_foreground_mobjects(self, *to_remove: Mobject): + def remove_foreground_mobjects(self, *to_remove: Mobject) -> Scene: """ Removes mobjects from the foreground, and internally from the list foreground_mobjects. @@ -757,7 +813,7 @@ def remove_foreground_mobjects(self, *to_remove: Mobject): self.restructure_mobjects(to_remove, "foreground_mobjects") return self - def remove_foreground_mobject(self, mobject: Mobject): + def remove_foreground_mobject(self, mobject: Mobject) -> Scene: """ Removes a single mobject from the foreground, and internally from the list foreground_mobjects. @@ -774,7 +830,7 @@ def remove_foreground_mobject(self, mobject: Mobject): """ return self.remove_foreground_mobjects(mobject) - def bring_to_front(self, *mobjects: Mobject): + def bring_to_front(self, *mobjects: Mobject) -> Scene: """ Adds the passed mobjects to the scene again, pushing them to he front of the scene. @@ -793,7 +849,7 @@ def bring_to_front(self, *mobjects: Mobject): self.add(*mobjects) return self - def bring_to_back(self, *mobjects: Mobject): + def bring_to_back(self, *mobjects: Mobject) -> Scene: """ Removes the mobject from the scene and adds them to the back of the scene. @@ -813,7 +869,7 @@ def bring_to_back(self, *mobjects: Mobject): self.mobjects = list(mobjects) + self.mobjects return self - def clear(self): + def clear(self) -> Self: """ Removes all mobjects present in self.mobjects and self.foreground_mobjects from the scene. @@ -829,7 +885,38 @@ def clear(self): self.foreground_mobjects = [] return self - def get_moving_mobjects(self, *animations: Animation): + def recursively_unpack_animation_groups( + self, *animations: Animation + ) -> list[Union[Mobject, OpenGLMobject]]: + """ + Unpacks animations + + Parameters + ---------- + *animations + The animations to unpack + + Returns + ------ + list + The list of mobjects in animations + """ + # Imported inside the method to avoid cyclic import + from ..animation.composition import AnimationGroup + + mobjects = [] + for anim in animations: + if isinstance(anim, AnimationGroup): + for sub in anim.animations: + unpacked = self.recursively_unpack_animation_groups(sub) + mobjects.extend(unpacked) + else: + mobjects.append(anim.mobject) + return mobjects + + def get_moving_mobjects( + self, *animations: Animation + ) -> list[Union[Mobject, OpenGLMobject]]: """ Gets all moving mobjects in the passed animation(s). @@ -848,7 +935,9 @@ def get_moving_mobjects(self, *animations: Animation): # as soon as there's one that needs updating of # some kind per frame, return the list from that # point forward. - animation_mobjects = [anim.mobject for anim in animations] + + animation_mobjects = self.recursively_unpack_animation_groups(*animations) + mobjects = self.get_mobject_family_members() for i, mob in enumerate(mobjects): update_possibilities = [ @@ -860,7 +949,9 @@ def get_moving_mobjects(self, *animations: Animation): return mobjects[i:] return [] - def get_moving_and_static_mobjects(self, animations): + def get_moving_and_static_mobjects( + self, animations: Iterable[Animation] + ) -> tuple[list[Mobject], list[Mobject]]: all_mobjects = list_update(self.mobjects, self.foreground_mobjects) all_mobject_families = extract_mobject_family_members( all_mobjects, @@ -881,8 +972,8 @@ def get_moving_and_static_mobjects(self, animations): def compile_animations( self, *args: Animation | Mobject | _AnimationBuilder, - **kwargs, - ): + **kwargs: Any, + ) -> list[Animation]: """ Creates _MethodAnimations from any _AnimationBuilders and updates animation kwargs with kwargs passed to play(). @@ -904,7 +995,7 @@ def compile_animations( # Allow passing a generator to self.play instead of comma separated arguments for arg in arg_anims: try: - animations.append(prepare_animation(arg)) + animations.append(prepare_animation(arg)) # type: ignore[arg-type] except TypeError as e: if inspect.ismethod(arg): raise TypeError( @@ -924,7 +1015,7 @@ def compile_animations( def _get_animation_time_progression( self, animations: list[Animation], duration: float - ): + ) -> tqdm[float]: """ You will hardly use this when making your own animations. This method is for Manim's internal use. @@ -977,10 +1068,10 @@ def _get_animation_time_progression( def get_time_progression( self, run_time: float, - description, + description: str, n_iterations: int | None = None, override_skip_animations: bool = False, - ): + ) -> tqdm[float]: """ You will hardly use this when making your own animations. This method is for Manim's internal use. @@ -1008,7 +1099,7 @@ def get_time_progression( The CommandLine Progress Bar. """ if self.renderer.skip_animations and not override_skip_animations: - times = [run_time] + times: Iterable[float] = [run_time] else: step = 1 / config["frame_rate"] times = np.arange(0, run_time, step) @@ -1026,7 +1117,7 @@ def get_time_progression( def validate_run_time( cls, run_time: float, - method: Callable[[Any, ...], Any], + method: Callable[[Any], Any], parameter_name: str = "run_time", ) -> float: method_name = f"{cls.__name__}.{method.__name__}()" @@ -1051,7 +1142,7 @@ def validate_run_time( return run_time - def get_run_time(self, animations: list[Animation]): + def get_run_time(self, animations: list[Animation]) -> float: """ Gets the total run time for a list of animations. @@ -1073,11 +1164,11 @@ def get_run_time(self, animations: list[Animation]): def play( self, *args: Animation | Mobject | _AnimationBuilder, - subcaption=None, - subcaption_duration=None, - subcaption_offset=0, - **kwargs, - ): + subcaption: str | None = None, + subcaption_duration: float | None = None, + subcaption_offset: float = 0, + **kwargs: Any, + ) -> None: r"""Plays an animation in this scene. Parameters @@ -1105,6 +1196,7 @@ def play( and config.renderer == RendererType.OPENGL and threading.current_thread().name != "MainThread" ): + # TODO: are these actually being used? kwargs.update( { "subcaption": subcaption, @@ -1112,13 +1204,7 @@ def play( "subcaption_offset": subcaption_offset, } ) - self.queue.put( - ( - "play", - args, - kwargs, - ) - ) + self.queue.put(SceneInteractRerun("play", **kwargs)) return start_time = self.time @@ -1142,7 +1228,7 @@ def wait( duration: float = DEFAULT_WAIT_TIME, stop_condition: Callable[[], bool] | None = None, frozen_frame: bool | None = None, - ): + ) -> None: """Plays a "no operation" animation. Parameters @@ -1173,7 +1259,7 @@ def wait( ) ) - def pause(self, duration: float = DEFAULT_WAIT_TIME): + def pause(self, duration: float = DEFAULT_WAIT_TIME) -> None: """Pauses the scene (i.e., displays a frozen frame). This is an alias for :meth:`.wait` with ``frozen_frame`` @@ -1191,7 +1277,9 @@ def pause(self, duration: float = DEFAULT_WAIT_TIME): duration = self.validate_run_time(duration, self.pause, "duration") self.wait(duration=duration, frozen_frame=True) - def wait_until(self, stop_condition: Callable[[], bool], max_time: float = 60): + def wait_until( + self, stop_condition: Callable[[], bool], max_time: float = 60 + ) -> None: """Wait until a condition is satisfied, up to a given maximum duration. Parameters @@ -1208,8 +1296,8 @@ def wait_until(self, stop_condition: Callable[[], bool], max_time: float = 60): def compile_animation_data( self, *animations: Animation | Mobject | _AnimationBuilder, - **play_kwargs, - ): + **play_kwargs: Any, + ) -> Self | None: """Given a list of animations, compile the corresponding static and moving mobjects, and gather the animation durations. @@ -1255,6 +1343,7 @@ def compile_animation_data( def begin_animations(self) -> None: """Start the animations of the scene.""" + assert self.animations is not None for animation in self.animations: animation._setup_scene(self) animation.begin() @@ -1269,13 +1358,14 @@ def begin_animations(self) -> None: def is_current_animation_frozen_frame(self) -> bool: """Returns whether the current animation produces a static frame (generally a Wait).""" + assert self.animations is not None return ( isinstance(self.animations[0], Wait) and len(self.animations) == 1 and self.animations[0].is_static_wait ) - def play_internal(self, skip_rendering: bool = False): + def play_internal(self, skip_rendering: bool = False) -> None: """ This method is used to prep the animations for rendering, apply the arguments and parameters required to them, @@ -1286,6 +1376,7 @@ def play_internal(self, skip_rendering: bool = False): skip_rendering Whether the rendering should be skipped, by default False """ + assert self.animations is not None self.duration = self.get_run_time(self.animations) self.time_progression = self._get_animation_time_progression( self.animations, @@ -1304,11 +1395,13 @@ def play_internal(self, skip_rendering: bool = False): animation.clean_up_from_scene(self) if not self.renderer.skip_animations: self.update_mobjects(0) - self.renderer.static_image = None + # TODO: The OpenGLRenderer does not have the property static.image. + self.renderer.static_image = None # type: ignore[union-attr] # Closing the progress bar at the end of the play. self.time_progression.close() - def check_interactive_embed_is_valid(self): + def check_interactive_embed_is_valid(self) -> bool: + assert isinstance(self.renderer, OpenGLRenderer) if config["force_window"]: return True if self.skip_animation_preview: @@ -1333,35 +1426,46 @@ def check_interactive_embed_is_valid(self): return False return True - def interactive_embed(self): + def interactive_embed(self) -> None: """Like embed(), but allows for screen interaction.""" + assert isinstance(self.camera, OpenGLCamera) + assert isinstance(self.renderer, OpenGLRenderer) if not self.check_interactive_embed_is_valid(): return self.interactive_mode = True + from IPython.terminal.embed import InteractiveShellEmbed - def ipython(shell, namespace): + def ipython(shell: InteractiveShellEmbed, namespace: dict[str, Any]) -> None: import manim.opengl - def load_module_into_namespace(module, namespace): + def load_module_into_namespace( + module: Any, namespace: dict[str, Any] + ) -> None: for name in dir(module): namespace[name] = getattr(module, name) load_module_into_namespace(manim, namespace) load_module_into_namespace(manim.opengl, namespace) - def embedded_rerun(*args, **kwargs): - self.queue.put(("rerun_keyboard", args, kwargs)) + def embedded_rerun(*args: Any, **kwargs: Any) -> None: + self.queue.put(SceneInteractRerun("keyboard")) shell.exiter() namespace["rerun"] = embedded_rerun shell(local_ns=namespace) - self.queue.put(("exit_keyboard", [], {})) + self.queue.put(SceneInteractContinue("keyboard")) + + def get_embedded_method(method_name: str) -> Callable[..., None]: + method = getattr(self, method_name) - def get_embedded_method(method_name): - return lambda *args, **kwargs: self.queue.put((method_name, args, kwargs)) + def embedded_method(*args: Any, **kwargs: Any) -> None: + self.queue.put(MethodWithArgs(method, args, kwargs)) - local_namespace = inspect.currentframe().f_back.f_locals + return embedded_method + + currentframe: FrameType = inspect.currentframe() # type: ignore[assignment] + local_namespace = currentframe.f_back.f_locals # type: ignore[union-attr] for method in ("play", "wait", "add", "remove"): embedded_method = get_embedded_method(method) # Allow for calling scene methods without prepending 'self.'. @@ -1370,7 +1474,6 @@ def get_embedded_method(method_name): from sqlite3 import connect from IPython.core.getipython import get_ipython - from IPython.terminal.embed import InteractiveShellEmbed from traitlets.config import Config cfg = Config() @@ -1394,19 +1497,21 @@ def get_embedded_method(method_name): if self.dearpygui_imported and config["enable_gui"]: if not dpg.is_dearpygui_running(): gui_thread = threading.Thread( - target=configure_pygui, - args=(self.renderer, self.widgets), + target=self._configure_pygui, kwargs={"update": False}, ) gui_thread.start() else: - configure_pygui(self.renderer, self.widgets, update=True) + self._configure_pygui(update=True) self.camera.model_matrix = self.camera.default_model_matrix self.interact(shell, keyboard_thread) - def interact(self, shell, keyboard_thread): + # from IPython.terminal.embed import InteractiveShellEmbed + + def interact(self, shell: Any, keyboard_thread: threading.Thread) -> None: + assert isinstance(self.renderer, OpenGLRenderer) event_handler = RerunSceneHandler(self.queue) file_observer = Observer() file_observer.schedule(event_handler, config["input_file"], recursive=True) @@ -1417,36 +1522,38 @@ def interact(self, shell, keyboard_thread): assert self.queue.qsize() == 0 last_time = time.time() - while not (self.renderer.window.is_closing or self.quit_interaction): + while not ( + (self.renderer.window is not None and self.renderer.window.is_closing) + or self.quit_interaction + ): if not self.queue.empty(): - tup = self.queue.get_nowait() - if tup[0].startswith("rerun"): + action = self.queue.get_nowait() + if isinstance(action, SceneInteractRerun): # Intentionally skip calling join() on the file thread to save time. - if not tup[0].endswith("keyboard"): + if action.sender != "keyboard": if shell.pt_app: shell.pt_app.app.exit(exception=EOFError) file_observer.unschedule_all() raise RerunSceneException keyboard_thread.join() - kwargs = tup[2] - if "from_animation_number" in kwargs: - config["from_animation_number"] = kwargs[ + if "from_animation_number" in action.kwargs: + config["from_animation_number"] = action.kwargs[ "from_animation_number" ] # # TODO: This option only makes sense if interactive_embed() is run at the # # end of a scene by default. - # if "upto_animation_number" in kwargs: - # config["upto_animation_number"] = kwargs[ + # if "upto_animation_number" in action.kwargs: + # config["upto_animation_number"] = action.kwargs[ # "upto_animation_number" # ] keyboard_thread.join() file_observer.unschedule_all() raise RerunSceneException - elif tup[0].startswith("exit"): + elif isinstance(action, SceneInteractContinue): # Intentionally skip calling join() on the file thread to save time. - if not tup[0].endswith("keyboard") and shell.pt_app: + if action.sender != "keyboard" and shell.pt_app: shell.pt_app.app.exit(exception=EOFError) keyboard_thread.join() # Remove exit_keyboard from the queue if necessary. @@ -1455,8 +1562,7 @@ def interact(self, shell, keyboard_thread): keyboard_thread_needs_join = False break else: - method, args, kwargs = tup - getattr(self, method)(*args, **kwargs) + action.method(*action.args, **action.kwargs) else: self.renderer.animation_start_time = 0 dt = time.time() - last_time @@ -1480,10 +1586,11 @@ def interact(self, shell, keyboard_thread): if self.dearpygui_imported and config["enable_gui"]: dpg.stop_dearpygui() - if self.renderer.window.is_closing: + if self.renderer.window is not None and self.renderer.window.is_closing: self.renderer.window.destroy() - def embed(self): + def embed(self) -> None: + assert isinstance(self.renderer, OpenGLRenderer) if not config["preview"]: logger.warning("Called embed() while no preview window is available.") return @@ -1507,7 +1614,9 @@ def embed(self): # Use the locals of the caller as the local namespace # once embedded, and add a few custom shortcuts. - local_ns = inspect.currentframe().f_back.f_locals + current_frame = inspect.currentframe() + assert isinstance(current_frame, FrameType) + local_ns = current_frame.f_back.f_locals # type: ignore[union-attr] # local_ns["touch"] = self.interact for method in ( "play", @@ -1525,9 +1634,77 @@ def embed(self): # End scene when exiting an embed. raise Exception("Exiting scene.") - def update_to_time(self, t): + def _configure_pygui(self, update: bool = True) -> None: + if not self.dearpygui_imported: + raise RuntimeError("Attempted to use DearPyGUI when it isn't imported.") + if update: + dpg.delete_item(window) + else: + dpg.create_viewport() + dpg.setup_dearpygui() + dpg.show_viewport() + + dpg.set_viewport_title(title=f"Manim Community v{__version__}") + dpg.set_viewport_width(1015) + dpg.set_viewport_height(540) + + def rerun_callback(sender: Any, data: Any) -> None: + self.queue.put(SceneInteractRerun("gui")) + + def continue_callback(sender: Any, data: Any) -> None: + self.queue.put(SceneInteractContinue("gui")) + + def scene_selection_callback(sender: Any, data: Any) -> None: + config["scene_names"] = (dpg.get_value(sender),) + self.queue.put(SceneInteractRerun("gui")) + + scene_classes = scene_classes_from_file( + Path(config["input_file"]), full_list=True + ) # type: ignore[call-overload] + scene_names = [scene_class.__name__ for scene_class in scene_classes] + + with dpg.window( + id=window, + label="Manim GUI", + pos=[config["gui_location"][0], config["gui_location"][1]], + width=1000, + height=500, + ): + dpg.set_global_font_scale(2) + dpg.add_button(label="Rerun", callback=rerun_callback) + dpg.add_button(label="Continue", callback=continue_callback) + dpg.add_combo( + label="Selected scene", + items=scene_names, + callback=scene_selection_callback, + default_value=config["scene_names"][0], + ) + dpg.add_separator() + if len(self.widgets) != 0: + with dpg.collapsing_header( + label=f"{config['scene_names'][0]} widgets", + default_open=True, + ): + for widget_config in self.widgets: + widget_config_copy = widget_config.copy() + name = widget_config_copy["name"] + widget = widget_config_copy["widget"] + if widget != "separator": + del widget_config_copy["name"] + del widget_config_copy["widget"] + getattr(dpg, f"add_{widget}")( + label=name, **widget_config_copy + ) + else: + dpg.add_separator() + + if not update: + dpg.start_dearpygui() + + def update_to_time(self, t: float) -> None: dt = t - self.last_t self.last_t = t + assert self.animations is not None for animation in self.animations: animation.update_mobjects(dt) alpha = t / animation.run_time @@ -1589,8 +1766,8 @@ def add_sound( sound_file: str, time_offset: float = 0, gain: float | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ This method is used to add a sound to the animation. @@ -1631,7 +1808,9 @@ def construct(self): time = self.time + time_offset self.renderer.file_writer.add_sound(sound_file, time, gain, **kwargs) - def on_mouse_motion(self, point, d_point): + def on_mouse_motion(self, point: Point3D, d_point: Point3D) -> None: + assert isinstance(self.camera, OpenGLCamera) + assert isinstance(self.renderer, OpenGLRenderer) self.mouse_point.move_to(point) if SHIFT_VALUE in self.renderer.pressed_keys: shift = -d_point @@ -1641,13 +1820,15 @@ def on_mouse_motion(self, point, d_point): shift = np.dot(np.transpose(transform), shift) self.camera.shift(shift) - def on_mouse_scroll(self, point, offset): + def on_mouse_scroll(self, point: Point3D, offset: Point3D) -> None: + assert isinstance(self.camera, OpenGLCamera) if not config.use_projection_stroke_shaders: factor = 1 + np.arctan(-2.1 * offset[1]) self.camera.scale(factor, about_point=self.camera_target) self.mouse_scroll_orbit_controls(point, offset) - def on_key_press(self, symbol, modifiers): + def on_key_press(self, symbol: int, modifiers: int) -> None: + assert isinstance(self.camera, OpenGLCamera) try: char = chr(symbol) except OverflowError: @@ -1663,10 +1844,17 @@ def on_key_press(self, symbol, modifiers): if char in self.key_to_function_map: self.key_to_function_map[char]() - def on_key_release(self, symbol, modifiers): + def on_key_release(self, symbol: int, modifiers: int) -> None: pass - def on_mouse_drag(self, point, d_point, buttons, modifiers): + def on_mouse_drag( + self, + point: Point3D, + d_point: Point3D, + buttons: int, + modifiers: int, + ) -> None: + assert isinstance(self.camera, OpenGLCamera) self.mouse_drag_point.move_to(point) if buttons == 1: self.camera.increment_theta(-d_point[0]) @@ -1680,7 +1868,8 @@ def on_mouse_drag(self, point, d_point, buttons, modifiers): self.mouse_drag_orbit_controls(point, d_point, buttons, modifiers) - def mouse_scroll_orbit_controls(self, point, offset): + def mouse_scroll_orbit_controls(self, point: Point3D, offset: Point3D) -> None: + assert isinstance(self.camera, OpenGLCamera) camera_to_target = self.camera_target - self.camera.get_position() camera_to_target *= np.sign(offset[1]) shift_vector = 0.01 * camera_to_target @@ -1688,7 +1877,14 @@ def mouse_scroll_orbit_controls(self, point, offset): opengl.translation_matrix(*shift_vector) @ self.camera.model_matrix ) - def mouse_drag_orbit_controls(self, point, d_point, buttons, modifiers): + def mouse_drag_orbit_controls( + self, + point: Point3D, + d_point: Point3D, + buttons: int, + modifiers: int, + ) -> None: + assert isinstance(self.camera, OpenGLCamera) # Left click drag. if buttons == 1: # Translate to target the origin and rotate around the z axis. @@ -1761,9 +1957,9 @@ def mouse_drag_orbit_controls(self, point, d_point, buttons, modifiers): ) self.camera_target += total_shift_vector - def set_key_function(self, char, func): + def set_key_function(self, char: str, func: Callable[[], Any]) -> None: self.key_to_function_map[char] = func - def on_mouse_press(self, point, button, modifiers): + def on_mouse_press(self, point: Point3D, button: str, modifiers: int) -> None: for func in self.mouse_press_callbacks: func() diff --git a/manim/scene/scene_file_writer.py b/manim/scene/scene_file_writer.py index d256afb736..257e4301c6 100644 --- a/manim/scene/scene_file_writer.py +++ b/manim/scene/scene_file_writer.py @@ -9,7 +9,7 @@ from fractions import Fraction from pathlib import Path from queue import Queue -from tempfile import NamedTemporaryFile +from tempfile import NamedTemporaryFile, _TemporaryFileWrapper from threading import Thread from typing import TYPE_CHECKING, Any @@ -20,7 +20,6 @@ from pydub import AudioSegment from manim import __version__ -from manim.typing import PixelArray, StrPath from .. import config, logger from .._config.logger_utils import set_file_logger @@ -38,11 +37,15 @@ from .section import DefaultSectionType, Section if TYPE_CHECKING: + from av.container.output import OutputContainer + from av.stream import Stream + from manim.renderer.cairo_renderer import CairoRenderer from manim.renderer.opengl_renderer import OpenGLRenderer + from manim.typing import PixelArray, StrPath -def to_av_frame_rate(fps): +def to_av_frame_rate(fps: float) -> Fraction: epsilon1 = 1e-4 epsilon2 = 0.02 @@ -59,7 +62,9 @@ def to_av_frame_rate(fps): return Fraction(num, denom) -def convert_audio(input_path: Path, output_path: Path, codec_name: str): +def convert_audio( + input_path: Path, output_path: Path | _TemporaryFileWrapper[bytes], codec_name: str +) -> None: with ( av.open(input_path) as input_audio, av.open(output_path, "w") as output_audio, @@ -75,8 +80,7 @@ def convert_audio(input_path: Path, output_path: Path, codec_name: str): class SceneFileWriter: - """ - SceneFileWriter is the object that actually writes the animations + """SceneFileWriter is the object that actually writes the animations played, into video files, using FFMPEG. This is mostly for Manim's internal use. You will rarely, if ever, have to use the methods for this class, unless tinkering with the very @@ -108,14 +112,14 @@ class SceneFileWriter: def __init__( self, renderer: CairoRenderer | OpenGLRenderer, - scene_name: StrPath, + scene_name: str, **kwargs: Any, ) -> None: self.renderer = renderer self.init_output_directories(scene_name) self.init_audio() self.frame_count = 0 - self.partial_movie_files: list[str] = [] + self.partial_movie_files: list[str | None] = [] self.subcaptions: list[srt.Subtitle] = [] self.sections: list[Section] = [] # first section gets automatically created for convenience @@ -124,7 +128,7 @@ def __init__( name="autocreated", type_=DefaultSectionType.NORMAL, skip_animations=False ) - def init_output_directories(self, scene_name: StrPath) -> None: + def init_output_directories(self, scene_name: str) -> None: """Initialise output directories. Notes @@ -231,9 +235,12 @@ def next_section(self, name: str, type_: str, skip_animations: bool) -> None: ), ) - def add_partial_movie_file(self, hash_animation: str): - """Adds a new partial movie file path to `scene.partial_movie_files` and current section from a hash. - This method will compute the path from the hash. In addition to that it adds the new animation to the current section. + def add_partial_movie_file(self, hash_animation: str | None) -> None: + """Adds a new partial movie file path to ``scene.partial_movie_files`` + and current section from a hash. + + This method will compute the path from the hash. In addition to that it + adds the new animation to the current section. Parameters ---------- @@ -256,7 +263,7 @@ def add_partial_movie_file(self, hash_animation: str): self.partial_movie_files.append(new_partial_movie_file) self.sections[-1].partial_movie_files.append(new_partial_movie_file) - def get_resolution_directory(self): + def get_resolution_directory(self) -> str: """Get the name of the resolution directory directly containing the video file. @@ -272,9 +279,11 @@ def get_resolution_directory(self): |--Tex |--texts |--videos - |-- - |--p - |--.mp4 + |-- + |--p + |--partial_movie_files + |--.mp4 + |--.srt Returns ------- @@ -286,11 +295,11 @@ def get_resolution_directory(self): return f"{pixel_height}p{frame_rate}" # Sound - def init_audio(self): + def init_audio(self) -> None: """Preps the writer for adding audio to the movie.""" self.includes_sound = False - def create_audio_segment(self): + def create_audio_segment(self) -> None: """Creates an empty, silent, Audio Segment.""" self.audio_segment = AudioSegment.silent() @@ -299,10 +308,9 @@ def add_audio_segment( new_segment: AudioSegment, time: float | None = None, gain_to_background: float | None = None, - ): - """ - This method adds an audio segment from an - AudioSegment type object and suitable parameters. + ) -> None: + """This method adds an audio segment from an AudioSegment type object + and suitable parameters. Parameters ---------- @@ -310,8 +318,7 @@ def add_audio_segment( The audio segment to add time - the timestamp at which the - sound should be added. + the timestamp at which the sound should be added. gain_to_background The gain of the segment from the background. @@ -341,13 +348,12 @@ def add_audio_segment( def add_sound( self, - sound_file: str, + sound_file: StrPath, time: float | None = None, gain: float | None = None, - **kwargs, - ): - """ - This method adds an audio segment from a sound file. + **kwargs: Any, + ) -> None: + """This method adds an audio segment from a sound file. Parameters ---------- @@ -387,8 +393,7 @@ def add_sound( def begin_animation( self, allow_write: bool = False, file_path: StrPath | None = None ) -> None: - """ - Used internally by manim to stream the animation to FFMPEG for + """Used internally by manim to stream the animation to FFMPEG for displaying or writing to a file. Parameters @@ -400,9 +405,7 @@ def begin_animation( self.open_partial_movie_stream(file_path=file_path) def end_animation(self, allow_write: bool = False) -> None: - """ - Internally used by Manim to stop streaming to - FFMPEG gracefully. + """Internally used by Manim to stop streaming to FFMPEG gracefully. Parameters ---------- @@ -412,7 +415,7 @@ def end_animation(self, allow_write: bool = False) -> None: if write_to_movie() and allow_write: self.close_partial_movie_stream() - def listen_and_write(self): + def listen_and_write(self) -> None: """For internal use only: blocks until new frame is available on the queue.""" while True: num_frames, frame_data = self.queue.get() @@ -422,9 +425,8 @@ def listen_and_write(self): self.encode_and_write_frame(frame_data, num_frames) def encode_and_write_frame(self, frame: PixelArray, num_frames: int) -> None: - """ - For internal use only: takes a given frame in ``np.ndarray`` format and - write it to the stream + """For internal use only: takes a given frame in ``np.ndarray`` format and + writes it to the stream """ for _ in range(num_frames): # Notes: precomputing reusing packets does not work! @@ -438,11 +440,9 @@ def encode_and_write_frame(self, frame: PixelArray, num_frames: int) -> None: self.video_container.mux(packet) def write_frame( - self, frame_or_renderer: np.ndarray | OpenGLRenderer, num_frames: int = 1 - ): - """ - Used internally by Manim to write a frame to - the FFMPEG input buffer. + self, frame_or_renderer: PixelArray | OpenGLRenderer, num_frames: int = 1 + ) -> None: + """Used internally by Manim to write a frame to the FFMPEG input buffer. Parameters ---------- @@ -452,21 +452,27 @@ def write_frame( The number of times to write frame. """ if write_to_movie(): - frame: np.ndarray = ( - frame_or_renderer.get_frame() - if config.renderer == RendererType.OPENGL - else frame_or_renderer - ) + if isinstance(frame_or_renderer, np.ndarray): + frame = frame_or_renderer + else: + frame = ( + frame_or_renderer.get_frame() + if config.renderer == RendererType.OPENGL + else frame_or_renderer + ) msg = (num_frames, frame) self.queue.put(msg) if is_png_format() and not config["dry_run"]: - image: Image = ( - frame_or_renderer.get_image() - if config.renderer == RendererType.OPENGL - else Image.fromarray(frame_or_renderer) - ) + if isinstance(frame_or_renderer, np.ndarray): + image = Image.fromarray(frame_or_renderer) + else: + image = ( + frame_or_renderer.get_image() + if config.renderer == RendererType.OPENGL + else Image.fromarray(frame_or_renderer) + ) target_dir = self.image_file_path.parent / self.image_file_path.stem extension = self.image_file_path.suffix self.output_image( @@ -476,16 +482,17 @@ def write_frame( config["zero_pad"], ) - def output_image(self, image: Image.Image, target_dir, ext, zero_pad: bool): + def output_image( + self, image: Image.Image, target_dir: StrPath, ext: str, zero_pad: bool + ) -> None: if zero_pad: image.save(f"{target_dir}{str(self.frame_count).zfill(zero_pad)}{ext}") else: image.save(f"{target_dir}{self.frame_count}{ext}") self.frame_count += 1 - def save_final_image(self, image: np.ndarray): - """ - The name is a misnomer. This method saves the image + def save_final_image(self, image: Image.Image) -> None: + """The name is a misnomer. This method saves the image passed to it as an in the default image directory. Parameters @@ -502,13 +509,9 @@ def save_final_image(self, image: np.ndarray): self.print_file_ready_message(self.image_file_path) def finish(self) -> None: - """ - Finishes writing to the FFMPEG buffer or writing images - to output directory. - Combines the partial movie files into the - whole scene. - If save_last_frame is True, saves the last - frame in the default image directory. + """Finishes writing to the FFMPEG buffer or writing images to output directory. + Combines the partial movie files into the whole scene. + If save_last_frame is True, saves the last frame in the default image directory. """ if write_to_movie(): self.combine_to_movie() @@ -524,7 +527,7 @@ def finish(self) -> None: if self.subcaptions: self.write_subcaption_file() - def open_partial_movie_stream(self, file_path=None) -> None: + def open_partial_movie_stream(self, file_path: StrPath | None = None) -> None: """Open a container holding a video stream. This is used internally by Manim initialize the container holding @@ -563,8 +566,8 @@ def open_partial_movie_stream(self, file_path=None) -> None: stream.width = config.pixel_width stream.height = config.pixel_height - self.video_container = video_container - self.video_stream = stream + self.video_container: OutputContainer = video_container + self.video_stream: Stream = stream self.queue: Queue[tuple[int, PixelArray | None]] = Queue() self.writer_thread = Thread(target=self.listen_and_write, args=()) @@ -590,7 +593,7 @@ def close_partial_movie_stream(self) -> None: {"path": f"'{self.partial_movie_file_path}'"}, ) - def is_already_cached(self, hash_invocation: str): + def is_already_cached(self, hash_invocation: str) -> bool: """Will check if a file named with `hash_invocation` exists. Parameters @@ -615,9 +618,9 @@ def combine_files( self, input_files: list[str], output_file: Path, - create_gif=False, - includes_sound=False, - ): + create_gif: bool = False, + includes_sound: bool = False, + ) -> None: file_list = self.partial_movie_directory / "partial_movie_file_list.txt" logger.debug( f"Partial movie files to combine ({len(input_files)} files): %(p)s", @@ -651,8 +654,7 @@ def combine_files( if config.transparent and config.movie_file_extension == ".webm": output_stream.pix_fmt = "yuva420p" if create_gif: - """ - The following solution was largely inspired from this comment + """The following solution was largely inspired from this comment https://github.com/imageio/imageio/issues/995#issuecomment-1580533018, and the following code https://github.com/imageio/imageio/blob/65d79140018bb7c64c0692ea72cb4093e8d632a0/imageio/plugins/pyav.py#L927-L996. @@ -716,7 +718,7 @@ def combine_files( partial_movies_input.close() output_container.close() - def combine_to_movie(self): + def combine_to_movie(self) -> None: """Used internally by Manim to combine the separate partial movie files that make up a Scene into a single video file for that Scene. @@ -836,7 +838,7 @@ def combine_to_section_videos(self) -> None: with (self.sections_output_dir / f"{self.output_name}.json").open("w") as file: json.dump(sections_index, file, indent=4) - def clean_cache(self): + def clean_cache(self) -> None: """Will clean the cache by removing the oldest partial_movie_files.""" cached_partial_movies = [ (self.partial_movie_directory / file_name) @@ -858,7 +860,7 @@ def clean_cache(self): " You can change this behaviour by changing max_files_cached in config.", ) - def flush_cache_directory(self): + def flush_cache_directory(self) -> None: """Delete all the cached partial movie files""" cached_partial_movies = [ self.partial_movie_directory / file_name @@ -872,7 +874,7 @@ def flush_cache_directory(self): {"par_dir": self.partial_movie_directory}, ) - def write_subcaption_file(self): + def write_subcaption_file(self) -> None: """Writes the subcaption file.""" if config.output_file is None: return @@ -880,7 +882,7 @@ def write_subcaption_file(self): subcaption_file.write_text(srt.compose(self.subcaptions), encoding="utf-8") logger.info(f"Subcaption file has been written as {subcaption_file}") - def print_file_ready_message(self, file_path): + def print_file_ready_message(self, file_path: StrPath) -> None: """Prints the "File Ready" message to STDOUT.""" config["output_file"] = file_path logger.info("\nFile ready at %(file_path)s\n", {"file_path": f"'{file_path}'"}) diff --git a/manim/scene/section.py b/manim/scene/section.py index af005b52da..99e62c3823 100644 --- a/manim/scene/section.py +++ b/manim/scene/section.py @@ -59,7 +59,9 @@ class Section: :meth:`.OpenGLRenderer.update_skipping_status` """ - def __init__(self, type_: str, video: str | None, name: str, skip_animations: bool): + def __init__( + self, type_: str, video: str | None, name: str, skip_animations: bool + ) -> None: self.type_ = type_ # None when not to be saved -> still keeps section alive self.video: str | None = video @@ -100,5 +102,5 @@ def get_dict(self, sections_dir: Path) -> dict[str, Any]: **video_metadata, ) - def __repr__(self): + def __repr__(self) -> str: return f"
" diff --git a/manim/scene/vector_space_scene.py b/manim/scene/vector_space_scene.py index be75151471..ccc3780200 100644 --- a/manim/scene/vector_space_scene.py +++ b/manim/scene/vector_space_scene.py @@ -4,10 +4,13 @@ __all__ = ["VectorScene", "LinearTransformationScene"] -from typing import Callable +from collections.abc import Callable, Iterable +from typing import TYPE_CHECKING, Any, cast import numpy as np +from manim.animation.creation import DrawBorderThenFill, Group +from manim.camera.camera import Camera from manim.mobject.geometry.arc import Dot from manim.mobject.geometry.line import Arrow, Line, Vector from manim.mobject.geometry.polygram import Rectangle @@ -41,6 +44,19 @@ from ..utils.rate_functions import rush_from, rush_into from ..utils.space_ops import angle_of_vector +if TYPE_CHECKING: + from typing_extensions import Self + + from manim.typing import ( + MappingFunction, + Point3D, + Point3DLike, + Vector2DLike, + Vector3D, + Vector3DLike, + ) + + X_COLOR = GREEN_C Y_COLOR = RED_C Z_COLOR = BLUE_D @@ -53,11 +69,11 @@ # Also, methods I would have thought of as getters, like coords_to_vector, are # actually doing a lot of animating. class VectorScene(Scene): - def __init__(self, basis_vector_stroke_width=6, **kwargs): + def __init__(self, basis_vector_stroke_width: float = 6.0, **kwargs: Any) -> None: super().__init__(**kwargs) self.basis_vector_stroke_width = basis_vector_stroke_width - def add_plane(self, animate: bool = False, **kwargs): + def add_plane(self, animate: bool = False, **kwargs: Any) -> NumberPlane: """ Adds a NumberPlane object to the background. @@ -79,7 +95,11 @@ def add_plane(self, animate: bool = False, **kwargs): self.add(plane) return plane - def add_axes(self, animate: bool = False, color: bool = WHITE, **kwargs): + def add_axes( + self, + animate: bool = False, + color: ParsableManimColor | Iterable[ParsableManimColor] = WHITE, + ) -> Axes: """ Adds a pair of Axes to the Scene. @@ -96,7 +116,9 @@ def add_axes(self, animate: bool = False, color: bool = WHITE, **kwargs): self.add(axes) return axes - def lock_in_faded_grid(self, dimness: float = 0.7, axes_dimness: float = 0.5): + def lock_in_faded_grid( + self, dimness: float = 0.7, axes_dimness: float = 0.5 + ) -> None: """ This method freezes the NumberPlane and Axes that were already in the background, and adds new, manipulatable ones to the foreground. @@ -116,11 +138,13 @@ def lock_in_faded_grid(self, dimness: float = 0.7, axes_dimness: float = 0.5): axes.fade(axes_dimness) self.add(axes) - self.renderer.update_frame() + # TODO + # error: Missing positional argument "scene" in call to "update_frame" of "CairoRenderer" [call-arg] + self.renderer.update_frame() # type: ignore[call-arg] self.renderer.camera = Camera(self.renderer.get_frame()) self.clear() - def get_vector(self, numerical_vector: np.ndarray | list | tuple, **kwargs): + def get_vector(self, numerical_vector: Vector3DLike, **kwargs: Any) -> Arrow: """ Returns an arrow on the Plane given an input numerical vector. @@ -137,19 +161,21 @@ def get_vector(self, numerical_vector: np.ndarray | list | tuple, **kwargs): The Arrow representing the Vector. """ return Arrow( - self.plane.coords_to_point(0, 0), - self.plane.coords_to_point(*numerical_vector[:2]), + # TODO + # error: "VectorScene" has no attribute "plane" [attr-defined] + self.plane.coords_to_point(0, 0), # type: ignore[attr-defined] + self.plane.coords_to_point(*numerical_vector[:2]), # type: ignore[attr-defined] buff=0, **kwargs, ) def add_vector( self, - vector: Arrow | list | tuple | np.ndarray, - color: str = YELLOW, + vector: Arrow | Vector3DLike, + color: ParsableManimColor | Iterable[ParsableManimColor] = YELLOW, animate: bool = True, - **kwargs, - ): + **kwargs: Any, + ) -> Arrow: """ Returns the Vector after adding it to the Plane. @@ -179,13 +205,13 @@ def add_vector( The arrow representing the vector. """ if not isinstance(vector, Arrow): - vector = Vector(vector, color=color, **kwargs) + vector = Vector(np.asarray(vector), color=color, **kwargs) if animate: self.play(GrowArrow(vector)) self.add(vector) return vector - def write_vector_coordinates(self, vector: Arrow, **kwargs): + def write_vector_coordinates(self, vector: Vector, **kwargs: Any) -> Matrix: """ Returns a column matrix indicating the vector coordinates, after writing them to the screen. @@ -203,11 +229,15 @@ def write_vector_coordinates(self, vector: Arrow, **kwargs): :class:`.Matrix` The column matrix representing the vector. """ - coords = vector.coordinate_label(**kwargs) + coords: Matrix = vector.coordinate_label(**kwargs) self.play(Write(coords)) return coords - def get_basis_vectors(self, i_hat_color: str = X_COLOR, j_hat_color: str = Y_COLOR): + def get_basis_vectors( + self, + i_hat_color: ParsableManimColor | Iterable[ParsableManimColor] = X_COLOR, + j_hat_color: ParsableManimColor | Iterable[ParsableManimColor] = Y_COLOR, + ) -> VGroup: """ Returns a VGroup of the Basis Vectors (1,0) and (0,1) @@ -226,12 +256,16 @@ def get_basis_vectors(self, i_hat_color: str = X_COLOR, j_hat_color: str = Y_COL """ return VGroup( *( - Vector(vect, color=color, stroke_width=self.basis_vector_stroke_width) + Vector( + np.asarray(vect), + color=color, + stroke_width=self.basis_vector_stroke_width, + ) for vect, color in [([1, 0], i_hat_color), ([0, 1], j_hat_color)] ) ) - def get_basis_vector_labels(self, **kwargs): + def get_basis_vector_labels(self, **kwargs: Any) -> VGroup: """ Returns naming labels for the basis vectors. @@ -263,13 +297,13 @@ def get_basis_vector_labels(self, **kwargs): def get_vector_label( self, vector: Vector, - label, + label: MathTex | str, at_tip: bool = False, direction: str = "left", rotate: bool = False, - color: str | None = None, + color: ParsableManimColor | None = None, label_scale_factor: float = LARGE_BUFF - 0.2, - ): + ) -> MathTex: """ Returns naming labels for the passed vector. @@ -300,8 +334,11 @@ def get_vector_label( label = "\\vec{\\textbf{%s}}" % label # noqa: UP031 label = MathTex(label) if color is None: - color = vector.get_color() - label.set_color(color) + prepared_color: ParsableManimColor = vector.get_color() + else: + prepared_color = color + label.set_color(prepared_color) + assert isinstance(label, MathTex) label.scale(label_scale_factor) label.add_background_rectangle() @@ -314,16 +351,18 @@ def get_vector_label( if not rotate: label.rotate(-angle, about_point=ORIGIN) if direction == "left": - label.shift(-label.get_bottom() + 0.1 * UP) + temp_shift_1: Vector3D = np.asarray(label.get_bottom()) + label.shift(-temp_shift_1 + 0.1 * UP) else: - label.shift(-label.get_top() + 0.1 * DOWN) + temp_shift_2: Vector3D = np.asarray(label.get_top()) + label.shift(-temp_shift_2 + 0.1 * DOWN) label.rotate(angle, about_point=ORIGIN) label.shift((vector.get_end() - vector.get_start()) / 2) return label def label_vector( - self, vector: Vector, label: MathTex | str, animate: bool = True, **kwargs - ): + self, vector: Vector, label: MathTex | str, animate: bool = True, **kwargs: Any + ) -> MathTex: """ Shortcut method for creating, and animating the addition of a label for the vector. @@ -347,38 +386,38 @@ def label_vector( :class:`~.MathTex` The MathTex of the label. """ - label = self.get_vector_label(vector, label, **kwargs) + mathtex_label = self.get_vector_label(vector, label, **kwargs) if animate: - self.play(Write(label, run_time=1)) - self.add(label) - return label + self.play(Write(mathtex_label, run_time=1)) + self.add(mathtex_label) + return mathtex_label def position_x_coordinate( self, - x_coord, - x_line, - vector, - ): # TODO Write DocStrings for this. + x_coord: MathTex, + x_line: Line, + vector: Vector3DLike, + ) -> MathTex: # TODO Write DocStrings for this. x_coord.next_to(x_line, -np.sign(vector[1]) * UP) x_coord.set_color(X_COLOR) return x_coord def position_y_coordinate( self, - y_coord, - y_line, - vector, - ): # TODO Write DocStrings for this. + y_coord: MathTex, + y_line: Line, + vector: Vector3DLike, + ) -> MathTex: # TODO Write DocStrings for this. y_coord.next_to(y_line, np.sign(vector[0]) * RIGHT) y_coord.set_color(Y_COLOR) return y_coord def coords_to_vector( self, - vector: np.ndarray | list | tuple, - coords_start: np.ndarray | list | tuple = 2 * RIGHT + 2 * UP, + vector: Vector2DLike, + coords_start: Point3DLike = 2 * RIGHT + 2 * UP, clean_up: bool = True, - ): + ) -> None: """ This method writes the vector as a column matrix (henceforth called the label), takes the values in it one by one, and form the corresponding @@ -409,26 +448,29 @@ def coords_to_vector( y_line = Line(x_line.get_end(), arrow.get_end()) x_line.set_color(X_COLOR) y_line.set_color(Y_COLOR) - x_coord, y_coord = array.get_mob_matrix().flatten() + mob_matrix = array.get_mob_matrix() + x_coord = mob_matrix[0][0] + y_coord = mob_matrix[1][0] self.play(Write(array, run_time=1)) self.wait() self.play( ApplyFunction( - lambda x: self.position_x_coordinate(x, x_line, vector), + lambda x: self.position_x_coordinate(x, x_line, vector), # type: ignore[arg-type] x_coord, ), ) self.play(Create(x_line)) animations = [ ApplyFunction( - lambda y: self.position_y_coordinate(y, y_line, vector), + lambda y: self.position_y_coordinate(y, y_line, vector), # type: ignore[arg-type] y_coord, ), FadeOut(array.get_brackets()), ] self.play(*animations) - y_coord, _ = (anim.mobject for anim in animations) + # TODO: Can we delete the line below? I don't think it have any purpose. + # y_coord, _ = (anim.mobject for anim in animations) self.play(Create(y_line)) self.play(Create(arrow)) self.wait() @@ -438,10 +480,10 @@ def coords_to_vector( def vector_to_coords( self, - vector: np.ndarray | list | tuple, + vector: Vector3DLike, integer_labels: bool = True, clean_up: bool = True, - ): + ) -> tuple[Matrix, Line, Line]: """ This method displays vector as a Vector() based vector, and then shows the corresponding lines that make up the x and y components of the vector. @@ -475,7 +517,7 @@ def vector_to_coords( y_line = Line(x_line.get_end(), arrow.get_end()) x_line.set_color(X_COLOR) y_line.set_color(Y_COLOR) - x_coord, y_coord = array.get_entries() + x_coord, y_coord = cast(VGroup, array.get_entries()) x_coord_start = self.position_x_coordinate(x_coord.copy(), x_line, vector) y_coord_start = self.position_y_coordinate(y_coord.copy(), y_line, vector) brackets = array.get_brackets() @@ -499,7 +541,7 @@ def vector_to_coords( self.add(*starting_mobjects) return array, x_line, y_line - def show_ghost_movement(self, vector: Arrow | list | tuple | np.ndarray): + def show_ghost_movement(self, vector: Arrow | Vector2DLike | Vector3DLike) -> None: """ This method plays an animation that partially shows the entire plane moving in the direction of a particular vector. This is useful when you wish to @@ -513,20 +555,26 @@ def show_ghost_movement(self, vector: Arrow | list | tuple | np.ndarray): """ if isinstance(vector, Arrow): vector = vector.get_end() - vector.get_start() - elif len(vector) == 2: - vector = np.append(np.array(vector), 0.0) - x_max = int(config["frame_x_radius"] + abs(vector[0])) - y_max = int(config["frame_y_radius"] + abs(vector[1])) + else: + vector = np.asarray(vector) + if len(vector) == 2: + vector = np.append(np.array(vector), 0.0) + vector_cleaned: Vector3D = vector + + x_max = int(config["frame_x_radius"] + abs(vector_cleaned[0])) + y_max = int(config["frame_y_radius"] + abs(vector_cleaned[1])) + # TODO: + # I think that this should be a VGroup instead of a VMobject. dots = VMobject( - *( + *( # type: ignore[arg-type] Dot(x * RIGHT + y * UP) for x in range(-x_max, x_max) for y in range(-y_max, y_max) ) ) dots.set_fill(BLACK, opacity=0) - dots_halfway = dots.copy().shift(vector / 2).set_fill(WHITE, 1) - dots_end = dots.copy().shift(vector) + dots_halfway = dots.copy().shift(vector_cleaned / 2).set_fill(WHITE, 1) + dots_end = dots.copy().shift(vector_cleaned) self.play(Transform(dots, dots_halfway, rate_func=rush_into)) self.play(Transform(dots, dots_end, rate_func=rush_from)) @@ -585,16 +633,16 @@ def __init__( self, include_background_plane: bool = True, include_foreground_plane: bool = True, - background_plane_kwargs: dict | None = None, - foreground_plane_kwargs: dict | None = None, + background_plane_kwargs: dict[str, Any] | None = None, + foreground_plane_kwargs: dict[str, Any] | None = None, show_coordinates: bool = False, show_basis_vectors: bool = True, basis_vector_stroke_width: float = 6, i_hat_color: ParsableManimColor = X_COLOR, j_hat_color: ParsableManimColor = Y_COLOR, leave_ghost_vectors: bool = False, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.include_background_plane = include_background_plane @@ -605,7 +653,7 @@ def __init__( self.i_hat_color = ManimColor(i_hat_color) self.j_hat_color = ManimColor(j_hat_color) self.leave_ghost_vectors = leave_ghost_vectors - self.background_plane_kwargs = { + self.background_plane_kwargs: dict[str, Any] = { "color": GREY, "axis_config": { "color": GREY, @@ -618,7 +666,7 @@ def __init__( self.ghost_vectors = VGroup() - self.foreground_plane_kwargs = { + self.foreground_plane_kwargs: dict[str, Any] = { "x_range": np.array([-config["frame_width"], config["frame_width"], 1.0]), "y_range": np.array([-config["frame_width"], config["frame_width"], 1.0]), "faded_line_ratio": 1, @@ -630,22 +678,25 @@ def __init__( ) @staticmethod - def update_default_configs(default_configs, passed_configs): + def update_default_configs( + default_configs: Iterable[dict[str, Any]], + passed_configs: Iterable[dict[str, Any] | None], + ) -> None: for default_config, passed_config in zip(default_configs, passed_configs): if passed_config is not None: update_dict_recursively(default_config, passed_config) - def setup(self): + def setup(self) -> None: # The has_already_setup attr is to not break all the old Scenes if hasattr(self, "has_already_setup"): return self.has_already_setup = True - self.background_mobjects = [] - self.foreground_mobjects = [] - self.transformable_mobjects = [] - self.moving_vectors = [] - self.transformable_labels = [] - self.moving_mobjects = [] + self.background_mobjects: list[Mobject] = [] + self.foreground_mobjects: list[Mobject] = [] + self.transformable_mobjects: list[Mobject] = [] + self.moving_vectors: list[Mobject] = [] + self.transformable_labels: list[MathTex] = [] + self.moving_mobjects: list[Mobject] = [] self.background_plane = NumberPlane(**self.background_plane_kwargs) @@ -665,7 +716,9 @@ def setup(self): self.i_hat, self.j_hat = self.basis_vectors self.add(self.basis_vectors) - def add_special_mobjects(self, mob_list: list, *mobs_to_add: Mobject): + def add_special_mobjects( + self, mob_list: list[Mobject], *mobs_to_add: Mobject + ) -> None: """ Adds mobjects to a separate list that can be tracked, if these mobjects have some extra importance. @@ -685,7 +738,7 @@ def add_special_mobjects(self, mob_list: list, *mobs_to_add: Mobject): mob_list.append(mobject) self.add(mobject) - def add_background_mobject(self, *mobjects: Mobject): + def add_background_mobject(self, *mobjects: Mobject) -> None: """ Adds the mobjects to the special list self.background_mobjects. @@ -697,8 +750,9 @@ def add_background_mobject(self, *mobjects: Mobject): """ self.add_special_mobjects(self.background_mobjects, *mobjects) - # TODO, this conflicts with Scene.add_fore - def add_foreground_mobject(self, *mobjects: Mobject): + # TODO, this conflicts with Scene.add_foreground_mobject + # Please be aware that there is also the method Scene.add_foreground_mobjects. + def add_foreground_mobject(self, *mobjects: Mobject) -> None: # type: ignore[override] """ Adds the mobjects to the special list self.foreground_mobjects. @@ -710,7 +764,7 @@ def add_foreground_mobject(self, *mobjects: Mobject): """ self.add_special_mobjects(self.foreground_mobjects, *mobjects) - def add_transformable_mobject(self, *mobjects: Mobject): + def add_transformable_mobject(self, *mobjects: Mobject) -> None: """ Adds the mobjects to the special list self.transformable_mobjects. @@ -724,7 +778,7 @@ def add_transformable_mobject(self, *mobjects: Mobject): def add_moving_mobject( self, mobject: Mobject, target_mobject: Mobject | None = None - ): + ) -> None: """ Adds the mobject to the special list self.moving_mobject, and adds a property @@ -751,8 +805,11 @@ def get_ghost_vectors(self) -> VGroup: return self.ghost_vectors def get_unit_square( - self, color: str = YELLOW, opacity: float = 0.3, stroke_width: float = 3 - ): + self, + color: ParsableManimColor | Iterable[ParsableManimColor] = YELLOW, + opacity: float = 0.3, + stroke_width: float = 3, + ) -> Rectangle: """ Returns a unit square for the current NumberPlane. @@ -783,7 +840,7 @@ def get_unit_square( square.move_to(self.plane.coords_to_point(0, 0), DL) return square - def add_unit_square(self, animate: bool = False, **kwargs): + def add_unit_square(self, animate: bool = False, **kwargs: Any) -> Self: """ Adds a unit square to the scene via self.get_unit_square. @@ -814,8 +871,12 @@ def add_unit_square(self, animate: bool = False, **kwargs): return self def add_vector( - self, vector: Arrow | list | tuple | np.ndarray, color: str = YELLOW, **kwargs - ): + self, + vector: Arrow | list | tuple | np.ndarray, + color: ParsableManimColor = YELLOW, + animate: bool = False, + **kwargs: Any, + ) -> Arrow: """ Adds a vector to the scene, and puts it in the special list self.moving_vectors. @@ -839,11 +900,11 @@ def add_vector( Arrow The arrow representing the vector. """ - vector = super().add_vector(vector, color=color, **kwargs) + vector = super().add_vector(vector, color=color, animate=animate, **kwargs) self.moving_vectors.append(vector) return vector - def write_vector_coordinates(self, vector: Arrow, **kwargs): + def write_vector_coordinates(self, vector: Vector, **kwargs: Any) -> Matrix: """ Returns a column matrix indicating the vector coordinates, after writing them to the screen, and adding them to the @@ -872,8 +933,8 @@ def add_transformable_label( label: MathTex | str, transformation_name: str | MathTex = "L", new_label: str | MathTex | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> MathTex: """ Method for creating, and animating the addition of a transformable label for the vector. @@ -900,26 +961,27 @@ def add_transformable_label( :class:`~.MathTex` The MathTex of the label. """ + # TODO: Clear up types in this function. This is currently a mess. label_mob = self.label_vector(vector, label, **kwargs) if new_label: - label_mob.target_text = new_label + label_mob.target_text = new_label # type: ignore[attr-defined] else: - label_mob.target_text = ( + label_mob.target_text = ( # type: ignore[attr-defined] f"{transformation_name}({label_mob.get_tex_string()})" ) - label_mob.vector = vector - label_mob.kwargs = kwargs - if "animate" in label_mob.kwargs: - label_mob.kwargs.pop("animate") + label_mob.vector = vector # type: ignore[attr-defined] + label_mob.kwargs = kwargs # type: ignore[attr-defined] + if "animate" in label_mob.kwargs: # type: ignore[operator] + label_mob.kwargs.pop("animate") # type: ignore[attr-defined] self.transformable_labels.append(label_mob) - return label_mob + return cast(MathTex, label_mob) def add_title( self, title: str | MathTex | Tex, scale_factor: float = 1.5, animate: bool = False, - ): + ) -> Self: """ Adds a title, after scaling it, adding a background rectangle, moving it to the top and adding it to foreground_mobjects adding @@ -951,7 +1013,9 @@ def add_title( self.title = title return self - def get_matrix_transformation(self, matrix: np.ndarray | list | tuple): + def get_matrix_transformation( + self, matrix: np.ndarray | list | tuple + ) -> Callable[[Point3D], Point3D]: """ Returns a function corresponding to the linear transformation represented by the matrix passed. @@ -965,7 +1029,7 @@ def get_matrix_transformation(self, matrix: np.ndarray | list | tuple): def get_transposed_matrix_transformation( self, transposed_matrix: np.ndarray | list | tuple - ): + ) -> Callable[[Point3D], Point3D]: """ Returns a function corresponding to the linear transformation represented by the transposed @@ -985,7 +1049,7 @@ def get_transposed_matrix_transformation( raise ValueError("Matrix has bad dimensions") return lambda point: np.dot(point, transposed_matrix) - def get_piece_movement(self, pieces: list | tuple | np.ndarray): + def get_piece_movement(self, pieces: Iterable[Mobject]) -> Transform: """ This method returns an animation that moves an arbitrary mobject in "pieces" to its corresponding .target value. @@ -1013,7 +1077,7 @@ def get_piece_movement(self, pieces: list | tuple | np.ndarray): self.add(self.ghost_vectors[-1]) return Transform(start, target, lag_ratio=0) - def get_moving_mobject_movement(self, func: Callable[[np.ndarray], np.ndarray]): + def get_moving_mobject_movement(self, func: MappingFunction) -> Transform: """ This method returns an animation that moves a mobject in "self.moving_mobjects" to its corresponding .target value. @@ -1034,11 +1098,12 @@ def get_moving_mobject_movement(self, func: Callable[[np.ndarray], np.ndarray]): for m in self.moving_mobjects: if m.target is None: m.target = m.copy() - target_point = func(m.get_center()) + temp: Point3D = m.get_center() + target_point = func(temp) m.target.move_to(target_point) return self.get_piece_movement(self.moving_mobjects) - def get_vector_movement(self, func: Callable[[np.ndarray], np.ndarray]): + def get_vector_movement(self, func: MappingFunction) -> Transform: """ This method returns an animation that moves a mobject in "self.moving_vectors" to its corresponding .target value. @@ -1058,12 +1123,12 @@ def get_vector_movement(self, func: Callable[[np.ndarray], np.ndarray]): """ for v in self.moving_vectors: v.target = Vector(func(v.get_end()), color=v.get_color()) - norm = np.linalg.norm(v.target.get_end()) + norm = float(np.linalg.norm(v.target.get_end())) if norm < 0.1: v.target.get_tip().scale(norm) return self.get_piece_movement(self.moving_vectors) - def get_transformable_label_movement(self): + def get_transformable_label_movement(self) -> Transform: """ This method returns an animation that moves all labels in "self.transformable_labels" to its corresponding .target . @@ -1074,12 +1139,17 @@ def get_transformable_label_movement(self): The animation of the movement. """ for label in self.transformable_labels: + # TODO: This location and lines 933 and 335 are the only locations in + # the code where the target_text property is referenced. + target_text: MathTex | str = label.target_text # type: ignore[assignment] label.target = self.get_vector_label( - label.vector.target, label.target_text, **label.kwargs + label.vector.target, # type: ignore[attr-defined] + target_text, + **label.kwargs, # type: ignore[arg-type] ) return self.get_piece_movement(self.transformable_labels) - def apply_matrix(self, matrix: np.ndarray | list | tuple, **kwargs): + def apply_matrix(self, matrix: np.ndarray | list | tuple, **kwargs: Any) -> None: """ Applies the transformation represented by the given matrix to the number plane, and each vector/similar @@ -1094,7 +1164,7 @@ def apply_matrix(self, matrix: np.ndarray | list | tuple, **kwargs): """ self.apply_transposed_matrix(np.array(matrix).T, **kwargs) - def apply_inverse(self, matrix: np.ndarray | list | tuple, **kwargs): + def apply_inverse(self, matrix: np.ndarray | list | tuple, **kwargs: Any) -> None: """ This method applies the linear transformation represented by the inverse of the passed matrix @@ -1110,8 +1180,8 @@ def apply_inverse(self, matrix: np.ndarray | list | tuple, **kwargs): self.apply_matrix(np.linalg.inv(matrix), **kwargs) def apply_transposed_matrix( - self, transposed_matrix: np.ndarray | list | tuple, **kwargs - ): + self, transposed_matrix: np.ndarray | list | tuple, **kwargs: Any + ) -> None: """ Applies the transformation represented by the given transposed matrix to the number plane, @@ -1132,7 +1202,9 @@ def apply_transposed_matrix( kwargs["path_arc"] = net_rotation self.apply_function(func, **kwargs) - def apply_inverse_transpose(self, t_matrix: np.ndarray | list | tuple, **kwargs): + def apply_inverse_transpose( + self, t_matrix: np.ndarray | list | tuple, **kwargs: Any + ) -> None: """ Applies the inverse of the transformation represented by the given transposed matrix to the number plane and each @@ -1149,8 +1221,8 @@ def apply_inverse_transpose(self, t_matrix: np.ndarray | list | tuple, **kwargs) self.apply_transposed_matrix(t_inv, **kwargs) def apply_nonlinear_transformation( - self, function: Callable[[np.ndarray], np.ndarray], **kwargs - ): + self, function: Callable[[np.ndarray], np.ndarray], **kwargs: Any + ) -> None: """ Applies the non-linear transformation represented by the given function to the number plane and each @@ -1168,10 +1240,10 @@ def apply_nonlinear_transformation( def apply_function( self, - function: Callable[[np.ndarray], np.ndarray], - added_anims: list = [], - **kwargs, - ): + function: MappingFunction, + added_anims: list[Animation] = [], + **kwargs: Any, + ) -> None: """ Applies the given function to each of the mobjects in self.transformable_mobjects, and plays the animation showing @@ -1194,7 +1266,7 @@ def apply_function( kwargs["run_time"] = 3 anims = ( [ - ApplyPointwiseFunction(function, t_mob) + ApplyPointwiseFunction(function, t_mob) # type: ignore[arg-type] for t_mob in self.transformable_mobjects ] + [ diff --git a/manim/scene/zoomed_scene.py b/manim/scene/zoomed_scene.py index 361c4eaf55..57c89b1ad6 100644 --- a/manim/scene/zoomed_scene.py +++ b/manim/scene/zoomed_scene.py @@ -49,44 +49,49 @@ def construct(self): __all__ = ["ZoomedScene"] +from typing import TYPE_CHECKING, Any from ..animation.transform import ApplyMethod +from ..camera.camera import Camera from ..camera.moving_camera import MovingCamera from ..camera.multi_camera import MultiCamera from ..constants import * from ..mobject.types.image_mobject import ImageMobjectFromCamera +from ..renderer.opengl_renderer import OpenGLCamera from ..scene.moving_camera_scene import MovingCameraScene +if TYPE_CHECKING: + from manim.typing import Point3DLike, Vector3D + # Note, any scenes from old videos using ZoomedScene will almost certainly # break, as it was restructured. class ZoomedScene(MovingCameraScene): - """ - This is a Scene with special configurations made for when + """This is a Scene with special configurations made for when a particular part of the scene must be zoomed in on and displayed separately. """ def __init__( self, - camera_class=MultiCamera, - zoomed_display_height=3, - zoomed_display_width=3, - zoomed_display_center=None, - zoomed_display_corner=UP + RIGHT, - zoomed_display_corner_buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER, - zoomed_camera_config={ + camera_class: type[Camera] = MultiCamera, + zoomed_display_height: float = 3, + zoomed_display_width: float = 3, + zoomed_display_center: Point3DLike | None = None, + zoomed_display_corner: Vector3D = UP + RIGHT, + zoomed_display_corner_buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, + zoomed_camera_config: dict[str, Any] = { "default_frame_stroke_width": 2, "background_opacity": 1, }, - zoomed_camera_image_mobject_config={}, - zoomed_camera_frame_starting_position=ORIGIN, - zoom_factor=0.15, - image_frame_stroke_width=3, - zoom_activated=False, - **kwargs, - ): + zoomed_camera_image_mobject_config: dict[str, Any] = {}, + zoomed_camera_frame_starting_position: Point3DLike = ORIGIN, + zoom_factor: float = 0.15, + image_frame_stroke_width: float = 3, + zoom_activated: bool = False, + **kwargs: Any, + ) -> None: self.zoomed_display_height = zoomed_display_height self.zoomed_display_width = zoomed_display_width self.zoomed_display_center = zoomed_display_center @@ -102,9 +107,8 @@ def __init__( self.zoom_activated = zoom_activated super().__init__(camera_class=camera_class, **kwargs) - def setup(self): - """ - This method is used internally by Manim to + def setup(self) -> None: + """This method is used internally by Manim to setup the scene for proper use. """ super().setup() @@ -132,10 +136,8 @@ def setup(self): self.zoomed_camera = zoomed_camera self.zoomed_display = zoomed_display - def activate_zooming(self, animate: bool = False): - """ - This method is used to activate the zooming for - the zoomed_camera. + def activate_zooming(self, animate: bool = False) -> None: + """This method is used to activate the zooming for the zoomed_camera. Parameters ---------- @@ -144,7 +146,7 @@ def activate_zooming(self, animate: bool = False): of the zoomed camera. """ self.zoom_activated = True - self.renderer.camera.add_image_mobject_from_camera(self.zoomed_display) + self.renderer.camera.add_image_mobject_from_camera(self.zoomed_display) # type: ignore[union-attr] if animate: self.play(self.get_zoom_in_animation()) self.play(self.get_zoomed_display_pop_out_animation()) @@ -153,9 +155,8 @@ def activate_zooming(self, animate: bool = False): self.zoomed_display, ) - def get_zoom_in_animation(self, run_time: float = 2, **kwargs): - """ - Returns the animation of camera zooming in. + def get_zoom_in_animation(self, run_time: float = 2, **kwargs: Any) -> ApplyMethod: + """Returns the animation of camera zooming in. Parameters ---------- @@ -170,8 +171,11 @@ def get_zoom_in_animation(self, run_time: float = 2, **kwargs): The animation of the camera zooming in. """ frame = self.zoomed_camera.frame - full_frame_height = self.camera.frame_height - full_frame_width = self.camera.frame_width + if isinstance(self.camera, OpenGLCamera): + full_frame_width, full_frame_height = self.camera.frame_shape + else: + full_frame_height = self.camera.frame_height + full_frame_width = self.camera.frame_width frame.save_state() frame.stretch_to_fit_width(full_frame_width) frame.stretch_to_fit_height(full_frame_height) @@ -179,11 +183,9 @@ def get_zoom_in_animation(self, run_time: float = 2, **kwargs): frame.set_stroke(width=0) return ApplyMethod(frame.restore, run_time=run_time, **kwargs) - def get_zoomed_display_pop_out_animation(self, **kwargs): - """ - This is the animation of the popping out of the - mini-display that shows the content of the zoomed - camera. + def get_zoomed_display_pop_out_animation(self, **kwargs: Any) -> ApplyMethod: + """This is the animation of the popping out of the mini-display that + shows the content of the zoomed camera. Returns ------- @@ -195,15 +197,18 @@ def get_zoomed_display_pop_out_animation(self, **kwargs): display.replace(self.zoomed_camera.frame, stretch=True) return ApplyMethod(display.restore) - def get_zoom_factor(self): - """ - Returns the Zoom factor of the Zoomed camera. - Defined as the ratio between the height of the - zoomed camera and the height of the zoomed mini - display. + def get_zoom_factor(self) -> float: + """Returns the Zoom factor of the Zoomed camera. + + Defined as the ratio between the height of the zoomed camera and + the height of the zoomed mini display. + Returns ------- float The zoom factor. """ - return self.zoomed_camera.frame.height / self.zoomed_display.height + zoom_factor: float = ( + self.zoomed_camera.frame.height / self.zoomed_display.height + ) + return zoom_factor diff --git a/manim/typing.py b/manim/typing.py index 660b4a1821..5682ee4eed 100644 --- a/manim/typing.py +++ b/manim/typing.py @@ -20,9 +20,9 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from os import PathLike -from typing import Callable, Union +from typing import Union import numpy as np import numpy.typing as npt @@ -61,11 +61,17 @@ "PointND_Array", "PointNDLike_Array", "Vector2D", + "Vector2DLike", "Vector2D_Array", + "Vector2DLike_Array", "Vector3D", + "Vector3DLike", "Vector3D_Array", + "Vector3DLike_Array", "VectorND", + "VectorNDLike", "VectorND_Array", + "VectorNDLike_Array", "RowVector", "ColVector", "MatrixMN", @@ -312,29 +318,23 @@ This represents anything which can be converted to a :class:`Point2D` NumPy array. - -Normally, a function or method which expects a `Point2D` as a -parameter can handle being passed a `Point3D` instead. """ Point2D_Array: TypeAlias = npt.NDArray[PointDType] """``shape: (M, 2)`` -A NumPy array representing a sequence of `Point2D` objects: +A NumPy array representing a sequence of :class:`Point2D` objects: ``[[float, float], ...]``. """ Point2DLike_Array: TypeAlias = Union[Point2D_Array, Sequence[Point2DLike]] """``shape: (M, 2)`` -An array of `Point2DLike` objects: ``[[float, float], ...]``. +An array of :class:`Point2DLike` objects: ``[[float, float], ...]``. This represents anything which can be converted to a :class:`Point2D_Array` NumPy array. -Normally, a function or method which expects a `Point2D_Array` as a -parameter can handle being passed a `Point3D_Array` instead. - Please refer to the documentation of the function you are using for further type information. """ @@ -357,14 +357,14 @@ Point3D_Array: TypeAlias = npt.NDArray[PointDType] """``shape: (M, 3)`` -A NumPy array representing a sequence of `Point3D` objects: +A NumPy array representing a sequence of :class:`Point3D` objects: ``[[float, float, float], ...]``. """ Point3DLike_Array: TypeAlias = Union[Point3D_Array, Sequence[Point3DLike]] """``shape: (M, 3)`` -An array of `Point3D` objects: ``[[float, float, float], ...]``. +An array of :class:`Point3DLike` objects: ``[[float, float, float], ...]``. This represents anything which can be converted to a :class:`Point3D_Array` NumPy array. @@ -391,14 +391,14 @@ PointND_Array: TypeAlias = npt.NDArray[PointDType] """``shape: (M, N)`` -A NumPy array representing a sequence of `PointND` objects: +A NumPy array representing a sequence of :class:`PointND` objects: ``[[float, ...], ...]``. """ PointNDLike_Array: TypeAlias = Union[PointND_Array, Sequence[PointNDLike]] """``shape: (M, N)`` -An array of `PointND` objects: ``[[float, ...], ...]``. +An array of :class:`PointNDLike` objects: ``[[float, ...], ...]``. This represents anything which can be converted to a :class:`PointND_Array` NumPy array. @@ -416,10 +416,20 @@ Vector2D: TypeAlias = npt.NDArray[PointDType] """``shape: (2,)`` +A NumPy array representing a 2-dimensional vector: ``[float, float]``. + +.. caution:: + Do not confuse with the :class:`~.Vector` or :class:`~.Arrow` + VMobjects! +""" + +Vector2DLike: TypeAlias = Union[npt.NDArray[PointDType], tuple[float, float]] +"""``shape: (2,)`` + A 2-dimensional vector: ``[float, float]``. -Normally, a function or method which expects a `Vector2D` as a -parameter can handle being passed a `Vector3D` instead. +This represents anything which can be converted to a :class:`Vector2D` NumPy +array. .. caution:: Do not confuse with the :class:`~.Vector` or :class:`~.Arrow` @@ -429,17 +439,37 @@ Vector2D_Array: TypeAlias = npt.NDArray[PointDType] """``shape: (M, 2)`` -An array of `Vector2D` objects: ``[[float, float], ...]``. +A NumPy array representing a sequence of :class:`Vector2D` objects: +``[[float, float], ...]``. +""" + +Vector2DLike_Array: TypeAlias = Union[Vector2D_Array, Sequence[Vector2DLike]] +"""``shape: (M, 2)`` + +An array of :class:`Vector2DLike` objects: ``[[float, float], ...]``. -Normally, a function or method which expects a `Vector2D_Array` as a -parameter can handle being passed a `Vector3D_Array` instead. +This represents anything which can be converted to a :class:`Vector2D_Array` +NumPy array. """ Vector3D: TypeAlias = npt.NDArray[PointDType] """``shape: (3,)`` +A NumPy array representing a 3-dimensional vector: ``[float, float, float]``. + +.. caution:: + Do not confuse with the :class:`~.Vector` or :class:`~.Arrow3D` + VMobjects! +""" + +Vector3DLike: TypeAlias = Union[npt.NDArray[PointDType], tuple[float, float, float]] +"""``shape: (3,)`` + A 3-dimensional vector: ``[float, float, float]``. +This represents anything which can be converted to a :class:`Vector3D` NumPy +array. + .. caution:: Do not confuse with the :class:`~.Vector` or :class:`~.Arrow3D` VMobjects! @@ -448,14 +478,38 @@ Vector3D_Array: TypeAlias = npt.NDArray[PointDType] """``shape: (M, 3)`` -An array of `Vector3D` objects: ``[[float, float, float], ...]``. +An NumPy array representing a sequence of :class:`Vector3D` objects: +``[[float, float, float], ...]``. +""" + +Vector3DLike_Array: TypeAlias = Union[npt.NDArray[PointDType], Sequence[Vector3DLike]] +"""``shape: (M, 3)`` + +An array of :class:`Vector3DLike` objects: ``[[float, float, float], ...]``. + +This represents anything which can be converted to a :class:`Vector3D_Array` +NumPy array. """ VectorND: TypeAlias = npt.NDArray[PointDType] """``shape (N,)`` +A NumPy array representing an :math:`N`-dimensional vector: ``[float, ...]``. + +.. caution:: + Do not confuse with the :class:`~.Vector` VMobject! This type alias + is named "VectorND" instead of "Vector" to avoid potential name + collisions. +""" + +VectorNDLike: TypeAlias = Union[npt.NDArray[PointDType], Sequence[float]] +"""``shape (N,)`` + An :math:`N`-dimensional vector: ``[float, ...]``. +This represents anything which can be converted to a :class:`VectorND` NumPy +array. + .. caution:: Do not confuse with the :class:`~.Vector` VMobject! This type alias is named "VectorND" instead of "Vector" to avoid potential name @@ -465,7 +519,17 @@ VectorND_Array: TypeAlias = npt.NDArray[PointDType] """``shape (M, N)`` -An array of `VectorND` objects: ``[[float, ...], ...]``. +A NumPy array representing a sequence of :class:`VectorND` objects: +``[[float, ...], ...]``. +""" + +VectorNDLike_Array: TypeAlias = Union[npt.NDArray[PointDType], Sequence[VectorNDLike]] +"""``shape (M, N)`` + +An array of :class:`VectorNDLike` objects: ``[[float, ...], ...]``. + +This represents anything which can be converted to a :class:`VectorND_Array` +NumPy array. """ RowVector: TypeAlias = npt.NDArray[PointDType] @@ -495,7 +559,7 @@ Zeros: TypeAlias = MatrixMN """``shape: (M, N)`` -A `MatrixMN` filled with zeros, typically created with +A :class:`MatrixMN` filled with zeros, typically created with ``numpy.zeros((M, N))``. """ @@ -508,7 +572,7 @@ QuadraticBezierPoints: TypeAlias = Point3D_Array """``shape: (3, 3)`` -A `Point3D_Array` of three 3D control points for a single quadratic Bézier +A :class:`Point3D_Array` of three 3D control points for a single quadratic Bézier curve: ``[[float, float, float], [float, float, float], [float, float, float]]``. """ @@ -518,7 +582,7 @@ ] """``shape: (3, 3)`` -A `Point3DLike_Array` of three 3D control points for a single quadratic Bézier +A :class:`Point3DLike_Array` of three 3D control points for a single quadratic Bézier curve: ``[[float, float, float], [float, float, float], [float, float, float]]``. @@ -529,7 +593,7 @@ QuadraticBezierPoints_Array: TypeAlias = npt.NDArray[PointDType] """``shape: (N, 3, 3)`` -A NumPy array containing :math:`N` `QuadraticBezierPoints` objects: +A NumPy array containing :math:`N` :class:`QuadraticBezierPoints` objects: ``[[[float, float, float], [float, float, float], [float, float, float]], ...]``. """ @@ -538,7 +602,7 @@ ] """``shape: (N, 3, 3)`` -A sequence of :math:`N` `QuadraticBezierPointsLike` objects: +A sequence of :math:`N` :class:`QuadraticBezierPointsLike` objects: ``[[[float, float, float], [float, float, float], [float, float, float]], ...]``. This represents anything which can be converted to a @@ -548,7 +612,7 @@ QuadraticBezierPath: TypeAlias = Point3D_Array """``shape: (3*N, 3)`` -A `Point3D_Array` of :math:`3N` points, where each one of the +A :class:`Point3D_Array` of :math:`3N` points, where each one of the :math:`N` consecutive blocks of 3 points represents a quadratic Bézier curve: ``[[float, float, float], ...], ...]``. @@ -560,7 +624,7 @@ QuadraticBezierPathLike: TypeAlias = Point3DLike_Array """``shape: (3*N, 3)`` -A `Point3DLike_Array` of :math:`3N` points, where each one of the +A :class:`Point3DLike_Array` of :math:`3N` points, where each one of the :math:`N` consecutive blocks of 3 points represents a quadratic Bézier curve: ``[[float, float, float], ...], ...]``. @@ -575,7 +639,7 @@ QuadraticSpline: TypeAlias = QuadraticBezierPath """``shape: (3*N, 3)`` -A special case of `QuadraticBezierPath` where all the :math:`N` +A special case of :class:`QuadraticBezierPath` where all the :math:`N` quadratic Bézier curves are connected, forming a quadratic spline: ``[[float, float, float], ...], ...]``. @@ -586,7 +650,7 @@ QuadraticSplineLike: TypeAlias = QuadraticBezierPathLike """``shape: (3*N, 3)`` -A special case of `QuadraticBezierPathLike` where all the :math:`N` +A special case of :class:`QuadraticBezierPathLike` where all the :math:`N` quadratic Bézier curves are connected, forming a quadratic spline: ``[[float, float, float], ...], ...]``. @@ -600,7 +664,7 @@ CubicBezierPoints: TypeAlias = Point3D_Array """``shape: (4, 3)`` -A `Point3D_Array` of four 3D control points for a single cubic Bézier curve: +A :class:`Point3D_Array` of four 3D control points for a single cubic Bézier curve: ``[[float, float, float], [float, float, float], [float, float, float], [float, float, float]]``. """ @@ -609,7 +673,7 @@ ] """``shape: (4, 3)`` -A `Point3DLike_Array` of 4 control points for a single cubic Bézier curve: +A :class:`Point3DLike_Array` of 4 control points for a single cubic Bézier curve: ``[[float, float, float], [float, float, float], [float, float, float], [float, float, float]]``. This represents anything which can be converted to a :class:`CubicBezierPoints` @@ -619,7 +683,7 @@ CubicBezierPoints_Array: TypeAlias = npt.NDArray[PointDType] """``shape: (N, 4, 3)`` -A NumPy array containing :math:`N` `CubicBezierPoints` objects: +A NumPy array containing :math:`N` :class:`CubicBezierPoints` objects: ``[[[float, float, float], [float, float, float], [float, float, float], [float, float, float]], ...]``. """ @@ -628,7 +692,7 @@ ] """``shape: (N, 4, 3)`` -A sequence of :math:`N` `CubicBezierPointsLike` objects: +A sequence of :math:`N` :class:`CubicBezierPointsLike` objects: ``[[[float, float, float], [float, float, float], [float, float, float], [float, float, float]], ...]``. This represents anything which can be converted to a @@ -638,7 +702,7 @@ CubicBezierPath: TypeAlias = Point3D_Array """``shape: (4*N, 3)`` -A `Point3D_Array` of :math:`4N` points, where each one of the +A :class:`Point3D_Array` of :math:`4N` points, where each one of the :math:`N` consecutive blocks of 4 points represents a cubic Bézier curve: ``[[float, float, float], ...], ...]``. @@ -650,7 +714,7 @@ CubicBezierPathLike: TypeAlias = Point3DLike_Array """``shape: (4*N, 3)`` -A `Point3DLike_Array` of :math:`4N` points, where each one of the +A :class:`Point3DLike_Array` of :math:`4N` points, where each one of the :math:`N` consecutive blocks of 4 points represents a cubic Bézier curve: ``[[float, float, float], ...], ...]``. @@ -665,7 +729,7 @@ CubicSpline: TypeAlias = CubicBezierPath """``shape: (4*N, 3)`` -A special case of `CubicBezierPath` where all the :math:`N` cubic +A special case of :class:`CubicBezierPath` where all the :math:`N` cubic Bézier curves are connected, forming a quadratic spline: ``[[float, float, float], ...], ...]``. @@ -676,7 +740,7 @@ CubicSplineLike: TypeAlias = CubicBezierPathLike """``shape: (4*N, 3)`` -A special case of `CubicBezierPath` where all the :math:`N` cubic +A special case of :class:`CubicBezierPath` where all the :math:`N` cubic Bézier curves are connected, forming a quadratic spline: ``[[float, float, float], ...], ...]``. @@ -690,7 +754,7 @@ BezierPoints: TypeAlias = Point3D_Array r"""``shape: (PPC, 3)`` -A `Point3D_Array` of :math:`\text{PPC}` control points +A :class:`Point3D_Array` of :math:`\text{PPC}` control points (:math:`\text{PPC: Points Per Curve} = n + 1`) for a single :math:`n`-th degree Bézier curve: ``[[float, float, float], ...]``. @@ -702,7 +766,7 @@ BezierPointsLike: TypeAlias = Point3DLike_Array r"""``shape: (PPC, 3)`` -A `Point3DLike_Array` of :math:`\text{PPC}` control points +A :class:`Point3DLike_Array` of :math:`\text{PPC}` control points (:math:`\text{PPC: Points Per Curve} = n + 1`) for a single :math:`n`-th degree Bézier curve: ``[[float, float, float], ...]``. @@ -717,8 +781,8 @@ BezierPoints_Array: TypeAlias = npt.NDArray[PointDType] r"""``shape: (N, PPC, 3)`` -A NumPy array of :math:`N` `BezierPoints` objects containing -:math:`\text{PPC}` `Point3D` objects each +A NumPy array of :math:`N` :class:`BezierPoints` objects containing +:math:`\text{PPC}` :class:`Point3D` objects each (:math:`\text{PPC: Points Per Curve} = n + 1`): ``[[[float, float, float], ...], ...]``. @@ -731,8 +795,8 @@ ] r"""``shape: (N, PPC, 3)`` -A sequence of :math:`N` `BezierPointsLike` objects containing -:math:`\text{PPC}` `Point3DLike` objects each +A sequence of :math:`N` :class:`BezierPointsLike` objects containing +:math:`\text{PPC}` :class:`Point3DLike` objects each (:math:`\text{PPC: Points Per Curve} = n + 1`): ``[[[float, float, float], ...], ...]``. @@ -746,7 +810,7 @@ BezierPath: TypeAlias = Point3D_Array r"""``shape: (PPC*N, 3)`` -A `Point3D_Array` of :math:`\text{PPC} \cdot N` points, where each +A :class:`Point3D_Array` of :math:`\text{PPC} \cdot N` points, where each one of the :math:`N` consecutive blocks of :math:`\text{PPC}` control points (:math:`\text{PPC: Points Per Curve} = n + 1`) represents a Bézier curve of :math:`n`-th degree: @@ -759,7 +823,7 @@ BezierPathLike: TypeAlias = Point3DLike_Array r"""``shape: (PPC*N, 3)`` -A `Point3DLike_Array` of :math:`\text{PPC} \cdot N` points, where each +A :class:`Point3DLike_Array` of :math:`\text{PPC} \cdot N` points, where each one of the :math:`N` consecutive blocks of :math:`\text{PPC}` control points (:math:`\text{PPC: Points Per Curve} = n + 1`) represents a Bézier curve of :math:`n`-th degree: @@ -775,8 +839,8 @@ Spline: TypeAlias = BezierPath r"""``shape: (PPC*N, 3)`` -A special case of `BezierPath` where all the :math:`N` Bézier curves -consisting of :math:`\text{PPC}` `Point3D` objects +A special case of :class:`BezierPath` where all the :math:`N` Bézier curves +consisting of :math:`\text{PPC}` :class:`Point3D` objects (:math:`\text{PPC: Points Per Curve} = n + 1`) are connected, forming an :math:`n`-th degree spline: ``[[float, float, float], ...], ...]``. @@ -788,8 +852,8 @@ SplineLike: TypeAlias = BezierPathLike r"""``shape: (PPC*N, 3)`` -A special case of `BezierPathLike` where all the :math:`N` Bézier curves -consisting of :math:`\text{PPC}` `Point3D` objects +A special case of :class:`BezierPathLike` where all the :math:`N` Bézier curves +consisting of :math:`\text{PPC}` :class:`Point3D` objects (:math:`\text{PPC: Points Per Curve} = n + 1`) are connected, forming an :math:`n`-th degree spline: ``[[float, float, float], ...], ...]``. @@ -851,29 +915,29 @@ Every value in the array is an integer from 0 to 255. Every pixel is represented either by a single integer indicating its -lightness (for greyscale images), an `RGB_Array_Int` or an +lightness (for greyscale images), an :class:`RGB_Array_Int` or an `RGBA_Array_Int`. """ GrayscalePixelArray: TypeAlias = PixelArray """``shape: (height, width)`` -A 100% opaque grayscale `PixelArray`, where every pixel value is a +A 100% opaque grayscale :class:`PixelArray`, where every pixel value is a `ManimInt` indicating its lightness (black -> gray -> white). """ RGBPixelArray: TypeAlias = PixelArray """``shape: (height, width, 3)`` -A 100% opaque `PixelArray` in color, where every pixel value is an +A 100% opaque :class:`PixelArray` in color, where every pixel value is an `RGB_Array_Int` object. """ RGBAPixelArray: TypeAlias = PixelArray """``shape: (height, width, 4)`` -A `PixelArray` in color where pixels can be transparent. Every pixel -value is an `RGBA_Array_Int` object. +A :class:`PixelArray` in color where pixels can be transparent. Every pixel +value is an :class:`RGBA_Array_Int` object. """ diff --git a/manim/utils/bezier.py b/manim/utils/bezier.py index 28958309a5..b50c7e236c 100644 --- a/manim/utils/bezier.py +++ b/manim/utils/bezier.py @@ -20,9 +20,9 @@ ] -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import reduce -from typing import TYPE_CHECKING, Callable, overload +from typing import TYPE_CHECKING, overload import numpy as np diff --git a/manim/utils/caching.py b/manim/utils/caching.py index 177bb11c5c..0c339c914f 100644 --- a/manim/utils/caching.py +++ b/manim/utils/caching.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from .. import config, logger from ..utils.hashing import get_hash_from_play_call @@ -8,8 +9,6 @@ __all__ = ["handle_caching_play"] if TYPE_CHECKING: - from typing import Any - from manim.renderer.opengl_renderer import OpenGLRenderer from manim.scene.scene import Scene diff --git a/manim/utils/color/core.py b/manim/utils/color/core.py index 82f0b8cf46..af25992e59 100644 --- a/manim/utils/color/core.py +++ b/manim/utils/color/core.py @@ -1501,17 +1501,124 @@ def random_bright_color() -> ManimColor: def random_color() -> ManimColor: """Return a random :class:`ManimColor`. - .. warning:: - This operation is very expensive. Please keep in mind the performance loss. - Returns ------- ManimColor A random :class:`ManimColor`. """ - import manim.utils.color.manim_colors as manim_colors + return RandomColorGenerator._random_color() + + +class RandomColorGenerator: + _singleton: RandomColorGenerator | None = None + """A generator for producing random colors from a given list of Manim colors, + optionally in a reproducible sequence using a seed value. + + When initialized with a specific seed, this class produces a deterministic + sequence of :class:`.ManimColor` instances. If no seed is provided, the selection is + non-deterministic using Python’s global random state. + + Parameters + ---------- + seed + A seed value to initialize the internal random number generator. + If ``None`` (the default), colors are chosen using the global random state. + + sample_colors + A custom list of Manim colors to sample from. Defaults to the full Manim + color palette. + + Examples + -------- + Without a seed (non-deterministic):: + + >>> from manim import RandomColorGenerator, ManimColor, RED, GREEN, BLUE + >>> rnd = RandomColorGenerator() + >>> isinstance(rnd.next(), ManimColor) + True + + With a seed (deterministic sequence):: + + >>> rnd = RandomColorGenerator(42) + >>> rnd.next() + ManimColor('#ECE7E2') + >>> rnd.next() + ManimColor('#BBBBBB') + >>> rnd.next() + ManimColor('#BBBBBB') + + Re-initializing with the same seed gives the same sequence:: + + >>> rnd2 = RandomColorGenerator(42) + >>> rnd2.next() + ManimColor('#ECE7E2') + >>> rnd2.next() + ManimColor('#BBBBBB') + >>> rnd2.next() + ManimColor('#BBBBBB') + + Using a custom color list:: + + >>> custom_palette = [RED, GREEN, BLUE] + >>> rnd_custom = RandomColorGenerator(1, sample_colors=custom_palette) + >>> rnd_custom.next() in custom_palette + True + >>> rnd_custom.next() in custom_palette + True + + Without a seed and custom palette (non-deterministic):: + + >>> rnd_nodet = RandomColorGenerator(sample_colors=[RED]) + >>> rnd_nodet.next() + ManimColor('#FC6255') + """ + + def __init__( + self, + seed: int | None = None, + sample_colors: list[ManimColor] | None = None, + ) -> None: + self.choice = random.choice if seed is None else random.Random(seed).choice + + from manim.utils.color.manim_colors import _all_manim_colors - return random.choice(manim_colors._all_manim_colors) + self.colors = _all_manim_colors if sample_colors is None else sample_colors + + def next(self) -> ManimColor: + """Returns the next color from the configured color list. + + Returns + ------- + ManimColor + A randomly selected color from the specified color list. + + Examples + -------- + Usage:: + + >>> from manim import RandomColorGenerator, RED + >>> rnd = RandomColorGenerator(sample_colors=[RED]) + >>> rnd.next() + ManimColor('#FC6255') + """ + return self.choice(self.colors) + + @classmethod + def _random_color(cls) -> ManimColor: + """Internal method to generate a random color using the singleton instance of + `RandomColorGenerator`. + It will be used by proxy method `random_color` publicly available + and makes it backwards compatible. + + Returns + ------- + ManimColor: + A randomly selected color from the configured color list of + the singleton instance. + """ + if cls._singleton is None: + cls._singleton = cls() + return cls._singleton.next() def get_shaded_rgb( @@ -1567,6 +1674,7 @@ def get_shaded_rgb( "average_color", "random_bright_color", "random_color", + "RandomColorGenerator", "get_shaded_rgb", "HSV", "RGBA", diff --git a/manim/utils/deprecation.py b/manim/utils/deprecation.py index d8bb9fb97c..b7a5febffa 100644 --- a/manim/utils/deprecation.py +++ b/manim/utils/deprecation.py @@ -8,8 +8,8 @@ import inspect import logging import re -from collections.abc import Iterable -from typing import Any, Callable, TypeVar, overload +from collections.abc import Callable, Iterable +from typing import Any, TypeVar, overload from decorator import decorate, decorator @@ -253,7 +253,7 @@ def deprecate(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: # The following line raises this mypy error: # Accessing "__init__" on an instance is unsound, since instance.__init__ # could be from an incompatible subclass [misc] - func.__init__ = decorate(func.__init__, deprecate) # type: ignore[misc] + func.__init__ = decorate(func.__init__, deprecate) # type: ignore[method-assign] return func func = decorate(func, deprecate) diff --git a/manim/utils/docbuild/autocolor_directive.py b/manim/utils/docbuild/autocolor_directive.py index 476dcb1326..108d97ef11 100644 --- a/manim/utils/docbuild/autocolor_directive.py +++ b/manim/utils/docbuild/autocolor_directive.py @@ -79,15 +79,18 @@ def run(self) -> list[nodes.Element]: for base_i in range(0, len(color_elements), num_color_cols): row = nodes.row() - for member_name, hex_code, font_color in color_elements[ - base_i : base_i + num_color_cols - ]: - col1 = nodes.literal(text=member_name) - col2 = nodes.raw( - "", - f'
{hex_code}
', - format="html", - ) + for idx in range(base_i, base_i + num_color_cols): + if idx < len(color_elements): + member_name, hex_code, font_color = color_elements[idx] + col1 = nodes.literal(text=member_name) + col2 = nodes.raw( + "", + f'
{hex_code}
', + format="html", + ) + else: + col1 = nodes.literal(text="") + col2 = nodes.raw("", "", format="html") row += nodes.entry("", col1) row += nodes.entry("", col2) tbody += row diff --git a/manim/utils/docbuild/manim_directive.py b/manim/utils/docbuild/manim_directive.py index b94b7386c9..bf4fb554ce 100644 --- a/manim/utils/docbuild/manim_directive.py +++ b/manim/utils/docbuild/manim_directive.py @@ -268,11 +268,11 @@ def run(self) -> list[nodes.Element]: ] source_block = "\n".join(source_block_in) - config.media_dir = (Path(setup.confdir) / "media").absolute() # type: ignore[attr-defined,assignment] + config.media_dir = (Path(setup.confdir) / "media").absolute() # type: ignore[attr-defined] config.images_dir = "{media_dir}/images" config.video_dir = "{media_dir}/videos/{quality}" output_file = f"{clsname}-{classnamedict[clsname]}" - config.assets_dir = Path("_static") # type: ignore[assignment] + config.assets_dir = Path("_static") config.progress_bar = "none" config.verbosity = "WARNING" @@ -399,7 +399,11 @@ def _delete_rendering_times(*args: tuple[Any]) -> None: def setup(app: Sphinx) -> SetupMetadata: - app.add_node(SkipManimNode, html=(visit, depart)) + app.add_node( + SkipManimNode, + html=(visit, depart), + latex=(lambda a, b: None, lambda a, b: None), + ) setup.app = app # type: ignore[attr-defined] setup.config = app.config # type: ignore[attr-defined] diff --git a/manim/utils/docbuild/module_parsing.py b/manim/utils/docbuild/module_parsing.py index 78769eb565..6e5a6bb2bf 100644 --- a/manim/utils/docbuild/module_parsing.py +++ b/manim/utils/docbuild/module_parsing.py @@ -177,8 +177,8 @@ def parse_module_attributes() -> tuple[AliasDocsDict, DataDict, TypeVarDict]: # TODO: ast.TypeAlias does not exist before Python 3.12, and that # could be the reason why MyPy does not recognize these as # attributes of node. - alias_name = node.name.id if is_type_alias else node.target.id # type: ignore[attr-defined] - definition_node = node.value # type: ignore[attr-defined] + alias_name = node.name.id if is_type_alias else node.target.id + definition_node = node.value # If the definition is a Union, replace with vertical bar notation. # Instead of "Union[Type1, Type2]", we'll have "Type1 | Type2". diff --git a/manim/utils/hashing.py b/manim/utils/hashing.py index ca5d840cf6..be680aef61 100644 --- a/manim/utils/hashing.py +++ b/manim/utils/hashing.py @@ -2,25 +2,24 @@ from __future__ import annotations -import collections import copy import inspect import json -import typing import zlib +from collections.abc import Callable, Hashable, Iterable from time import perf_counter from types import FunctionType, MappingProxyType, MethodType, ModuleType -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np from manim._config import config, logger -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from manim.animation.animation import Animation from manim.camera.camera import Camera from manim.mobject.mobject import Mobject - from manim.opengl.opengl_renderer import OpenGLCamera + from manim.renderer.opengl_renderer import OpenGLCamera from manim.scene.scene import Scene __all__ = ["KEYS_TO_FILTER_OUT", "get_hash_from_play_call", "get_json"] @@ -117,7 +116,7 @@ def mark_as_processed(cls, obj: Any) -> None: def _handle_already_processed( cls, obj, - default_function: typing.Callable[[Any], Any], + default_function: Callable[[Any], Any], ): if isinstance( obj, @@ -131,7 +130,7 @@ def _handle_already_processed( # It makes no sense (and it'd slower) to memoize objects of these primitive # types. Hence, we simply return the object. return obj - if isinstance(obj, collections.abc.Hashable): + if isinstance(obj, Hashable): try: return cls._return(obj, hash, default_function) except TypeError: @@ -144,8 +143,8 @@ def _handle_already_processed( @classmethod def _return( cls, - obj: typing.Any, - obj_to_membership_sign: typing.Callable[[Any], int], + obj: Any, + obj_to_membership_sign: Callable[[Any], int], default_func, memoizing=True, ) -> str | Any: @@ -234,7 +233,7 @@ def default(self, obj: Any): # Serialize it with only the type of the object. You can change this to whatever string when debugging the serialization process. return str(type(obj)) - def _cleaned_iterable(self, iterable: typing.Iterable[Any]): + def _cleaned_iterable(self, iterable: Iterable[Any]): """Check for circular reference at each iterable that will go through the JSONEncoder, as well as key of the wrong format. If a key with a bad format is found (i.e not a int, string, or float), it gets replaced byt its hash using the same process implemented here. @@ -325,8 +324,8 @@ def get_json(obj: dict): def get_hash_from_play_call( scene_object: Scene, camera_object: Camera | OpenGLCamera, - animations_list: typing.Iterable[Animation], - current_mobjects_list: typing.Iterable[Mobject], + animations_list: Iterable[Animation], + current_mobjects_list: Iterable[Mobject], ) -> str: """Take the list of animations and a list of mobjects and output their hashes. This is meant to be used for `scene.play` function. diff --git a/manim/utils/ipython_magic.py b/manim/utils/ipython_magic.py index ce6c93d552..1d62bbc6f4 100644 --- a/manim/utils/ipython_magic.py +++ b/manim/utils/ipython_magic.py @@ -129,7 +129,7 @@ def construct(self): args = main(modified_args, standalone_mode=False, prog_name="manim") assert isinstance(local_ns, dict) with tempconfig(local_ns.get("config", {})): - config.digest_args(args) # type: ignore[arg-type] + config.digest_args(args) renderer = None if config.renderer == RendererType.OPENGL: diff --git a/manim/utils/iterables.py b/manim/utils/iterables.py index 678750deb2..165037a824 100644 --- a/manim/utils/iterables.py +++ b/manim/utils/iterables.py @@ -20,6 +20,7 @@ import itertools as it from collections.abc import ( + Callable, Collection, Generator, Hashable, @@ -27,7 +28,7 @@ Reversible, Sequence, ) -from typing import TYPE_CHECKING, Callable, TypeVar, overload +from typing import TYPE_CHECKING, TypeVar, overload import numpy as np diff --git a/manim/utils/module_ops.py b/manim/utils/module_ops.py index 1b03e374f4..e4d921d8b3 100644 --- a/manim/utils/module_ops.py +++ b/manim/utils/module_ops.py @@ -7,7 +7,7 @@ import types import warnings from pathlib import Path -from typing import TYPE_CHECKING, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, overload from manim._config import config, console, logger from manim.constants import ( @@ -19,8 +19,6 @@ from manim.scene.scene_file_writer import SceneFileWriter if TYPE_CHECKING: - from typing import Any - from manim.scene.scene import Scene __all__ = ["scene_classes_from_file"] diff --git a/manim/utils/opengl.py b/manim/utils/opengl.py index 877cbc2e8f..0cb8d6c867 100644 --- a/manim/utils/opengl.py +++ b/manim/utils/opengl.py @@ -15,12 +15,6 @@ from manim.typing import MatrixMN, Point3D -if TYPE_CHECKING: - from typing_extensions import TypeAlias - - from manim.typing import MatrixMN - - depth = 20 __all__ = [ diff --git a/manim/utils/paths.py b/manim/utils/paths.py index aa39732888..5f073ed428 100644 --- a/manim/utils/paths.py +++ b/manim/utils/paths.py @@ -16,10 +16,15 @@ from ..constants import OUT from ..utils.bezier import interpolate -from ..utils.space_ops import rotation_matrix +from ..utils.space_ops import normalize, rotation_matrix if TYPE_CHECKING: - from manim.typing import PathFuncType, Point3D_Array, Vector3D + from manim.typing import ( + PathFuncType, + Point3D_Array, + Point3DLike_Array, + Vector3DLike, + ) STRAIGHT_PATH_THRESHOLD = 0.01 @@ -72,7 +77,7 @@ def construct(self): def path_along_circles( - arc_angle: float, circles_centers: np.ndarray, axis: Vector3D = OUT + arc_angle: float, circles_centers: Point3DLike_Array, axis: Vector3DLike = OUT ) -> PathFuncType: """This function transforms each point by moving it roughly along a circle, each with its own specified center. @@ -132,9 +137,7 @@ def construct(self): self.wait() """ - if np.linalg.norm(axis) == 0: - axis = OUT - unit_axis = axis / np.linalg.norm(axis) + unit_axis = normalize(axis, fall_back=OUT) def path( start_points: Point3D_Array, end_points: Point3D_Array, alpha: float @@ -152,7 +155,7 @@ def path( return path -def path_along_arc(arc_angle: float, axis: Vector3D = OUT) -> PathFuncType: +def path_along_arc(arc_angle: float, axis: Vector3DLike = OUT) -> PathFuncType: """This function transforms each point by moving it along a circular arc. Parameters @@ -204,9 +207,7 @@ def construct(self): """ if abs(arc_angle) < STRAIGHT_PATH_THRESHOLD: return straight_path() - if np.linalg.norm(axis) == 0: - axis = OUT - unit_axis = axis / np.linalg.norm(axis) + unit_axis = normalize(axis, fall_back=OUT) def path( start_points: Point3D_Array, end_points: Point3D_Array, alpha: float @@ -313,7 +314,7 @@ def construct(self): return path_along_arc(np.pi) -def spiral_path(angle: float, axis: Vector3D = OUT) -> PathFuncType: +def spiral_path(angle: float, axis: Vector3DLike = OUT) -> PathFuncType: """This function transforms each point by moving along a spiral to its destination. Parameters @@ -365,9 +366,7 @@ def construct(self): """ if abs(angle) < STRAIGHT_PATH_THRESHOLD: return straight_path() - if np.linalg.norm(axis) == 0: - axis = OUT - unit_axis = axis / np.linalg.norm(axis) + unit_axis = normalize(axis, fall_back=OUT) def path( start_points: Point3D_Array, end_points: Point3D_Array, alpha: float diff --git a/manim/utils/simple_functions.py b/manim/utils/simple_functions.py index 3735960654..792b7764d6 100644 --- a/manim/utils/simple_functions.py +++ b/manim/utils/simple_functions.py @@ -10,8 +10,9 @@ ] +from collections.abc import Callable from functools import lru_cache -from typing import Any, Callable, Protocol, TypeVar +from typing import Any, Protocol, TypeVar import numpy as np from scipy import special diff --git a/manim/utils/space_ops.py b/manim/utils/space_ops.py index 963c0811ee..2bbecaa499 100644 --- a/manim/utils/space_ops.py +++ b/manim/utils/space_ops.py @@ -3,8 +3,8 @@ from __future__ import annotations import itertools as it -from collections.abc import Sequence -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING import numpy as np from mapbox_earcut import triangulate_float32 as earcut @@ -28,7 +28,8 @@ Vector2D, Vector2D_Array, Vector3D, - Vector3D_Array, + Vector3DLike, + Vector3DLike_Array, ) __all__ = [ @@ -70,7 +71,7 @@ def norm_squared(v: float) -> float: return val -def cross(v1: Vector3D, v2: Vector3D) -> Vector3D: +def cross(v1: Vector3DLike, v2: Vector3DLike) -> Vector3D: return np.array( [ v1[1] * v2[2] - v1[2] * v2[1], @@ -178,8 +179,8 @@ def quaternion_conjugate(quaternion: Sequence[float]) -> np.ndarray: def rotate_vector( - vector: np.ndarray, angle: float, axis: np.ndarray = OUT -) -> np.ndarray: + vector: Vector3DLike, angle: float, axis: Vector3DLike = OUT +) -> Vector3D: """Function for rotating a vector. Parameters @@ -245,7 +246,7 @@ def rotation_matrix_from_quaternion(quat: np.ndarray) -> np.ndarray: return np.transpose(rotation_matrix_transpose_from_quaternion(quat)) -def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> np.ndarray: +def rotation_matrix_transpose(angle: float, axis: Vector3DLike) -> np.ndarray: if all(np.array(axis)[:2] == np.zeros(2)): return rotation_about_z(angle * np.sign(axis[2])).T return rotation_matrix(angle, axis).T @@ -253,12 +254,12 @@ def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> np.ndarray: def rotation_matrix( angle: float, - axis: np.ndarray, + axis: Vector3DLike, homogeneous: bool = False, ) -> np.ndarray: """Rotation in R^3 about a specified axis of rotation.""" inhomogeneous_rotation_matrix = Rotation.from_rotvec( - angle * normalize(np.array(axis)) + angle * normalize(axis) ).as_matrix() if not homogeneous: return inhomogeneous_rotation_matrix @@ -388,7 +389,7 @@ def normalize_along_axis(array: np.ndarray, axis: np.ndarray) -> np.ndarray: return array -def get_unit_normal(v1: Vector3D, v2: Vector3D, tol: float = 1e-6) -> Vector3D: +def get_unit_normal(v1: Vector3DLike, v2: Vector3DLike, tol: float = 1e-6) -> Vector3D: """Gets the unit normal of the vectors. Parameters @@ -405,18 +406,21 @@ def get_unit_normal(v1: Vector3D, v2: Vector3D, tol: float = 1e-6) -> Vector3D: np.ndarray The normal of the two vectors. """ + np_v1 = np.asarray(v1) + np_v2 = np.asarray(v2) + # Instead of normalizing v1 and v2, just divide by the greatest # of all their absolute components, which is just enough - div1, div2 = max(np.abs(v1)), max(np.abs(v2)) + div1, div2 = max(np.abs(np_v1)), max(np.abs(np_v2)) if div1 == 0.0: if div2 == 0.0: return DOWN - u = v2 / div2 + u = np_v2 / div2 elif div2 == 0.0: - u = v1 / div1 + u = np_v1 / div1 else: # Normal scenario: v1 and v2 are both non-null - u1, u2 = v1 / div1, v2 / div2 + u1, u2 = np_v1 / div1, np_v2 / div2 cp = cross(u1, u2) cp_norm = np.sqrt(norm_squared(cp)) if cp_norm > tol: @@ -590,9 +594,9 @@ def line_intersection( def find_intersection( p0s: Point3DLike_Array, - v0s: Vector3D_Array, + v0s: Vector3DLike_Array, p1s: Point3DLike_Array, - v1s: Vector3D_Array, + v1s: Vector3DLike_Array, threshold: float = 1e-5, ) -> list[Point3D]: """ @@ -605,7 +609,7 @@ def find_intersection( # algorithm from https://en.wikipedia.org/wiki/Skew_lines#Nearest_points result = [] - for p0, v0, p1, v1 in zip(*[p0s, v0s, p1s, v1s]): + for p0, v0, p1, v1 in zip(p0s, v0s, p1s, v1s): normal = cross(v1, cross(v0, v1)) denom = max(np.dot(v0, normal), threshold) result += [p0 + np.dot(p1 - p0, normal) / denom * v0] diff --git a/manim/utils/testing/_test_class_makers.py b/manim/utils/testing/_test_class_makers.py index 084aab487b..b7b53306d3 100644 --- a/manim/utils/testing/_test_class_makers.py +++ b/manim/utils/testing/_test_class_makers.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from manim.renderer.cairo_renderer import CairoRenderer from manim.renderer.opengl_renderer import OpenGLRenderer @@ -46,16 +47,16 @@ class DummySceneFileWriter(SceneFileWriter): def __init__( self, renderer: CairoRenderer | OpenGLRenderer, - scene_name: StrPath, + scene_name: str, **kwargs: Any, ) -> None: super().__init__(renderer, scene_name, **kwargs) self.i = 0 - def init_output_directories(self, scene_name: StrPath) -> None: + def init_output_directories(self, scene_name: str) -> None: pass - def add_partial_movie_file(self, hash_animation: str) -> None: + def add_partial_movie_file(self, hash_animation: str | None) -> None: pass def begin_animation( diff --git a/manim/utils/testing/frames_comparison.py b/manim/utils/testing/frames_comparison.py index 7b5f23f8c8..01061d6860 100644 --- a/manim/utils/testing/frames_comparison.py +++ b/manim/utils/testing/frames_comparison.py @@ -2,8 +2,9 @@ import functools import inspect +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable +from typing import Any import cairo import pytest diff --git a/manim/utils/tex_file_writing.py b/manim/utils/tex_file_writing.py index 8c2c9c8c00..4c61bf4fa5 100644 --- a/manim/utils/tex_file_writing.py +++ b/manim/utils/tex_file_writing.py @@ -288,7 +288,7 @@ def print_all_tex_errors(log_file: Path, tex_compiler: str, tex_file: Path) -> N index for index, line in enumerate(tex_compilation_log) if line.startswith("!") ] if error_indices: - with tex_file.open() as f: + with tex_file.open(encoding="utf-8") as f: tex = f.readlines() for error_index in error_indices: print_tex_error(tex_compilation_log, error_index, tex) diff --git a/mypy.ini b/mypy.ini index 4b5f718509..d0dcd812d6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -48,58 +48,107 @@ warn_return_any = True # # disable_recursive_aliases = True -[mypy-manim._config.*] +[mypy-manim._config.utils] ignore_errors = True -disable_error_code = return-value -[mypy-manim._config.logger_utils] -ignore_errors = False +[mypy-manim.animation.animation] +ignore_errors = True -[mypy-manim.animation.*] +[mypy-manim.animation.creation] ignore_errors = True -[mypy-manim.camera.*] +[mypy-manim.animation.rotation] ignore_errors = True -[mypy-manim.cli.*] -ignore_errors = False +[mypy-manim.animation.speedmodifier] +ignore_errors = True -[mypy-manim.cli.cfg.*] -ignore_errors = False +[mypy-manim.animation.transform_matching_parts] +ignore_errors = True -[mypy-manim.gui.*] +[mypy-manim.animation.transform] ignore_errors = True -[mypy-manim.mobject.*] +[mypy-manim.animation.updaters.mobject_update_utils] ignore_errors = True -[mypy-manim.mobject.text.code_mobject] -ignore_errors = False +[mypy-manim.camera.mapping_camera] +ignore_errors = True -[mypy-manim.mobject.geometry.*] +[mypy-manim.camera.moving_camera] ignore_errors = True -[mypy-manim.renderer.*] +[mypy-manim.mobject.graphing.coordinate_systems] ignore_errors = True -[mypy-manim.scene.*] +[mypy-manim.mobject.graph] ignore_errors = True -[mypy-manim.utils.hashing.*] +[mypy-manim.mobject.logo] ignore_errors = True -[mypy-manim.utils.color.*] -ignore_errors = False +[mypy-manim.mobject.mobject] +ignore_errors = True -[mypy-manim.utils.iterables] -warn_return_any = False +[mypy-manim.mobject.opengl.opengl_compatibility] +ignore_errors = True +[mypy-manim.mobject.opengl.opengl_geometry] +ignore_errors = True -# ---------------- We can't properly type this ------------------------ +[mypy-manim.mobject.opengl.opengl_image_mobject] +ignore_errors = True + +[mypy-manim.mobject.opengl.opengl_mobject] +ignore_errors = True + +[mypy-manim.mobject.opengl.opengl_point_cloud_mobject] +ignore_errors = True + +[mypy-manim.mobject.opengl.opengl_surface] +ignore_errors = True + +[mypy-manim.mobject.opengl.opengl_vectorized_mobject] +ignore_errors = True + +[mypy-manim.mobject.table] +ignore_errors = True -[mypy-manim.grpc.*] +[mypy-manim.mobject.text.text_mobject] ignore_errors = True +[mypy-manim.mobject.three_d.three_dimensions] +ignore_errors = True + +[mypy-manim.mobject.types.image_mobject] +ignore_errors = True + +[mypy-manim.mobject.types.point_cloud_mobject] +ignore_errors = True + +[mypy-manim.mobject.types.vectorized_mobject] +ignore_errors = True + +[mypy-manim.mobject.vector_field] +ignore_errors = True + +[mypy-manim.renderer.cairo_renderer] +ignore_errors = True + +[mypy-manim.renderer.opengl_renderer] +ignore_errors = True + +[mypy-manim.renderer.shader_wrapper] +ignore_errors = True + +[mypy-manim.scene.three_d_scene] +ignore_errors = True + +[mypy-manim.utils.hashing] +ignore_errors = True + + + # ---------------- Stubless imported Modules -------------------------- # We should be able to create stubs for this or type hint it diff --git a/tests/module/animation/test_transform.py b/tests/module/animation/test_transform.py new file mode 100644 index 0000000000..d1b9c43aef --- /dev/null +++ b/tests/module/animation/test_transform.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from manim import Circle, ReplacementTransform, Scene, Square, VGroup + + +def test_no_duplicate_references(): + scene = Scene() + c = Circle() + sq = Square() + scene.add(c, sq) + + scene.play(ReplacementTransform(c, sq)) + assert len(scene.mobjects) == 1 + assert scene.mobjects[0] is sq + + +def test_duplicate_references_in_group(): + scene = Scene() + c = Circle() + sq = Square() + vg = VGroup(c, sq) + scene.add(vg) + + scene.play(ReplacementTransform(c, sq)) + submobs = vg.submobjects + assert len(submobs) == 1 + assert submobs[0] is sq diff --git a/tests/module/mobject/test_value_tracker.py b/tests/module/mobject/test_value_tracker.py index 824902a9b0..e690d131f3 100644 --- a/tests/module/mobject/test_value_tracker.py +++ b/tests/module/mobject/test_value_tracker.py @@ -39,6 +39,13 @@ def test_value_tracker_bool(): assert tracker +def test_value_tracker_add(): + """Test ValueTracker.__add__()""" + tracker = ValueTracker(0.0) + tracker2 = tracker + 10.0 + assert tracker2.get_value() == 10.0 + + def test_value_tracker_iadd(): """Test ValueTracker.__iadd__()""" tracker = ValueTracker(0.0) @@ -46,6 +53,13 @@ def test_value_tracker_iadd(): assert tracker.get_value() == 10.0 +def test_value_tracker_floordiv(): + """Test ValueTracker.__floordiv__()""" + tracker = ValueTracker(5.0) + tracker2 = tracker // 2.0 + assert tracker2.get_value() == 2.0 + + def test_value_tracker_ifloordiv(): """Test ValueTracker.__ifloordiv__()""" tracker = ValueTracker(5.0) @@ -53,6 +67,13 @@ def test_value_tracker_ifloordiv(): assert tracker.get_value() == 2.0 +def test_value_tracker_mod(): + """Test ValueTracker.__mod__()""" + tracker = ValueTracker(20.0) + tracker2 = tracker % 3.0 + assert tracker2.get_value() == 2.0 + + def test_value_tracker_imod(): """Test ValueTracker.__imod__()""" tracker = ValueTracker(20.0) @@ -60,6 +81,13 @@ def test_value_tracker_imod(): assert tracker.get_value() == 2.0 +def test_value_tracker_mul(): + """Test ValueTracker.__mul__()""" + tracker = ValueTracker(3.0) + tracker2 = tracker * 4.0 + assert tracker2.get_value() == 12.0 + + def test_value_tracker_imul(): """Test ValueTracker.__imul__()""" tracker = ValueTracker(3.0) @@ -67,6 +95,13 @@ def test_value_tracker_imul(): assert tracker.get_value() == 12.0 +def test_value_tracker_pow(): + """Test ValueTracker.__pow__()""" + tracker = ValueTracker(3.0) + tracker2 = tracker**3.0 + assert tracker2.get_value() == 27.0 + + def test_value_tracker_ipow(): """Test ValueTracker.__ipow__()""" tracker = ValueTracker(3.0) @@ -74,6 +109,13 @@ def test_value_tracker_ipow(): assert tracker.get_value() == 27.0 +def test_value_tracker_sub(): + """Test ValueTracker.__sub__()""" + tracker = ValueTracker(20.0) + tracker2 = tracker - 10.0 + assert tracker2.get_value() == 10.0 + + def test_value_tracker_isub(): """Test ValueTracker.__isub__()""" tracker = ValueTracker(20.0) @@ -81,6 +123,13 @@ def test_value_tracker_isub(): assert tracker.get_value() == 10.0 +def test_value_tracker_truediv(): + """Test ValueTracker.__truediv__()""" + tracker = ValueTracker(5.0) + tracker2 = tracker / 2.0 + assert tracker2.get_value() == 2.5 + + def test_value_tracker_itruediv(): """Test ValueTracker.__itruediv__()""" tracker = ValueTracker(5.0) diff --git a/tests/test_graphical_units/control_data/geometry/negative_z_index_AnimationGroup.npz b/tests/test_graphical_units/control_data/geometry/negative_z_index_AnimationGroup.npz new file mode 100644 index 0000000000..80ba71a0eb Binary files /dev/null and b/tests/test_graphical_units/control_data/geometry/negative_z_index_AnimationGroup.npz differ diff --git a/tests/test_graphical_units/control_data/geometry/negative_z_index_LaggedStart.npz b/tests/test_graphical_units/control_data/geometry/negative_z_index_LaggedStart.npz new file mode 100644 index 0000000000..13aba250f3 Binary files /dev/null and b/tests/test_graphical_units/control_data/geometry/negative_z_index_LaggedStart.npz differ diff --git a/tests/test_graphical_units/test_geometry.py b/tests/test_graphical_units/test_geometry.py index fef2ca0951..7bc65561cb 100644 --- a/tests/test_graphical_units/test_geometry.py +++ b/tests/test_graphical_units/test_geometry.py @@ -174,6 +174,21 @@ def test_ZIndex(scene): scene.play(ApplyMethod(triangle.shift, 2 * UP)) +@frames_comparison(last_frame=False) +def test_negative_z_index_AnimationGroup(scene): + # https://github.com/ManimCommunity/manim/issues/3334 + s = Square().set_z_index(-1) + scene.play(AnimationGroup(GrowFromCenter(s))) + + +@frames_comparison(last_frame=False) +def test_negative_z_index_LaggedStart(scene): + # https://github.com/ManimCommunity/manim/issues/3914 + background = Rectangle(z_index=-1) + line = Line(2 * LEFT, 2 * RIGHT, color=RED_D, z_index=-1) + scene.play(LaggedStart(FadeIn(background), FadeIn(line), lag_ratio=0.5)) + + @frames_comparison def test_Angle(scene): l1 = Line(ORIGIN, RIGHT) diff --git a/tests/test_graphical_units/test_indication.py b/tests/test_graphical_units/test_indication.py index 592cf5fb38..a6e66ac9cc 100644 --- a/tests/test_graphical_units/test_indication.py +++ b/tests/test_graphical_units/test_indication.py @@ -63,5 +63,5 @@ def test_Wiggle_custom_about_points(): scale_about_point=[1.0, 2.0, 3.0], rotate_about_point=[4.0, 5.0, 6.0], ) - assert wiggle.get_scale_about_point() == [1.0, 2.0, 3.0] - assert wiggle.get_rotate_about_point() == [4.0, 5.0, 6.0] + assert np.all(wiggle.get_scale_about_point() == [1.0, 2.0, 3.0]) + assert np.all(wiggle.get_rotate_about_point() == [4.0, 5.0, 6.0])