Skip to content

Commit a56c06c

Browse files
chopan050behackl
andauthored
Rename types like RGBA_Array_Float to FloatRGBA and add types like FloatRGBA_Array (#4386)
Co-authored-by: Benjamin Hackl <[email protected]>
1 parent 03a9414 commit a56c06c

File tree

13 files changed

+350
-273
lines changed

13 files changed

+350
-273
lines changed

manim/camera/camera.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,19 @@
1414

1515
import cairo
1616
import numpy as np
17+
import numpy.typing as npt
1718
from PIL import Image
1819
from scipy.spatial.distance import pdist
1920
from typing_extensions import Self
2021

21-
from manim.typing import MatrixMN, PixelArray, Point3D, Point3D_Array
22+
from manim.typing import (
23+
FloatRGBA_Array,
24+
FloatRGBALike_Array,
25+
ManimInt,
26+
PixelArray,
27+
Point3D,
28+
Point3D_Array,
29+
)
2230

2331
from .. import config, logger
2432
from ..constants import *
@@ -211,8 +219,8 @@ def type_or_raise(
211219
type[Mobject], Callable[[list[Mobject], PixelArray], Any]
212220
] = {
213221
VMobject: self.display_multiple_vectorized_mobjects, # type: ignore[dict-item]
214-
PMobject: self.display_multiple_point_cloud_mobjects,
215-
AbstractImageMobject: self.display_multiple_image_mobjects,
222+
PMobject: self.display_multiple_point_cloud_mobjects, # type: ignore[dict-item]
223+
AbstractImageMobject: self.display_multiple_image_mobjects, # type: ignore[dict-item]
216224
Mobject: lambda batch, pa: batch, # Do nothing
217225
}
218226
# We have to check each type in turn because we are dealing with
@@ -723,7 +731,7 @@ def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject) -> Self
723731
return self
724732

725733
def set_cairo_context_color(
726-
self, ctx: cairo.Context, rgbas: MatrixMN, vmobject: VMobject
734+
self, ctx: cairo.Context, rgbas: FloatRGBALike_Array, vmobject: VMobject
727735
) -> Self:
728736
"""Sets the color of the cairo context
729737
@@ -818,7 +826,7 @@ def apply_stroke(
818826

819827
def get_stroke_rgbas(
820828
self, vmobject: VMobject, background: bool = False
821-
) -> PixelArray:
829+
) -> FloatRGBA_Array:
822830
"""Gets the RGBA array for the stroke of the passed
823831
VMobject.
824832
@@ -837,7 +845,7 @@ def get_stroke_rgbas(
837845
"""
838846
return vmobject.get_stroke_rgbas(background)
839847

840-
def get_fill_rgbas(self, vmobject: VMobject) -> PixelArray:
848+
def get_fill_rgbas(self, vmobject: VMobject) -> FloatRGBA_Array:
841849
"""Returns the RGBA array of the fill of the passed VMobject
842850
843851
Parameters
@@ -898,7 +906,7 @@ def display_multiple_background_colored_vmobjects(
898906
# As a result, the other methods do not have as detailed docstrings as would be preferred.
899907

900908
def display_multiple_point_cloud_mobjects(
901-
self, pmobjects: list, pixel_array: PixelArray
909+
self, pmobjects: Iterable[PMobject], pixel_array: PixelArray
902910
) -> None:
903911
"""Displays multiple PMobjects by modifying the passed pixel array.
904912
@@ -921,8 +929,8 @@ def display_multiple_point_cloud_mobjects(
921929
def display_point_cloud(
922930
self,
923931
pmobject: PMobject,
924-
points: list,
925-
rgbas: np.ndarray,
932+
points: Point3D_Array,
933+
rgbas: FloatRGBA_Array,
926934
thickness: float,
927935
pixel_array: PixelArray,
928936
) -> None:
@@ -972,7 +980,9 @@ def display_point_cloud(
972980
pixel_array[:, :] = new_pa.reshape((ph, pw, rgba_len))
973981

974982
def display_multiple_image_mobjects(
975-
self, image_mobjects: list, pixel_array: np.ndarray
983+
self,
984+
image_mobjects: Iterable[AbstractImageMobject],
985+
pixel_array: PixelArray,
976986
) -> None:
977987
"""Displays multiple image mobjects by modifying the passed pixel_array.
978988
@@ -1121,8 +1131,8 @@ def transform_points_pre_display(
11211131
def points_to_pixel_coords(
11221132
self,
11231133
mobject: Mobject,
1124-
points: np.ndarray,
1125-
) -> np.ndarray: # TODO: Write more detailed docstrings for this method.
1134+
points: Point3D_Array,
1135+
) -> npt.NDArray[ManimInt]: # TODO: Write more detailed docstrings for this method.
11261136
points = self.transform_points_pre_display(mobject, points)
11271137
shifted_points = points - self.frame_center
11281138

manim/camera/three_d_camera.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from manim.mobject.types.vectorized_mobject import VMobject
2121
from manim.mobject.value_tracker import ValueTracker
2222
from manim.typing import (
23+
FloatRGBA_Array,
2324
MatrixMN,
2425
Point3D,
2526
Point3D_Array,
@@ -109,7 +110,9 @@ def get_value_trackers(self) -> list[ValueTracker]:
109110
self.zoom_tracker,
110111
]
111112

112-
def modified_rgbas(self, vmobject: VMobject, rgbas: MatrixMN) -> MatrixMN:
113+
def modified_rgbas(
114+
self, vmobject: VMobject, rgbas: FloatRGBA_Array
115+
) -> FloatRGBA_Array:
113116
if not self.should_apply_shading:
114117
return rgbas
115118
if vmobject.shade_in_3d and (vmobject.get_num_points() > 0):
@@ -137,12 +140,12 @@ def get_stroke_rgbas(
137140
self,
138141
vmobject: VMobject,
139142
background: bool = False,
140-
) -> MatrixMN: # NOTE : DocStrings From parent
143+
) -> FloatRGBA_Array: # NOTE : DocStrings From parent
141144
return self.modified_rgbas(vmobject, vmobject.get_stroke_rgbas(background))
142145

143146
def get_fill_rgbas(
144147
self, vmobject: VMobject
145-
) -> MatrixMN: # NOTE : DocStrings From parent
148+
) -> FloatRGBA_Array: # NOTE : DocStrings From parent
146149
return self.modified_rgbas(vmobject, vmobject.get_fill_rgbas())
147150

