Skip to content

Commit 7cd7111

Browse files
committed
Merge remote-tracking branch 'upstream/main' into RefactorExcludeDots
# Conflicts: # manim/mobject/text/text_mobject.py
2 parents f4f7387 + 2d3aa0d commit 7cd7111

21 files changed

+483
-396
lines changed

manim/camera/camera.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,19 @@
1414

1515
import cairo
1616
import numpy as np
17+
import numpy.typing as npt
1718
from PIL import Image
1819
from scipy.spatial.distance import pdist
1920
from typing_extensions import Self
2021

21-
from manim.typing import MatrixMN, PixelArray, Point3D, Point3D_Array
22+
from manim.typing import (
23+
FloatRGBA_Array,
24+
FloatRGBALike_Array,
25+
ManimInt,
26+
PixelArray,
27+
Point3D,
28+
Point3D_Array,
29+
)
2230

2331
from .. import config, logger
2432
from ..constants import *
@@ -211,8 +219,8 @@ def type_or_raise(
211219
type[Mobject], Callable[[list[Mobject], PixelArray], Any]
212220
] = {
213221
VMobject: self.display_multiple_vectorized_mobjects, # type: ignore[dict-item]
214-
PMobject: self.display_multiple_point_cloud_mobjects,
215-
AbstractImageMobject: self.display_multiple_image_mobjects,
222+
PMobject: self.display_multiple_point_cloud_mobjects, # type: ignore[dict-item]
223+
AbstractImageMobject: self.display_multiple_image_mobjects, # type: ignore[dict-item]
216224
Mobject: lambda batch, pa: batch, # Do nothing
217225
}
218226
# We have to check each type in turn because we are dealing with
@@ -723,7 +731,7 @@ def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject) -> Self
723731
return self
724732

725733
def set_cairo_context_color(
726-
self, ctx: cairo.Context, rgbas: MatrixMN, vmobject: VMobject
734+
self, ctx: cairo.Context, rgbas: FloatRGBALike_Array, vmobject: VMobject
727735
) -> Self:
728736
"""Sets the color of the cairo context
729737
@@ -818,7 +826,7 @@ def apply_stroke(
818826

819827
def get_stroke_rgbas(
820828
self, vmobject: VMobject, background: bool = False
821-
) -> PixelArray:
829+
) -> FloatRGBA_Array:
822830
"""Gets the RGBA array for the stroke of the passed
823831
VMobject.
824832
@@ -837,7 +845,7 @@ def get_stroke_rgbas(
837845
"""
838846
return vmobject.get_stroke_rgbas(background)
839847

840-
def get_fill_rgbas(self, vmobject: VMobject) -> PixelArray:
848+
def get_fill_rgbas(self, vmobject: VMobject) -> FloatRGBA_Array:
841849
"""Returns the RGBA array of the fill of the passed VMobject
842850
843851
Parameters
@@ -898,7 +906,7 @@ def display_multiple_background_colored_vmobjects(
898906
# As a result, the other methods do not have as detailed docstrings as would be preferred.
899907

900908
def display_multiple_point_cloud_mobjects(
901-
self, pmobjects: list, pixel_array: PixelArray
909+
self, pmobjects: Iterable[PMobject], pixel_array: PixelArray
902910
) -> None:
903911
"""Displays multiple PMobjects by modifying the passed pixel array.
904912
@@ -921,8 +929,8 @@ def display_multiple_point_cloud_mobjects(
921929
def display_point_cloud(
922930
self,
923931
pmobject: PMobject,
924-
points: list,
925-
rgbas: np.ndarray,
932+
points: Point3D_Array,
933+
rgbas: FloatRGBA_Array,
926934
thickness: float,
927935
pixel_array: PixelArray,
928936
) -> None:
@@ -972,7 +980,9 @@ def display_point_cloud(
972980
pixel_array[:, :] = new_pa.reshape((ph, pw, rgba_len))
973981

974982
def display_multiple_image_mobjects(
975-
self, image_mobjects: list, pixel_array: np.ndarray
983+
self,
984+
image_mobjects: Iterable[AbstractImageMobject],
985+
pixel_array: PixelArray,
976986
) -> None:
977987
"""Displays multiple image mobjects by modifying the passed pixel_array.
978988
@@ -1121,8 +1131,8 @@ def transform_points_pre_display(
11211131
def points_to_pixel_coords(
11221132
self,
11231133
mobject: Mobject,
1124-
points: np.ndarray,
1125-
) -> np.ndarray: # TODO: Write more detailed docstrings for this method.
1134+
points: Point3D_Array,
1135+
) -> npt.NDArray[ManimInt]: # TODO: Write more detailed docstrings for this method.
11261136
points = self.transform_points_pre_display(mobject, points)
11271137
shifted_points = points - self.frame_center
11281138

manim/camera/three_d_camera.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from manim.mobject.types.vectorized_mobject import VMobject
2121
from manim.mobject.value_tracker import ValueTracker
2222
from manim.typing import (
23+
FloatRGBA_Array,
2324
MatrixMN,
2425
Point3D,
2526
Point3D_Array,
@@ -109,7 +110,9 @@ def get_value_trackers(self) -> list[ValueTracker]:
109110
self.zoom_tracker,
110111
]
111112

112-
def modified_rgbas(self, vmobject: VMobject, rgbas: MatrixMN) -> MatrixMN:
113+
def modified_rgbas(
114+
self, vmobject: VMobject, rgbas: FloatRGBA_Array
115+
) -> FloatRGBA_Array:
113116
if not self.should_apply_shading:
114117
return rgbas
115118
if vmobject.shade_in_3d and (vmobject.get_num_points() > 0):
@@ -137,12 +140,12 @@ def get_stroke_rgbas(
137140
self,
138141
vmobject: VMobject,
139142
background: bool = False,
140-
) -> MatrixMN: # NOTE : DocStrings From parent
143+
) -> FloatRGBA_Array: # NOTE : DocStrings From parent
141144
return self.modified_rgbas(vmobject, vmobject.get_stroke_rgbas(background))
142145

143146
def get_fill_rgbas(
144147
self, vmobject: VMobject
145-
) -> MatrixMN: # NOTE : DocStrings From parent
148+
) -> FloatRGBA_Array: # NOTE : DocStrings From parent
146149
return self.modified_rgbas(vmobject, vmobject.get_fill_rgbas())
147150

