Skip to content

Commit ded54e4

Browse files
Add type annotations to indication.py (#4367)
Co-authored-by: Francisco Manríquez Novoa <[email protected]>
1 parent 04503ad commit ded54e4

File tree

4 files changed

+62
-47
lines changed

4 files changed

+62
-47
lines changed

manim/animation/indication.py

Lines changed: 59 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,16 @@ def construct(self):
4040
]
4141

4242
from collections.abc import Iterable
43-
from typing import Callable
43+
from typing import Any
4444

4545
import numpy as np
46+
from typing_extensions import Self
4647

4748
from manim.mobject.geometry.arc import Circle, Dot
4849
from manim.mobject.geometry.line import Line
4950
from manim.mobject.geometry.polygram import Rectangle
5051
from manim.mobject.geometry.shape_matchers import SurroundingRectangle
52+
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
5153
from manim.scene.scene import Scene
5254

5355
from .. import config
@@ -61,9 +63,10 @@ def construct(self):
6163
from ..constants import *
6264
from ..mobject.mobject import Mobject
6365
from ..mobject.types.vectorized_mobject import VGroup, VMobject
66+
from ..typing import Point3D, Point3DLike, Vector3DLike
6467
from ..utils.bezier import interpolate, inverse_interpolate
6568
from ..utils.color import GREY, YELLOW, ParsableManimColor
66-
from ..utils.rate_functions import smooth, there_and_back, wiggle
69+
from ..utils.rate_functions import RateFunction, smooth, there_and_back, wiggle
6770
from ..utils.space_ops import normalize
6871

6972

@@ -95,12 +98,12 @@ def construct(self):
9598

9699
def __init__(
97100
self,
98-
focus_point: np.ndarray | Mobject,
101+
focus_point: Point3DLike | Mobject,
99102
opacity: float = 0.2,
100-
color: str = GREY,
103+
color: ParsableManimColor = GREY,
101104
run_time: float = 2,
102-
**kwargs,
103-
) -> None:
105+
**kwargs: Any,
106+
):
104107
self.focus_point = focus_point
105108
self.color = color
106109
self.opacity = opacity
@@ -151,15 +154,15 @@ def __init__(
151154
self,
152155
mobject: Mobject,
153156
scale_factor: float = 1.2,
154-
color: str = YELLOW,
155-
rate_func: Callable[[float, float | None], np.ndarray] = there_and_back,
156-
**kwargs,
157-
) -> None:
157+
color: ParsableManimColor = YELLOW,
158+
rate_func: RateFunction = there_and_back,
159+
**kwargs: Any,
160+
):
158161
self.color = color
159162
self.scale_factor = scale_factor
160163
super().__init__(mobject, rate_func=rate_func, **kwargs)
161164

162-
def create_target(self) -> Mobject:
165+
def create_target(self) -> Mobject | OpenGLMobject:
163166
target = self.mobject.copy()
164167
target.scale(self.scale_factor)
165168
target.set_color(self.color)
@@ -219,20 +222,20 @@ def construct(self):
219222

