diff --git a/manim/__init__.py b/manim/__init__.py index d9be9319f7..cf134afdce 100644 --- a/manim/__init__.py +++ b/manim/__init__.py @@ -100,6 +100,7 @@ from .utils.images import * from .utils.iterables import * from .utils.paths import * +from .utils.position import * from .utils.rate_functions import * from .utils.simple_functions import * from .utils.sounds import * diff --git a/manim/mobject/coordinate_systems.py b/manim/mobject/coordinate_systems.py index d8c74b865e..0656ac78b9 100644 --- a/manim/mobject/coordinate_systems.py +++ b/manim/mobject/coordinate_systems.py @@ -96,13 +96,13 @@ def __init__( self.y_length = y_length self.num_sampled_graph_points_per_tick = 10 - def coords_to_point(self, *coords): + def coords_to_point(self, *coords: Union[float, int, Sequence[float]]): raise NotImplementedError() def point_to_coords(self, point): raise NotImplementedError() - def c2p(self, *coords): + def c2p(self, *coords: Union[float, int, Sequence[float]]): """Abbreviation for coords_to_point""" return self.coords_to_point(*coords) @@ -1080,7 +1080,9 @@ def create_axis( axis.shift(-axis.number_to_point(self.origin_shift(range_terms))) return axis - def coords_to_point(self, *coords: Sequence[float]) -> np.ndarray: + def coords_to_point( + self, *coords: Union[float, int, Sequence[float]] + ) -> np.ndarray: """Transforms the vector formed from ``coords`` formed by the :class:`Axes` into the corresponding vector with respect to the default basis. @@ -1090,6 +1092,7 @@ def coords_to_point(self, *coords: Sequence[float]) -> np.ndarray: A point that results from a change of basis from the coordinate system defined by the :class:`Axes` to that of ``manim``'s default coordinate system """ + coords = np.array(coords).flatten() origin = self.x_axis.number_to_point(self.origin_shift(self.x_range)) result = np.array(origin) for axis, coord in zip(self.get_axes(), coords): diff --git a/manim/mobject/geometry.py b/manim/mobject/geometry.py index 878025cbb9..b048613925 100644 --- a/manim/mobject/geometry.py +++ b/manim/mobject/geometry.py @@ -76,6 +76,7 @@ def construct(self): from ..mobject.types.vectorized_mobject import DashedVMobject, VGroup, VMobject from ..utils.color import * from ..utils.iterables import adjacent_n_tuples, adjacent_pairs +from ..utils.position import Position from ..utils.simple_functions import fdiv from ..utils.space_ops import ( angle_between_vectors, @@ -283,7 +284,7 @@ def __init__( start_angle=0, angle=TAU / 4, num_components=9, - arc_center=ORIGIN, + arc_center=Position(), **kwargs, ): if radius is None: # apparently None is passed by ArcBetweenPoints @@ -614,13 +615,14 @@ def construct(self): def __init__( self, - point=ORIGIN, + point=Position(ORIGIN), radius: float = DEFAULT_DOT_RADIUS, stroke_width=0, fill_opacity=1.0, color=WHITE, **kwargs, ): + point = Position(point) super().__init__( arc_center=point, radius=radius, diff --git a/manim/mobject/mobject.py b/manim/mobject/mobject.py index 3092ef9f9f..5c6e4d7d45 100644 --- a/manim/mobject/mobject.py +++ b/manim/mobject/mobject.py @@ -44,6 +44,7 @@ from ..utils.exceptions import MultiAnimationOverrideException from ..utils.iterables import list_update, remove_list_redundancies from ..utils.paths import straight_path +from ..utils.position import Position from ..utils.simple_functions import get_parameters from ..utils.space_ops import ( angle_between_vectors, @@ -1089,7 +1090,7 @@ def apply_to_family(self, func: Callable[["Mobject"], None]) -> "Mobject": for mob in self.family_members_with_points(): func(mob) - def shift(self, *vectors: np.ndarray) -> "Mobject": + def shift(self, *vectors: Position) -> "Mobject": """Shift by the given vectors. Parameters @@ -1106,7 +1107,9 @@ def shift(self, *vectors: np.ndarray) -> "Mobject": -------- :meth:`move_to` """ - + vectors = [ + Position(pos) for pos in vectors + ] # todo remove - only temp for testing if config.renderer == "opengl": self.apply_points_function( lambda points: points + vectors[0], @@ -1118,7 +1121,7 @@ def shift(self, *vectors: np.ndarray) -> "Mobject": total_vector = reduce(op.add, vectors) for mob in self.family_members_with_points(): mob.points = mob.points.astype("float") - mob.points += total_vector + mob.points += total_vector() if hasattr(mob, "data") and "points" in mob.data: mob.data["points"] += total_vector return self diff --git a/manim/utils/position.py b/manim/utils/position.py new file mode 100644 index 0000000000..c58fd53869 --- /dev/null +++ b/manim/utils/position.py @@ -0,0 +1,76 @@ +from typing import Sequence, Union + +import numpy as np + +__all__ = [ + "Position", +] + + +class Position: + MAX_DIMS = 3 + + def __init__( + self, + *args: Union[float, int, Sequence[float], "Position"], + default_val=0, + dtype=np.float64 + ): + if args is None or len(args) == 0: + self.pos = np.array([default_val for _ in range(Position.MAX_DIMS)]) + elif isinstance(args[0], Position): + self.pos = args[0]() + else: + self.pos: np.ndarray = np.array(args, dtype=dtype).flatten()[ + : Position.MAX_DIMS + ] + self.default_val = default_val + self.add_padding() + + def add_padding(self, padding=3): + zeros_to_pad = padding - len(self.pos) + if zeros_to_pad <= 0: + return + self.pos = np.append(self.pos, [self.default_val] * zeros_to_pad) + + @property + def x(self) -> float: + return self.pos[0] + + @property + def y(self) -> float: + return self.pos[1] + + @property + def z(self) -> float: + return self.pos[2] + + def as_numpy(self): + return self.pos + + def as_list(self) -> list: + return [i.item() for i in self.pos] + + def as_tuple(self) -> tuple: + return tuple(i.item() for i in self.pos) + + def __call__(self, *args, **kwargs) -> np.ndarray: + return self.as_numpy() + + def __add__(self, other): + return Position(self.pos + other.pos) + + def __sub__(self, other): + return Position(self.pos - other.pos) + + def __mul__(self, other): + return Position(self.pos * other.pos) + + def __truediv__(self, other): + return Position(self.pos / other.pos) + + def __floordiv__(self, other): + return Position(self.pos // other.pos) + + def __mod__(self, other): + return Position(self.pos % other.pos) diff --git a/tests/test_vectorized_mobject.py b/tests/test_vectorized_mobject.py index 346e8fe769..1820de48e7 100644 --- a/tests/test_vectorized_mobject.py +++ b/tests/test_vectorized_mobject.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from manim import Circle, Line, Mobject, Square, VDict, VGroup, VMobject +from manim import Circle, Dot, Line, Mobject, Square, VDict, VGroup, VMobject def test_vmobject_point_from_propotion(): @@ -204,3 +204,11 @@ def test_vgroup_item_assignment_only_allows_vmobjects(): vgroup = VGroup(VMobject()) with pytest.raises(TypeError, match="All submobjects must be of type VMobject"): vgroup[0] = "invalid object" + + +def test_dot_allows_position_omission(): + dot = Dot([1]) + assert len(dot.arc_center()) == 3 + + dot = Dot([1, 2]) + assert len(dot.arc_center()) == 3