Skip to content

Commit 9c9ebf4

Browse files
henrikmidtibychopan050pre-commit-ci[bot]
authored
Add type annotations to dot_cloud.py, vectorized_mobject_rendering.py and opengl_three_dimensions.py (#4359)
Co-authored-by: Francisco Manríquez Novoa <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d18dc8f commit 9c9ebf4

File tree

7 files changed

+90
-36
lines changed

7 files changed

+90
-36
lines changed

manim/mobject/opengl/dot_cloud.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,33 @@
22

33
__all__ = ["TrueDot", "DotCloud"]
44

5+
from typing import Any
6+
57
import numpy as np
8+
from typing_extensions import Self
69

710
from manim.constants import ORIGIN, RIGHT, UP
811
from manim.mobject.opengl.opengl_point_cloud_mobject import OpenGLPMobject
9-
from manim.utils.color import YELLOW
12+
from manim.typing import Point3DLike
13+
from manim.utils.color import YELLOW, ParsableManimColor
1014

1115

1216
class DotCloud(OpenGLPMobject):
1317
def __init__(
14-
self, color=YELLOW, stroke_width=2.0, radius=2.0, density=10, **kwargs
18+
self,
19+
color: ParsableManimColor = YELLOW,
20+
stroke_width: float = 2.0,
21+
radius: float = 2.0,
22+
density: float = 10,
23+
**kwargs: Any,
1524
):
1625
self.radius = radius
1726
self.epsilon = 1.0 / density
1827
super().__init__(
1928
stroke_width=stroke_width, density=density, color=color, **kwargs
2029
)
2130

22-
def init_points(self):
31+
def init_points(self) -> None:
2332
self.points = np.array(
2433
[
2534
r * (np.cos(theta) * RIGHT + np.sin(theta) * UP)
@@ -34,14 +43,16 @@ def init_points(self):
3443
dtype=np.float32,
3544
)
3645

37-
def make_3d(self, gloss=0.5, shadow=0.2):
46+
def make_3d(self, gloss: float = 0.5, shadow: float = 0.2) -> Self:
3847
self.set_gloss(gloss)
3948
self.set_shadow(shadow)
4049
self.apply_depth_test()
4150
return self
4251

4352

4453
class TrueDot(DotCloud):
45-
def __init__(self, center=ORIGIN, stroke_width=2.0, **kwargs):
54+
def __init__(
55+
self, center: Point3DLike = ORIGIN, stroke_width: float = 2.0, **kwargs: Any
56+
):
4657
self.radius = stroke_width
4758
super().__init__(points=[center], stroke_width=stroke_width, **kwargs)

manim/mobject/opengl/opengl_point_cloud_mobject.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,19 @@
44

55
import moderngl
66
import numpy as np
7+
from typing_extensions import Self
78

89
from manim.constants import *
910
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
1011
from manim.utils.bezier import interpolate
11-
from manim.utils.color import BLACK, WHITE, YELLOW, color_gradient, color_to_rgba
12+
from manim.utils.color import (
13+
BLACK,
14+
WHITE,
15+
YELLOW,
16+
ParsableManimColor,
17+
color_gradient,
18+
color_to_rgba,
19+
)
1220
from manim.utils.config_ops import _Uniforms
1321
from manim.utils.iterables import resize_with_interpolation
1422

@@ -27,15 +35,19 @@ class OpenGLPMobject(OpenGLMobject):
2735
point_radius = _Uniforms()
2836

2937
def __init__(
30-
self, stroke_width=2.0, color=YELLOW, render_primitive=moderngl.POINTS, **kwargs
38+
self,
39+
stroke_width: float = 2.0,
40+
color: ParsableManimColor = YELLOW,
41+
render_primitive: int = moderngl.POINTS,
42+
**kwargs,
3143
):
3244
self.stroke_width = stroke_width
3345
super().__init__(color=color, render_primitive=render_primitive, **kwargs)
3446
self.point_radius = (
3547
self.stroke_width * OpenGLPMobject.OPENGL_POINT_RADIUS_SCALE_FACTOR
3648
)
3749

38-
def reset_points(self):
50+
def reset_points(self) -> Self:
3951
self.rgbas = np.zeros((1, 4))
4052
self.points = np.zeros((0, 3))
4153
return self

manim/mobject/opengl/opengl_surface.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from manim.constants import *
1111
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
12+
from manim.typing import Point3D_Array, Vector3D_Array
1213
from manim.utils.bezier import integer_interpolate, interpolate
1314
from manim.utils.color import *
1415
from manim.utils.config_ops import _Data, _Uniforms
@@ -160,12 +161,14 @@ def compute_triangle_indices(self):
160161
def get_triangle_indices(self):
161162
return self.triangle_indices
162163

163-
def get_surface_points_and_nudged_points(self):
164+
def get_surface_points_and_nudged_points(
165+
self,
166+
) -> tuple[Point3D_Array, Point3D_Array, Point3D_Array]:
164167
points = self.points
165168
k = len(points) // 3
166169
return points[:k], points[k : 2 * k], points[2 * k :]
167170

168-
def get_unit_normals(self):
171+
def get_unit_normals(self) -> Vector3D_Array:
169172
s_points, du_points, dv_points = self.get_surface_points_and_nudged_points()
170173
normals = np.cross(
171174
(du_points - s_points) / self.epsilon,

manim/mobject/opengl/opengl_three_dimensions.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
import numpy as np
46

57
from manim.mobject.opengl.opengl_surface import OpenGLSurface
@@ -11,13 +13,13 @@
1113
class OpenGLSurfaceMesh(OpenGLVGroup):
1214
def __init__(
1315
self,
14-
uv_surface,
15-
resolution=None,
16-
stroke_width=1,
17-
normal_nudge=1e-2,
18-
depth_test=True,
19-
flat_stroke=False,
20-
**kwargs,
16+
uv_surface: OpenGLSurface,
17+
resolution: tuple[int, int] | None = None,
18+
stroke_width: float = 1,
19+
normal_nudge: float = 1e-2,
20+
depth_test: bool = True,
21+
flat_stroke: bool = False,
22+
**kwargs: Any,
2123
):
2224
if not isinstance(uv_surface, OpenGLSurface):
2325
raise Exception("uv_surface must be of type OpenGLSurface")
@@ -31,7 +33,7 @@ def __init__(
3133
**kwargs,
3234
)
3335

34-
def init_points(self):
36+
def init_points(self) -> None:
3537
uv_surface = self.uv_surface
3638

3739
full_nu, full_nv = uv_surface.resolution

manim/mobject/opengl/opengl_vectorized_mobject.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import operator as op
55
from collections.abc import Iterable, Sequence
66
from functools import reduce, wraps
7-
from typing import Callable
7+
from typing import Any, Callable
88

99
import moderngl
1010
import numpy as np
11+
from typing_extensions import Self
1112

1213
from manim import config
1314
from manim.constants import *
@@ -171,6 +172,15 @@ def get_group_class(self):
171172
def get_mobject_type_class():
172173
return OpenGLVMobject
173174

175+
@property
176+
def submobjects(self) -> Sequence[OpenGLVMobject]:
177+
return self._submobjects if hasattr(self, "_submobjects") else []
178+
179+
@submobjects.setter
180+
def submobjects(self, submobject_list: Iterable[OpenGLVMobject]) -> None:
181+
self.remove(*self.submobjects)
182+
self.add(*submobject_list)
183+
174184
def init_data(self):
175185
super().init_data()
176186
self.data.pop("rgbas")
@@ -594,7 +604,9 @@ def set_points_as_corners(self, points: Iterable[float]) -> OpenGLVMobject:
594604
)
595605
return self
596606

597-
def set_points_smoothly(self, points, true_smooth=False):
607+
def set_points_smoothly(
608+
self, points: Point3DLike_Array, true_smooth: bool = False
609+
) -> Self:
598610
self.set_points_as_corners(points)
599611
self.make_smooth()
600612
return self
@@ -1654,7 +1666,7 @@ def construct(self):
16541666
self.add(circles_group)
16551667
"""
16561668

1657-
def __init__(self, *vmobjects, **kwargs):
1669+
def __init__(self, *vmobjects: OpenGLVMobject, **kwargs: Any):
16581670
super().__init__(**kwargs)
16591671
self.add(*vmobjects)
16601672

manim/renderer/vectorized_mobject_rendering.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
from __future__ import annotations
22

33
import collections
4+
from collections.abc import Iterable, Sequence
5+
from typing import TYPE_CHECKING
46

57
import numpy as np
68

9+
if TYPE_CHECKING:
10+
from manim.renderer.opengl_renderer import (
11+
OpenGLRenderer,
12+
OpenGLVMobject,
13+
)
14+
from manim.typing import MatrixMN
15+
716
from ..utils import opengl
817
from ..utils.space_ops import cross2d, earclip_triangulation
918
from .shader import Shader
@@ -14,7 +23,9 @@
1423
]
1524

1625

17-
def build_matrix_lists(mob):
26+
def build_matrix_lists(
27+
mob: OpenGLVMobject,
28+
) -> collections.defaultdict[tuple[float, ...], list[OpenGLVMobject]]:
1829
root_hierarchical_matrix = mob.hierarchical_model_matrix()
1930
matrix_to_mobject_list = collections.defaultdict(list)
2031
if mob.has_points():
@@ -36,15 +47,21 @@ def build_matrix_lists(mob):
3647
return matrix_to_mobject_list
3748

3849

39-
def render_opengl_vectorized_mobject_fill(renderer, mobject):
50+
def render_opengl_vectorized_mobject_fill(
51+
renderer: OpenGLRenderer, mobject: OpenGLVMobject
52+
) -> None:
4053
matrix_to_mobject_list = build_matrix_lists(mobject)
4154

4255
for matrix_tuple, mobject_list in matrix_to_mobject_list.items():
4356
model_matrix = np.array(matrix_tuple).reshape((4, 4))
4457
render_mobject_fills_with_matrix(renderer, model_matrix, mobject_list)
4558

4659

47-
def render_mobject_fills_with_matrix(renderer, model_matrix, mobjects):
60+
def render_mobject_fills_with_matrix(
61+
renderer: OpenGLRenderer,
62+
model_matrix: MatrixMN,
63+
mobjects: Iterable[OpenGLVMobject],
64+
) -> None:
4865
# Precompute the total number of vertices for which to reserve space.
4966
# Note that triangulate_mobject() will cache its results.
5067
total_size = 0
@@ -98,7 +115,7 @@ def render_mobject_fills_with_matrix(renderer, model_matrix, mobjects):
98115
vbo.release()
99116

100117

101-
def triangulate_mobject(mob):
118+
def triangulate_mobject(mob: OpenGLVMobject) -> np.ndarray:
102119
if not mob.needs_new_triangulation:
103120
return mob.triangulation
104121

@@ -192,14 +209,20 @@ def triangulate_mobject(mob):
192209
return attributes
193210

194211

195-
def render_opengl_vectorized_mobject_stroke(renderer, mobject):
212+
def render_opengl_vectorized_mobject_stroke(
213+
renderer: OpenGLRenderer, mobject: OpenGLVMobject
214+
) -> None:
196215
matrix_to_mobject_list = build_matrix_lists(mobject)
197216
for matrix_tuple, mobject_list in matrix_to_mobject_list.items():
198217
model_matrix = np.array(matrix_tuple).reshape((4, 4))
199218
render_mobject_strokes_with_matrix(renderer, model_matrix, mobject_list)
200219

201220

202-
def render_mobject_strokes_with_matrix(renderer, model_matrix, mobjects):
221+
def render_mobject_strokes_with_matrix(
222+
renderer: OpenGLRenderer,
223+
model_matrix: MatrixMN,
224+
mobjects: Sequence[OpenGLVMobject],
225+
) -> None:
203226
# Precompute the total number of vertices for which to reserve space.
204227
total_size = 0
205228
for submob in mobjects:

mypy.ini

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ ignore_errors = True
102102
[mypy-manim.mobject.mobject]
103103
ignore_errors = True
104104

105-
[mypy-manim.mobject.opengl.dot_cloud]
106-
ignore_errors = True
107-
108105
[mypy-manim.mobject.opengl.opengl_compatibility]
109106
ignore_errors = True
110107

@@ -123,9 +120,6 @@ ignore_errors = True
123120
[mypy-manim.mobject.opengl.opengl_surface]
124121
ignore_errors = True
125122

126-
[mypy-manim.mobject.opengl.opengl_three_dimensions]
127-
ignore_errors = True
128-
129123
[mypy-manim.mobject.opengl.opengl_vectorized_mobject]
130124
ignore_errors = True
131125

@@ -162,9 +156,6 @@ ignore_errors = True
162156
[mypy-manim.renderer.shader_wrapper]
163157
ignore_errors = True
164158

165-
[mypy-manim.renderer.vectorized_mobject_rendering]
166-
ignore_errors = True
167-
168159
[mypy-manim.scene.three_d_scene]
169160
ignore_errors = True
170161

0 commit comments

Comments
 (0)