148151
def get_mobjects_to_display(

manim/mobject/opengl/opengl_mobject.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353

5454
from manim.renderer.shader_wrapper import ShaderWrapper
5555
from manim.typing import (
56+
FloatRGB_Array,
57+
FloatRGBA_Array,
5658
ManimFloat,
5759
MappingFunction,
5860
MatrixMN,
@@ -328,9 +330,9 @@ def init_data(self) -> None:
328330
"""Initializes the ``points``, ``bounding_box`` and ``rgbas`` attributes and groups them into self.data.
329331
Subclasses can inherit and overwrite this method to extend `self.data`.
330332
"""
331-
self.points = np.zeros((0, 3))
332-
self.bounding_box = np.zeros((3, 3))
333-
self.rgbas = np.zeros((1, 4))
333+
self.points: Point3D_Array = np.zeros((0, 3))
334+
self.bounding_box: Point3D_Array = np.zeros((3, 3))
335+
self.rgbas: FloatRGBA_Array = np.zeros((1, 4))
334336

335337
def init_colors(self) -> object:
336338
"""Initializes the colors.
@@ -2082,7 +2084,7 @@ def set_rgba_array(
20822084
recurse: bool = True,
20832085
) -> Self:
20842086
if color is not None:
2085-
rgbs = np.array([color_to_rgb(c) for c in listify(color)])
2087+
rgbs: FloatRGB_Array = np.array([color_to_rgb(c) for c in listify(color)])
20862088
if opacity is not None:
20872089
opacities = listify(opacity)
20882090

@@ -2105,14 +2107,16 @@ def set_rgba_array(
21052107

21062108
# Color and opacity
21072109
if color is not None and opacity is not None:
2108-
rgbas = np.array([[*rgb, o] for rgb, o in zip(*make_even(rgbs, opacities))])
2110+
rgbas: FloatRGBA_Array = np.array(
2111+
[[*rgb, o] for rgb, o in zip(*make_even(rgbs, opacities))]
2112+
)
21092113
for mob in self.get_family(recurse):
21102114
mob.data[name] = rgbas.copy()
21112115
return self
21122116

21132117
def set_rgba_array_direct(
21142118
self,
2115-
rgbas: npt.NDArray[RGBA_Array_Float],
2119+
rgbas: FloatRGBA_Array,
21162120
name: str = "rgbas",
21172121
recurse: bool = True,
21182122
) -> Self:
@@ -2794,6 +2798,7 @@ def set_color_by_xyz_func(
27942798
# of the shader code
27952799
for char in "xyz":
27962800
glsl_snippet = glsl_snippet.replace(char, "point." + char)
2801+
# TODO: get_colormap_list does not exist
27972802
rgb_list = get_colormap_list(colormap)
27982803
self.set_color_by_code(
27992804
f"color.rgb = float_to_color({glsl_snippet}, {float(min_value)}, {float(max_value)}, {get_colormap_code(rgb_list)});",

manim/mobject/opengl/opengl_point_cloud_mobject.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
__all__ = ["OpenGLPMobject", "OpenGLPGroup", "OpenGLPMPoint"]
44

5+
from typing import TYPE_CHECKING
6+
57
import moderngl
68
import numpy as np
7-
from typing_extensions import Self
89

910
from manim.constants import *
1011
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
@@ -20,6 +21,16 @@
2021
from manim.utils.config_ops import _Uniforms
2122
from manim.utils.iterables import resize_with_interpolation
2223

24+
if TYPE_CHECKING:
25+
from typing_extensions import Self
26+
27+
from manim.typing import (
28+
FloatRGBA_Array,
29+
FloatRGBALike_Array,
30+
Point3D_Array,
31+
Point3DLike_Array,
32+
)
33+
2334
__all__ = ["OpenGLPMobject", "OpenGLPGroup", "OpenGLPMPoint"]
2435

2536

@@ -48,14 +59,20 @@ def __init__(
4859
)
4960

5061
def reset_points(self) -> Self:
51-
self.rgbas = np.zeros((1, 4))
52-
self.points = np.zeros((0, 3))
62+
self.rgbas: FloatRGBA_Array = np.zeros((1, 4))
63+
self.points: Point3D_Array = np.zeros((0, 3))
5364
return self
5465

5566
def get_array_attrs(self):
5667
return ["points", "rgbas"]
5768

58-
def add_points(self, points, rgbas=None, color=None, opacity=None):
69+
def add_points(
70+
self,
71+
points: Point3DLike_Array,
72+
rgbas: FloatRGBALike_Array | None = None,
73+
color: ParsableManimColor | None = None,
74+
opacity: float | None = None,
75+
) -> Self:
5976
"""Add points.
6077
6178
Points must be a Nx3 numpy array.

manim/mobject/opengl/opengl_vectorized_mobject.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ class OpenGLVMobject(OpenGLMobject):
8484
stroke_shader_folder = "quadratic_bezier_stroke"
8585
fill_shader_folder = "quadratic_bezier_fill"
8686

87+
# TODO: although these are called "rgba" in singular, they are used as
88+
# FloatRGBA_Arrays and should be called instead "rgbas" in plural for consistency.
89+
# The same should probably apply for "stroke_width" and "unit_normal".
8790
fill_rgba = _Data()
8891
stroke_rgba = _Data()
8992
stroke_width = _Data()

manim/mobject/types/point_cloud_mobject.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,14 @@
3333
import numpy.typing as npt
3434
from typing_extensions import Self
3535

36-
from manim.typing import ManimFloat, Point3DLike
36+
from manim.typing import (
37+
FloatRGBA_Array,
38+
FloatRGBALike_Array,
39+
ManimFloat,
40+
Point3D_Array,
41+
Point3DLike,
42+
Point3DLike_Array,
43+
)
3744

3845

3946
class PMobject(Mobject, metaclass=ConvertToOpenGL):
@@ -70,19 +77,19 @@ def __init__(self, stroke_width: int = DEFAULT_STROKE_WIDTH, **kwargs: Any) -> N
7077
super().__init__(**kwargs)
7178

7279
def reset_points(self) -> Self:
73-
self.rgbas = np.zeros((0, 4))
74-
self.points = np.zeros((0, 3))
80+
self.rgbas: FloatRGBA_Array = np.zeros((0, 4))
81+
self.points: Point3D_Array = np.zeros((0, 3))
7582
return self
7683

7784
def get_array_attrs(self) -> list[str]:
7885
return super().get_array_attrs() + ["rgbas"]
7986

8087
def add_points(
8188
self,
82-
points: npt.NDArray,
83-
rgbas: npt.NDArray | None = None,
89+
points: Point3DLike_Array,
90+
rgbas: FloatRGBALike_Array | None = None,
8491
color: ParsableManimColor | None = None,
85-
alpha: float = 1,
92+
alpha: float = 1.0,
8693
) -> Self:
8794
"""Add points.
8895

manim/mobject/types/vectorized_mobject.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,17 @@
5454
CubicBezierPath,
5555
CubicBezierPointsLike,
5656
CubicSpline,
57+
FloatRGBA,
58+
FloatRGBA_Array,
5759
ManimFloat,
5860
MappingFunction,
5961
Point2DLike,
6062
Point3D,
6163
Point3D_Array,
6264
Point3DLike,
6365
Point3DLike_Array,
64-
RGBA_Array_Float,
6566
Vector3D,
6667
Vector3DLike,
67-
Zeros,
6868
)
6969

7070
# TODO
@@ -215,8 +215,10 @@ def init_colors(self, propagate_colors: bool = True) -> Self:
215215
return self
216216

217217
def generate_rgbas_array(
218-
self, color: ManimColor | list[ManimColor], opacity: float | Iterable[float]
219-
) -> RGBA_Array_Float:
218+
self,
219+
color: ParsableManimColor | Iterable[ManimColor] | None,
220+
opacity: float | Iterable[float],
221+
) -> FloatRGBA:
220222
"""
221223
First arg can be either a color, or a tuple/list of colors.
222224
Likewise, opacity can either be a float, or a tuple of floats.
@@ -230,7 +232,7 @@ def generate_rgbas_array(
230232
opacities: list[float] = [
231233
o if (o is not None) else 0.0 for o in tuplify(opacity)
232234
]
233-
rgbas: npt.NDArray[RGBA_Array_Float] = np.array(
235+
rgbas: FloatRGBA_Array = np.array(
234236
[c.to_rgba_with_alpha(o) for c, o in zip(*make_even(colors, opacities))],
235237
)
236238

@@ -245,7 +247,7 @@ def generate_rgbas_array(
245247
def update_rgbas_array(
246248
self,
247249
array_name: str,
248-
color: ManimColor | None = None,
250+
color: ParsableManimColor | Iterable[ManimColor] | None = None,
249251
opacity: float | None = None,
250252
) -> Self:
251253
rgbas = self.generate_rgbas_array(color, opacity)
@@ -313,7 +315,7 @@ def construct(self):
313315
for submobject in self.submobjects:
314316
submobject.set_fill(color, opacity, family)
315317
self.update_rgbas_array("fill_rgbas", color, opacity)
316-
self.fill_rgbas: RGBA_Array_Float
318+
self.fill_rgbas: FloatRGBA_Array
317319
if opacity is not None:
318320
self.fill_opacity = opacity
319321
return self
@@ -539,7 +541,7 @@ def fade(self, darkness: float = 0.5, family: bool = True) -> Self:
539541
super().fade(darkness, family)
540542
return self
541543

542-
def get_fill_rgbas(self) -> RGBA_Array_Float | Zeros:
544+
def get_fill_rgbas(self) -> FloatRGBA_Array:
543545
try:
544546
return self.fill_rgbas
545547
except AttributeError:
@@ -572,13 +574,13 @@ def get_fill_colors(self) -> list[ManimColor | None]:
572574
def get_fill_opacities(self) -> npt.NDArray[ManimFloat]:
573575
return self.get_fill_rgbas()[:, 3]
574576

575-
def get_stroke_rgbas(self, background: bool = False) -> RGBA_Array_float | Zeros:
577+
def get_stroke_rgbas(self, background: bool = False) -> FloatRGBA_Array:
576578
try:
577579
if background:
578-
self.background_stroke_rgbas: RGBA_Array_Float
580+
self.background_stroke_rgbas: FloatRGBA_Array
579581
rgbas = self.background_stroke_rgbas
580582
else:
581-
self.stroke_rgbas: RGBA_Array_Float
583+
self.stroke_rgbas: FloatRGBA_Array
582584
rgbas = self.stroke_rgbas
583585
return rgbas
584586
except AttributeError:

0 commit comments

Comments
 (0)