148151
def get_mobjects_to_display(

manim/mobject/mobject.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2357,7 +2357,7 @@ def __getitem__(self, value):
23572357
def __iter__(self):
23582358
return iter(self.split())
23592359

2360-
def __len__(self):
2360+
def __len__(self) -> int:
23612361
return len(self.split())
23622362

23632363
def get_group_class(self) -> type[Group]:

manim/mobject/opengl/opengl_mobject.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353

5454
from manim.renderer.shader_wrapper import ShaderWrapper
5555
from manim.typing import (
56+
FloatRGB_Array,
57+
FloatRGBA_Array,
5658
ManimFloat,
5759
MappingFunction,
5860
MatrixMN,
@@ -328,9 +330,9 @@ def init_data(self) -> None:
328330
"""Initializes the ``points``, ``bounding_box`` and ``rgbas`` attributes and groups them into self.data.
329331
Subclasses can inherit and overwrite this method to extend `self.data`.
330332
"""
331-
self.points = np.zeros((0, 3))
332-
self.bounding_box = np.zeros((3, 3))
333-
self.rgbas = np.zeros((1, 4))
333+
self.points: Point3D_Array = np.zeros((0, 3))
334+
self.bounding_box: Point3D_Array = np.zeros((3, 3))
335+
self.rgbas: FloatRGBA_Array = np.zeros((1, 4))
334336

335337
def init_colors(self) -> object:
336338
"""Initializes the colors.
@@ -2082,7 +2084,7 @@ def set_rgba_array(
20822084
recurse: bool = True,
20832085
) -> Self:
20842086
if color is not None:
2085-
rgbs = np.array([color_to_rgb(c) for c in listify(color)])
2087+
rgbs: FloatRGB_Array = np.array([color_to_rgb(c) for c in listify(color)])
20862088
if opacity is not None:
20872089
opacities = listify(opacity)
20882090

@@ -2105,14 +2107,16 @@ def set_rgba_array(
21052107

21062108
# Color and opacity
21072109
if color is not None and opacity is not None:
2108-
rgbas = np.array([[*rgb, o] for rgb, o in zip(*make_even(rgbs, opacities))])
2110+
rgbas: FloatRGBA_Array = np.array(
2111+
[[*rgb, o] for rgb, o in zip(*make_even(rgbs, opacities))]
2112+
)
21092113
for mob in self.get_family(recurse):
21102114
mob.data[name] = rgbas.copy()
21112115
return self
21122116

21132117
def set_rgba_array_direct(
21142118
self,
2115-
rgbas: npt.NDArray[RGBA_Array_Float],
2119+
rgbas: FloatRGBA_Array,
21162120
name: str = "rgbas",
21172121
recurse: bool = True,
21182122
) -> Self:
@@ -2794,6 +2798,7 @@ def set_color_by_xyz_func(
27942798
# of the shader code
27952799
for char in "xyz":
27962800
glsl_snippet = glsl_snippet.replace(char, "point." + char)
2801+
# TODO: get_colormap_list does not exist
27972802
rgb_list = get_colormap_list(colormap)
27982803
self.set_color_by_code(
27992804
f"color.rgb = float_to_color({glsl_snippet}, {float(min_value)}, {float(max_value)}, {get_colormap_code(rgb_list)});",

manim/mobject/opengl/opengl_point_cloud_mobject.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
__all__ = ["OpenGLPMobject", "OpenGLPGroup", "OpenGLPMPoint"]
44

5+
from typing import TYPE_CHECKING
6+
57
import moderngl
68
import numpy as np
7-
from typing_extensions import Self
89

910
from manim.constants import *
1011
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
@@ -20,6 +21,16 @@
2021
from manim.utils.config_ops import _Uniforms
2122
from manim.utils.iterables import resize_with_interpolation
2223

24+
if TYPE_CHECKING:
25+
from typing_extensions import Self
26+
27+
from manim.typing import (
28+
FloatRGBA_Array,
29+
FloatRGBALike_Array,
30+
Point3D_Array,
31+
Point3DLike_Array,
32+
)
33+
2334
__all__ = ["OpenGLPMobject", "OpenGLPGroup", "OpenGLPMPoint"]
2435

2536

@@ -48,14 +59,20 @@ def __init__(
4859
)
4960

5061
def reset_points(self) -> Self:
51-
self.rgbas = np.zeros((1, 4))
52-
self.points = np.zeros((0, 3))
62+
self.rgbas: FloatRGBA_Array = np.zeros((1, 4))
63+
self.points: Point3D_Array = np.zeros((0, 3))
5364
return self
5465

5566
def get_array_attrs(self):
5667
return ["points", "rgbas"]
5768

58-
def add_points(self, points, rgbas=None, color=None, opacity=None):
69+
def add_points(
70+
self,
71+
points: Point3DLike_Array,
72+
rgbas: FloatRGBALike_Array | None = None,
73+
color: ParsableManimColor | None = None,
74+
opacity: float | None = None,
75+
) -> Self:
5976
"""Add points.
6077
6178
Points must be a Nx3 numpy array.

manim/mobject/opengl/opengl_vectorized_mobject.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ class OpenGLVMobject(OpenGLMobject):
8484
stroke_shader_folder = "quadratic_bezier_stroke"
8585
fill_shader_folder = "quadratic_bezier_fill"
8686

87+
# TODO: although these are called "rgba" in singular, they are used as
88+
# FloatRGBA_Arrays and should be called instead "rgbas" in plural for consistency.
89+
# The same should probably apply for "stroke_width" and "unit_normal".
8790
fill_rgba = _Data()
8891
stroke_rgba = _Data()
8992
stroke_width = _Data()

manim/mobject/svg/svg_mobject.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import svgelements as se
1212

1313
from manim import config, logger
14-
from manim.utils.color import ParsableManimColor
14+
from manim.utils.color import ManimColor, ParsableManimColor
1515

1616
from ...constants import RIGHT
1717
from ...utils.bezier import get_quadratic_approximation_of_cubic
@@ -120,7 +120,7 @@ def __init__(
120120
self.should_center = should_center
121121
self.svg_height = height
122122
self.svg_width = width
123-
self.color = color
123+
self.color = ManimColor(color)
124124
self.opacity = opacity
125125
self.fill_color = fill_color
126126
self.fill_opacity = fill_opacity # type: ignore[assignment]

manim/mobject/text/code_mobject.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from manim.mobject.geometry.arc import Dot
2020
from manim.mobject.geometry.shape_matchers import SurroundingRectangle
2121
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
22-
from manim.mobject.text.text_mobject import Paragraph
2322
from manim.mobject.types.vectorized_mobject import VGroup, VMobject
2423
from manim.typing import StrPath
2524
from manim.utils.color import WHITE, ManimColor
@@ -119,6 +118,7 @@ def construct(self):
119118
"line_spacing": 0.5,
120119
"disable_ligatures": True,
121120
}
121+
code: VMobject
122122

123123
def __init__(
124124
self,
@@ -200,6 +200,8 @@ def __init__(
200200
base_paragraph_config = self.default_paragraph_config.copy()
201201
base_paragraph_config.update(paragraph_config)
202202

203+
from manim.mobject.text.text_mobject import Paragraph
204+
203205
self.code_lines = Paragraph(
204206
*code_lines,
205207
**base_paragraph_config,
@@ -224,6 +226,8 @@ def __init__(
224226
)
225227
self.add(self.line_numbers)
226228

229+
for line in self.code_lines:
230+
line.submobjects = [c for c in line if not isinstance(c, Dot)]
227231
self.add(self.code_lines)
228232

229233
if background_config is None:

manim/mobject/text/numbers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from manim.mobject.value_tracker import ValueTracker
1919
from manim.typing import Vector3DLike
2020

21-
string_to_mob_map: dict[str, VMobject] = {}
21+
string_to_mob_map: dict[str, SingleStringMathTex] = {}
2222

2323

2424
class DecimalNumber(VMobject, metaclass=ConvertToOpenGL):
@@ -227,7 +227,7 @@ def _string_to_mob(
227227
if string not in string_to_mob_map:
228228
string_to_mob_map[string] = mob_class(string, **kwargs)
229229
mob = string_to_mob_map[string].copy()
230-
mob.font_size = self._font_size # type: ignore[attr-defined]
230+
mob.font_size = self._font_size
231231
return mob
232232

233233
def _get_formatter(self, **kwargs: Any) -> str:

0 commit comments

Comments
 (0)