Skip to content

Rename types like RGBA_Array_Float to FloatRGBA and add types like FloatRGBA_Array #4386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions manim/camera/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,19 @@

import cairo
import numpy as np
import numpy.typing as npt
from PIL import Image
from scipy.spatial.distance import pdist
from typing_extensions import Self

from manim.typing import MatrixMN, PixelArray, Point3D, Point3D_Array
from manim.typing import (
FloatRGBA_Array,
FloatRGBALike_Array,
ManimInt,
PixelArray,
Point3D,
Point3D_Array,
)

from .. import config, logger
from ..constants import *
Expand Down Expand Up @@ -211,8 +219,8 @@ def type_or_raise(
type[Mobject], Callable[[list[Mobject], PixelArray], Any]
] = {
VMobject: self.display_multiple_vectorized_mobjects, # type: ignore[dict-item]
PMobject: self.display_multiple_point_cloud_mobjects,
AbstractImageMobject: self.display_multiple_image_mobjects,
PMobject: self.display_multiple_point_cloud_mobjects, # type: ignore[dict-item]
AbstractImageMobject: self.display_multiple_image_mobjects, # type: ignore[dict-item]
Mobject: lambda batch, pa: batch, # Do nothing
}
# We have to check each type in turn because we are dealing with
Expand Down Expand Up @@ -723,7 +731,7 @@ def set_cairo_context_path(self, ctx: cairo.Context, vmobject: VMobject) -> Self
return self

def set_cairo_context_color(
self, ctx: cairo.Context, rgbas: MatrixMN, vmobject: VMobject
self, ctx: cairo.Context, rgbas: FloatRGBALike_Array, vmobject: VMobject
) -> Self:
"""Sets the color of the cairo context

Expand Down Expand Up @@ -818,7 +826,7 @@ def apply_stroke(

def get_stroke_rgbas(
self, vmobject: VMobject, background: bool = False
) -> PixelArray:
) -> FloatRGBA_Array:
"""Gets the RGBA array for the stroke of the passed
VMobject.

Expand All @@ -837,7 +845,7 @@ def get_stroke_rgbas(
"""
return vmobject.get_stroke_rgbas(background)

def get_fill_rgbas(self, vmobject: VMobject) -> PixelArray:
def get_fill_rgbas(self, vmobject: VMobject) -> FloatRGBA_Array:
"""Returns the RGBA array of the fill of the passed VMobject

Parameters
Expand Down Expand Up @@ -898,7 +906,7 @@ def display_multiple_background_colored_vmobjects(
# As a result, the other methods do not have as detailed docstrings as would be preferred.

def display_multiple_point_cloud_mobjects(
self, pmobjects: list, pixel_array: PixelArray
self, pmobjects: Iterable[PMobject], pixel_array: PixelArray
) -> None:
"""Displays multiple PMobjects by modifying the passed pixel array.

Expand All @@ -921,8 +929,8 @@ def display_multiple_point_cloud_mobjects(
def display_point_cloud(
self,
pmobject: PMobject,
points: list,
rgbas: np.ndarray,
points: Point3D_Array,
rgbas: FloatRGBA_Array,
thickness: float,
pixel_array: PixelArray,
) -> None:
Expand Down Expand Up @@ -972,7 +980,9 @@ def display_point_cloud(
pixel_array[:, :] = new_pa.reshape((ph, pw, rgba_len))

def display_multiple_image_mobjects(
self, image_mobjects: list, pixel_array: np.ndarray
self,
image_mobjects: Iterable[AbstractImageMobject],
pixel_array: PixelArray,
) -> None:
"""Displays multiple image mobjects by modifying the passed pixel_array.

Expand Down Expand Up @@ -1121,8 +1131,8 @@ def transform_points_pre_display(
def points_to_pixel_coords(
self,
mobject: Mobject,
points: np.ndarray,
) -> np.ndarray: # TODO: Write more detailed docstrings for this method.
points: Point3D_Array,
) -> npt.NDArray[ManimInt]: # TODO: Write more detailed docstrings for this method.
points = self.transform_points_pre_display(mobject, points)
shifted_points = points - self.frame_center

Expand Down
9 changes: 6 additions & 3 deletions manim/camera/three_d_camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from manim.mobject.types.vectorized_mobject import VMobject
from manim.mobject.value_tracker import ValueTracker
from manim.typing import (
FloatRGBA_Array,
MatrixMN,
Point3D,
Point3D_Array,
Expand Down Expand Up @@ -109,7 +110,9 @@ def get_value_trackers(self) -> list[ValueTracker]:
self.zoom_tracker,
]

def modified_rgbas(self, vmobject: VMobject, rgbas: MatrixMN) -> MatrixMN:
def modified_rgbas(
self, vmobject: VMobject, rgbas: FloatRGBA_Array
) -> FloatRGBA_Array:
if not self.should_apply_shading:
return rgbas
if vmobject.shade_in_3d and (vmobject.get_num_points() > 0):
Expand Down Expand Up @@ -137,12 +140,12 @@ def get_stroke_rgbas(
self,
vmobject: VMobject,
background: bool = False,
) -> MatrixMN: # NOTE : DocStrings From parent
) -> FloatRGBA_Array: # NOTE : DocStrings From parent
return self.modified_rgbas(vmobject, vmobject.get_stroke_rgbas(background))

def get_fill_rgbas(
self, vmobject: VMobject
) -> MatrixMN: # NOTE : DocStrings From parent
) -> FloatRGBA_Array: # NOTE : DocStrings From parent
return self.modified_rgbas(vmobject, vmobject.get_fill_rgbas())

