Skip to content

Commit c887b51

Browse files
Add type annotations to composition.py (#4366)
Co-authored-by: Francisco Manríquez Novoa <[email protected]>
1 parent ded54e4 commit c887b51

File tree

3 files changed

+36
-35
lines changed

3 files changed

+36
-35
lines changed

manim/animation/animation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __new__(
129129

130130
def __init__(
131131
self,
132-
mobject: Mobject | None,
132+
mobject: Mobject | OpenGLMobject | None,
133133
lag_ratio: float = DEFAULT_ANIMATION_LAG_RATIO,
134134
run_time: float = DEFAULT_ANIMATION_RUN_TIME,
135135
rate_func: Callable[[float], float] = smooth,
@@ -266,7 +266,7 @@ def create_starting_mobject(self) -> Mobject:
266266
# Keep track of where the mobject starts
267267
return self.mobject.copy()
268268

269-
def get_all_mobjects(self) -> Sequence[Mobject]:
269+
def get_all_mobjects(self) -> Sequence[Mobject | OpenGLMobject]:
270270
"""Get all mobjects involved in the animation.
271271
272272
Ordering must match the ordering of arguments to interpolate_submobject

manim/animation/composition.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22

33
from __future__ import annotations
44

5-
import types
65
from collections.abc import Iterable, Sequence
7-
from typing import TYPE_CHECKING, Callable
6+
from typing import TYPE_CHECKING, Any, Callable
87

98
import numpy as np
109

1110
from manim._config import config
1211
from manim.animation.animation import Animation, prepare_animation
1312
from manim.constants import RendererType
1413
from manim.mobject.mobject import Group, Mobject
15-
from manim.mobject.opengl.opengl_mobject import OpenGLGroup
14+
from manim.mobject.opengl.opengl_mobject import OpenGLGroup, OpenGLMobject
1615
from manim.scene.scene import Scene
1716
from manim.utils.iterables import remove_list_redundancies
1817
from manim.utils.parameter_parsing import flatten_iterable_parameters
@@ -54,31 +53,34 @@ class AnimationGroup(Animation):
5453

5554
def __init__(
5655
self,
57-
*animations: Animation | Iterable[Animation] | types.GeneratorType[Animation],
58-
group: Group | VGroup | OpenGLGroup | OpenGLVGroup = None,
56+
*animations: Animation | Iterable[Animation],
57+
group: Group | VGroup | OpenGLGroup | OpenGLVGroup | None = None,
5958
run_time: float | None = None,
6059
rate_func: Callable[[float], float] = linear,
6160
lag_ratio: float = 0,
62-
**kwargs,
63-
) -> None:
61+
**kwargs: Any,
62+
):
6463
arg_anim = flatten_iterable_parameters(animations)
6564
self.animations = [prepare_animation(anim) for anim in arg_anim]
6665
self.rate_func = rate_func
67-
self.group = group
68-
if self.group is None:
66+
if group is None:
6967
mobjects = remove_list_redundancies(
7068
[anim.mobject for anim in self.animations if not anim.is_introducer()],
7169
)
7270
if config["renderer"] == RendererType.OPENGL:
73-
self.group = OpenGLGroup(*mobjects)
71+
self.group: Group | VGroup | OpenGLGroup | OpenGLVGroup = OpenGLGroup(
72+
*mobjects
73+
)
7474
else:
7575
self.group = Group(*mobjects)
76+
else:
77+
self.group = group
7678
super().__init__(
7779
self.group, rate_func=self.rate_func, lag_ratio=lag_ratio, **kwargs
7880
)
7981
self.run_time: float = self.init_run_time(run_time)
8082

81-
def get_all_mobjects(self) -> Sequence[Mobject]:
83+
def get_all_mobjects(self) -> Sequence[Mobject | OpenGLMobject]:
8284
return list(self.group)
8385

8486
def begin(self) -> None:
@@ -93,7 +95,7 @@ def begin(self) -> None:
9395
for anim in self.animations:
9496
anim.begin()
9597

96-
def _setup_scene(self, scene) -> None:
98+
def _setup_scene(self, scene: Scene) -> None:
9799
for anim in self.animations:
98100
anim._setup_scene(scene)
99101

@@ -118,7 +120,7 @@ def update_mobjects(self, dt: float) -> None:
118120
]:
119121
anim.update_mobjects(dt)
120122

121-
def init_run_time(self, run_time) -> float:
123+
def init_run_time(self, run_time: float | None) -> float:
122124
"""Calculates the run time of the animation, if different from ``run_time``.
123125
124126
Parameters
@@ -146,9 +148,9 @@ def build_animations_with_timings(self) -> None:
146148
run_times = np.array([anim.run_time for anim in self.animations])
147149
num_animations = run_times.shape[0]
148150
dtype = [("anim", "O"), ("start", "f8"), ("end", "f8")]
149-
self.anims_with_timings = np.zeros(num_animations, dtype=dtype)
150-
self.anims_begun = np.zeros(num_animations, dtype=bool)
151-
self.anims_finished = np.zeros(num_animations, dtype=bool)
151+
self.anims_with_timings: np.ndarray = np.zeros(num_animations, dtype=dtype)
152+
self.anims_begun: np.ndarray = np.zeros(num_animations, dtype=bool)
153+
self.anims_finished: np.ndarray = np.zeros(num_animations, dtype=bool)
152154
if num_animations == 0:
153155
return
154156

@@ -228,7 +230,7 @@ def construct(self):
228230
))
229231
"""
230232

231-
def __init__(self, *animations: Animation, lag_ratio: float = 1, **kwargs) -> None:
233+
def __init__(self, *animations: Animation, lag_ratio: float = 1, **kwargs: Any):
232234
super().__init__(*animations, lag_ratio=lag_ratio, **kwargs)
233235

234236
def begin(self) -> None:
@@ -247,7 +249,7 @@ def update_mobjects(self, dt: float) -> None:
247249
if self.active_animation:
248250
self.active_animation.update_mobjects(dt)
249251

250-
def _setup_scene(self, scene) -> None:
252+
def _setup_scene(self, scene: Scene | None) -> None:
251253
if scene is None:
252254
return
253255
if self.is_introducer():
@@ -339,7 +341,7 @@ def __init__(
339341
self,
340342
*animations: Animation,
341343
lag_ratio: float = DEFAULT_LAGGED_START_LAG_RATIO,
342-
**kwargs,
344+
**kwargs: Any,
343345
):
344346
super().__init__(*animations, lag_ratio=lag_ratio, **kwargs)
345347

@@ -384,20 +386,22 @@ def construct(self):
384386

385387
def __init__(
386388
self,
387-
AnimationClass: Callable[..., Animation],
389+
animation_class: type[Animation],
388390
mobject: Mobject,
389-
arg_creator: Callable[[Mobject], str] = None,
391+
arg_creator: Callable[[Mobject], Iterable[Any]] | None = None,
390392
run_time: float = 2,
391-
**kwargs,
392-
) -> None:
393-
args_list = []
394-
for submob in mobject:
395-
if arg_creator:
396-
args_list.append(arg_creator(submob))
397-
else:
398-
args_list.append((submob,))
393+
**kwargs: Any,
394+
):
395+
if arg_creator is None:
396+
397+
def identity(mob: Mobject) -> Mobject:
398+
return mob
399+
400+
arg_creator = identity
401+
402+
args_list = [arg_creator(submob) for submob in mobject]
399403
anim_kwargs = dict(kwargs)
400404
if "lag_ratio" in anim_kwargs:
401405
anim_kwargs.pop("lag_ratio")
402-
animations = [AnimationClass(*args, **anim_kwargs) for args in args_list]
406+
animations = [animation_class(*args, **anim_kwargs) for args in args_list]
403407
super().__init__(*animations, run_time=run_time, **kwargs)

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ ignore_errors = True
5454
[mypy-manim.animation.animation]
5555
ignore_errors = True
5656

57-
[mypy-manim.animation.composition]
58-
ignore_errors = True
59-
6057
[mypy-manim.animation.creation]
6158
ignore_errors = True
6259

0 commit comments

Comments
 (0)