220223
def __init__(
221224
self,
222-
point: np.ndarray | Mobject,
225+
point: Point3DLike | Mobject,
223226
line_length: float = 0.2,
224227
num_lines: int = 12,
225228
flash_radius: float = 0.1,
226229
line_stroke_width: int = 3,
227-
color: str = YELLOW,
230+
color: ParsableManimColor = YELLOW,
228231
time_width: float = 1,
229232
run_time: float = 1.0,
230-
**kwargs,
231-
) -> None:
233+
**kwargs: Any,
234+
):
232235
if isinstance(point, Mobject):
233-
self.point = point.get_center()
236+
self.point: Point3D = point.get_center()
234237
else:
235-
self.point = point
238+
self.point = np.asarray(point)
236239
self.color = color
237240
self.line_length = line_length
238241
self.num_lines = num_lines
@@ -303,7 +306,9 @@ def construct(self):
303306
304307
"""
305308

306-
def __init__(self, mobject: VMobject, time_width: float = 0.1, **kwargs) -> None:
309+
def __init__(
310+
self, mobject: VMobject, time_width: float = 0.1, **kwargs: Any
311+
) -> None:
307312
self.time_width = time_width
308313
super().__init__(mobject, remover=True, introducer=True, **kwargs)
309314

@@ -322,7 +327,14 @@ def clean_up_from_scene(self, scene: Scene) -> None:
322327

323328

324329
class ShowPassingFlashWithThinningStrokeWidth(AnimationGroup):
325-
def __init__(self, vmobject, n_segments=10, time_width=0.1, remover=True, **kwargs):
330+
def __init__(
331+
self,
332+
vmobject: VMobject,
333+
n_segments: int = 10,
334+
time_width: float = 0.1,
335+
remover: bool = True,
336+
**kwargs: Any,
337+
):
326338
self.n_segments = n_segments
327339
self.time_width = time_width
328340
self.remover = remover
@@ -389,19 +401,19 @@ def construct(self):
389401
def __init__(
390402
self,
391403
mobject: Mobject,
392-
direction: np.ndarray = UP,
404+
direction: Vector3DLike = UP,
393405
amplitude: float = 0.2,
394-
wave_func: Callable[[float], float] = smooth,
406+
wave_func: RateFunction = smooth,
395407
time_width: float = 1,
396408
ripples: int = 1,
397409
run_time: float = 2,
398-
**kwargs,
399-
) -> None:
410+
**kwargs: Any,
411+
):
400412
x_min = mobject.get_left()[0]
401413
x_max = mobject.get_right()[0]
402414
vect = amplitude * normalize(direction)
403415

404-
def wave(t):
416+
def wave(t: float) -> float:
405417
# Creates a wave with n ripples from a simple rate_func
406418
# This wave is build up as follows:
407419
# The time is split into 2*ripples phases. In every phase the amplitude
@@ -467,7 +479,8 @@ def homotopy(
467479
relative_x = inverse_interpolate(x_min, x_max, x)
468480
wave_phase = inverse_interpolate(lower, upper, relative_x)
469481
nudge = wave(wave_phase) * vect
470-
return np.array([x, y, z]) + nudge
482+
return_value: tuple[float, float, float] = np.array([x, y, z]) + nudge
483+
return return_value
471484

472485
super().__init__(homotopy, mobject, run_time=run_time, **kwargs)
473486

@@ -511,24 +524,28 @@ def __init__(
511524
scale_value: float = 1.1,
512525
rotation_angle: float = 0.01 * TAU,
513526
n_wiggles: int = 6,
514-
scale_about_point: np.ndarray | None = None,
515-
rotate_about_point: np.ndarray | None = None,
527+
scale_about_point: Point3DLike | None = None,
528+
rotate_about_point: Point3DLike | None = None,
516529
run_time: float = 2,
517-
**kwargs,
518-
) -> None:
530+
**kwargs: Any,
531+
):
519532
self.scale_value = scale_value
520533
self.rotation_angle = rotation_angle
521534
self.n_wiggles = n_wiggles
522535
self.scale_about_point = scale_about_point
536+
if scale_about_point is not None:
537+
self.scale_about_point = np.array(scale_about_point)
523538
self.rotate_about_point = rotate_about_point
539+
if rotate_about_point is not None:
540+
self.rotate_about_point = np.array(rotate_about_point)
524541
super().__init__(mobject, run_time=run_time, **kwargs)
525542

526-
def get_scale_about_point(self) -> np.ndarray:
543+
def get_scale_about_point(self) -> Point3D:
527544
if self.scale_about_point is None:
528545
return self.mobject.get_center()
529546
return self.scale_about_point
530547

531-
def get_rotate_about_point(self) -> np.ndarray:
548+
def get_rotate_about_point(self) -> Point3D:
532549
if self.rotate_about_point is None:
533550
return self.mobject.get_center()
534551
return self.rotate_about_point
@@ -538,7 +555,7 @@ def interpolate_submobject(
538555
submobject: Mobject,
539556
starting_submobject: Mobject,
540557
alpha: float,
541-
) -> None:
558+
) -> Self:
542559
submobject.points[:, :] = starting_submobject.points
543560
submobject.scale(
544561
interpolate(1, self.scale_value, there_and_back(alpha)),
@@ -548,6 +565,7 @@ def interpolate_submobject(
548565
wiggle(alpha, self.n_wiggles) * self.rotation_angle,
549566
about_point=self.get_rotate_about_point(),
550567
)
568+
return self
551569

552570

553571
class Circumscribe(Succession):
@@ -595,18 +613,18 @@ def construct(self):
595613
def __init__(
596614
self,
597615
mobject: Mobject,
598-
shape: type = Rectangle,
599-
fade_in=False,
600-
fade_out=False,
601-
time_width=0.3,
616+
shape: type[Rectangle] | type[Circle] = Rectangle,
617+
fade_in: bool = False,
618+
fade_out: bool = False,
619+
time_width: float = 0.3,
602620
buff: float = SMALL_BUFF,
603621
color: ParsableManimColor = YELLOW,
604-
run_time=1,
605-
stroke_width=DEFAULT_STROKE_WIDTH,
606-
**kwargs,
622+
run_time: float = 1,
623+
stroke_width: float = DEFAULT_STROKE_WIDTH,
624+
**kwargs: Any,
607625
):
608626
if shape is Rectangle:
609-
frame = SurroundingRectangle(
627+
frame: SurroundingRectangle | Circle = SurroundingRectangle(
610628
mobject,
611629
color=color,
612630
buff=buff,
@@ -685,7 +703,7 @@ def __init__(
685703
time_off: float = 0.5,
686704
blinks: int = 1,
687705
hide_at_end: bool = False,
688-
**kwargs,
706+
**kwargs: Any,
689707
):
690708
animations = [
691709
UpdateFromFunc(

manim/animation/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def begin(self) -> None:
209209
self.mobject.align_data(self.target_copy)
210210
super().begin()
211211

212-
def create_target(self) -> Mobject:
212+
def create_target(self) -> Mobject | OpenGLMobject:
213213
# Has no meaningful effect here, but may be useful
214214
# in subclasses
215215
return self.target_mobject

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ ignore_errors = True
6363
[mypy-manim.animation.growing]
6464
ignore_errors = True
6565

66-
[mypy-manim.animation.indication]
67-
ignore_errors = True
68-
6966
[mypy-manim.animation.movement]
7067
ignore_errors = True
7168

tests/test_graphical_units/test_indication.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,5 @@ def test_Wiggle_custom_about_points():
6363
scale_about_point=[1.0, 2.0, 3.0],
6464
rotate_about_point=[4.0, 5.0, 6.0],
6565
)
66-
assert wiggle.get_scale_about_point() == [1.0, 2.0, 3.0]
67-
assert wiggle.get_rotate_about_point() == [4.0, 5.0, 6.0]
66+
assert np.all(wiggle.get_scale_about_point() == [1.0, 2.0, 3.0])
67+
assert np.all(wiggle.get_rotate_about_point() == [4.0, 5.0, 6.0])

0 commit comments

Comments
 (0)