def get_mobjects_to_display(
Expand Down
17 changes: 11 additions & 6 deletions manim/mobject/opengl/opengl_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@

from manim.renderer.shader_wrapper import ShaderWrapper
from manim.typing import (
FloatRGB_Array,
FloatRGBA_Array,
ManimFloat,
MappingFunction,
MatrixMN,
Expand Down Expand Up @@ -328,9 +330,9 @@ def init_data(self) -> None:
"""Initializes the ``points``, ``bounding_box`` and ``rgbas`` attributes and groups them into self.data.
Subclasses can inherit and overwrite this method to extend `self.data`.
"""
self.points = np.zeros((0, 3))
self.bounding_box = np.zeros((3, 3))
self.rgbas = np.zeros((1, 4))
self.points: Point3D_Array = np.zeros((0, 3))
self.bounding_box: Point3D_Array = np.zeros((3, 3))
self.rgbas: FloatRGBA_Array = np.zeros((1, 4))

def init_colors(self) -> object:
"""Initializes the colors.
Expand Down Expand Up @@ -2082,7 +2084,7 @@ def set_rgba_array(
recurse: bool = True,
) -> Self:
if color is not None:
rgbs = np.array([color_to_rgb(c) for c in listify(color)])
rgbs: FloatRGB_Array = np.array([color_to_rgb(c) for c in listify(color)])
if opacity is not None:
opacities = listify(opacity)

Expand All @@ -2105,14 +2107,16 @@ def set_rgba_array(

# Color and opacity
if color is not None and opacity is not None:
rgbas = np.array([[*rgb, o] for rgb, o in zip(*make_even(rgbs, opacities))])
rgbas: FloatRGBA_Array = np.array(
[[*rgb, o] for rgb, o in zip(*make_even(rgbs, opacities))]
)
for mob in self.get_family(recurse):
mob.data[name] = rgbas.copy()
return self

def set_rgba_array_direct(
self,
rgbas: npt.NDArray[RGBA_Array_Float],
rgbas: FloatRGBA_Array,
name: str = "rgbas",
recurse: bool = True,
) -> Self:
Expand Down Expand Up @@ -2794,6 +2798,7 @@ def set_color_by_xyz_func(
# of the shader code
for char in "xyz":
glsl_snippet = glsl_snippet.replace(char, "point." + char)
# TODO: get_colormap_list does not exist
rgb_list = get_colormap_list(colormap)
self.set_color_by_code(
f"color.rgb = float_to_color({glsl_snippet}, {float(min_value)}, {float(max_value)}, {get_colormap_code(rgb_list)});",
Expand Down
25 changes: 21 additions & 4 deletions manim/mobject/opengl/opengl_point_cloud_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

__all__ = ["OpenGLPMobject", "OpenGLPGroup", "OpenGLPMPoint"]

from typing import TYPE_CHECKING

import moderngl
import numpy as np
from typing_extensions import Self

from manim.constants import *
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
Expand All @@ -20,6 +21,16 @@
from manim.utils.config_ops import _Uniforms
from manim.utils.iterables import resize_with_interpolation

if TYPE_CHECKING:
from typing_extensions import Self

from manim.typing import (
FloatRGBA_Array,
FloatRGBALike_Array,
Point3D_Array,
Point3DLike_Array,
)

__all__ = ["OpenGLPMobject", "OpenGLPGroup", "OpenGLPMPoint"]


Expand Down Expand Up @@ -48,14 +59,20 @@ def __init__(
)

def reset_points(self) -> Self:
self.rgbas = np.zeros((1, 4))
self.points = np.zeros((0, 3))
self.rgbas: FloatRGBA_Array = np.zeros((1, 4))
self.points: Point3D_Array = np.zeros((0, 3))
return self

def get_array_attrs(self):
return ["points", "rgbas"]

def add_points(self, points, rgbas=None, color=None, opacity=None):
def add_points(
self,
points: Point3DLike_Array,
rgbas: FloatRGBALike_Array | None = None,
color: ParsableManimColor | None = None,
opacity: float | None = None,
) -> Self:
"""Add points.
Points must be a Nx3 numpy array.
Expand Down
3 changes: 3 additions & 0 deletions manim/mobject/opengl/opengl_vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class OpenGLVMobject(OpenGLMobject):
stroke_shader_folder = "quadratic_bezier_stroke"
fill_shader_folder = "quadratic_bezier_fill"

# TODO: although these are called "rgba" in singular, they are used as
# FloatRGBA_Arrays and should be called instead "rgbas" in plural for consistency.
# The same should probably apply for "stroke_width" and "unit_normal".
fill_rgba = _Data()
stroke_rgba = _Data()
stroke_width = _Data()
Expand Down
19 changes: 13 additions & 6 deletions manim/mobject/types/point_cloud_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@
import numpy.typing as npt
from typing_extensions import Self

from manim.typing import ManimFloat, Point3DLike
from manim.typing import (
FloatRGBA_Array,
FloatRGBALike_Array,
ManimFloat,
Point3D_Array,
Point3DLike,
Point3DLike_Array,
)


class PMobject(Mobject, metaclass=ConvertToOpenGL):
Expand Down Expand Up @@ -70,19 +77,19 @@ def __init__(self, stroke_width: int = DEFAULT_STROKE_WIDTH, **kwargs: Any) -> N
super().__init__(**kwargs)

def reset_points(self) -> Self:
self.rgbas = np.zeros((0, 4))
self.points = np.zeros((0, 3))
self.rgbas: FloatRGBA_Array = np.zeros((0, 4))
self.points: Point3D_Array = np.zeros((0, 3))
return self

def get_array_attrs(self) -> list[str]:
return super().get_array_attrs() + ["rgbas"]

def add_points(
self,
points: npt.NDArray,
rgbas: npt.NDArray | None = None,
points: Point3DLike_Array,
rgbas: FloatRGBALike_Array | None = None,
color: ParsableManimColor | None = None,
alpha: float = 1,
alpha: float = 1.0,
) -> Self:
"""Add points.
Expand Down
24 changes: 13 additions & 11 deletions manim/mobject/types/vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@
CubicBezierPath,
CubicBezierPointsLike,
CubicSpline,
FloatRGBA,
FloatRGBA_Array,
ManimFloat,
MappingFunction,
Point2DLike,
Point3D,
Point3D_Array,
Point3DLike,
Point3DLike_Array,
RGBA_Array_Float,
Vector3D,
Vector3DLike,
Zeros,
)

# TODO
Expand Down Expand Up @@ -215,8 +215,10 @@ def init_colors(self, propagate_colors: bool = True) -> Self:
return self

def generate_rgbas_array(
self, color: ManimColor | list[ManimColor], opacity: float | Iterable[float]
) -> RGBA_Array_Float:
self,
color: ParsableManimColor | Iterable[ManimColor] | None,
opacity: float | Iterable[float],
) -> FloatRGBA:
"""
First arg can be either a color, or a tuple/list of colors.
Likewise, opacity can either be a float, or a tuple of floats.
Expand All @@ -230,7 +232,7 @@ def generate_rgbas_array(
opacities: list[float] = [
o if (o is not None) else 0.0 for o in tuplify(opacity)
]
rgbas: npt.NDArray[RGBA_Array_Float] = np.array(
rgbas: FloatRGBA_Array = np.array(
[c.to_rgba_with_alpha(o) for c, o in zip(*make_even(colors, opacities))],
)

Expand All @@ -245,7 +247,7 @@ def generate_rgbas_array(
def update_rgbas_array(
self,
array_name: str,
color: ManimColor | None = None,
color: ParsableManimColor | Iterable[ManimColor] | None = None,
opacity: float | None = None,
) -> Self:
rgbas = self.generate_rgbas_array(color, opacity)
Expand Down Expand Up @@ -313,7 +315,7 @@ def construct(self):
for submobject in self.submobjects:
submobject.set_fill(color, opacity, family)
self.update_rgbas_array("fill_rgbas", color, opacity)
self.fill_rgbas: RGBA_Array_Float
self.fill_rgbas: FloatRGBA_Array
if opacity is not None:
self.fill_opacity = opacity
return self
Expand Down Expand Up @@ -539,7 +541,7 @@ def fade(self, darkness: float = 0.5, family: bool = True) -> Self:
super().fade(darkness, family)
return self

def get_fill_rgbas(self) -> RGBA_Array_Float | Zeros:
def get_fill_rgbas(self) -> FloatRGBA_Array:
try:
return self.fill_rgbas
except AttributeError:
Expand Down Expand Up @@ -572,13 +574,13 @@ def get_fill_colors(self) -> list[ManimColor | None]:
def get_fill_opacities(self) -> npt.NDArray[ManimFloat]:
return self.get_fill_rgbas()[:, 3]

def get_stroke_rgbas(self, background: bool = False) -> RGBA_Array_float | Zeros:
def get_stroke_rgbas(self, background: bool = False) -> FloatRGBA_Array:
try:
if background:
self.background_stroke_rgbas: RGBA_Array_Float
self.background_stroke_rgbas: FloatRGBA_Array
rgbas = self.background_stroke_rgbas
else:
self.stroke_rgbas: RGBA_Array_Float
self.stroke_rgbas: FloatRGBA_Array
rgbas = self.stroke_rgbas
return rgbas
except AttributeError:
Expand Down
Loading
Loading