Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions manim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from .utils.images import *
from .utils.iterables import *
from .utils.paths import *
from .utils.position import *
from .utils.rate_functions import *
from .utils.simple_functions import *
from .utils.sounds import *
Expand Down
9 changes: 6 additions & 3 deletions manim/mobject/coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def __init__(
self.y_length = y_length
self.num_sampled_graph_points_per_tick = 10

def coords_to_point(self, *coords):
def coords_to_point(self, *coords: Union[float, int, Sequence[float]]):
raise NotImplementedError()

def point_to_coords(self, point):
raise NotImplementedError()

def c2p(self, *coords):
def c2p(self, *coords: Union[float, int, Sequence[float]]):
"""Abbreviation for coords_to_point"""
return self.coords_to_point(*coords)

Expand Down Expand Up @@ -1080,7 +1080,9 @@ def create_axis(
axis.shift(-axis.number_to_point(self.origin_shift(range_terms)))
return axis

def coords_to_point(self, *coords: Sequence[float]) -> np.ndarray:
def coords_to_point(
self, *coords: Union[float, int, Sequence[float]]
) -> np.ndarray:
"""Transforms the vector formed from ``coords`` formed by the :class:`Axes`
into the corresponding vector with respect to the default basis.

Expand All @@ -1090,6 +1092,7 @@ def coords_to_point(self, *coords: Sequence[float]) -> np.ndarray:
A point that results from a change of basis from the coordinate system
defined by the :class:`Axes` to that of ``manim``'s default coordinate system
"""
coords = np.array(coords).flatten()
origin = self.x_axis.number_to_point(self.origin_shift(self.x_range))
result = np.array(origin)
for axis, coord in zip(self.get_axes(), coords):
Expand Down
6 changes: 4 additions & 2 deletions manim/mobject/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def construct(self):
from ..mobject.types.vectorized_mobject import DashedVMobject, VGroup, VMobject
from ..utils.color import *
from ..utils.iterables import adjacent_n_tuples, adjacent_pairs
from ..utils.position import Position
from ..utils.simple_functions import fdiv
from ..utils.space_ops import (
angle_between_vectors,
Expand Down Expand Up @@ -283,7 +284,7 @@ def __init__(
start_angle=0,
angle=TAU / 4,
num_components=9,
arc_center=ORIGIN,
arc_center=Position(),
**kwargs,
):
if radius is None: # apparently None is passed by ArcBetweenPoints
Expand Down Expand Up @@ -614,13 +615,14 @@ def construct(self):

def __init__(
self,
point=ORIGIN,
point=Position(ORIGIN),
radius: float = DEFAULT_DOT_RADIUS,
stroke_width=0,
fill_opacity=1.0,
color=WHITE,
**kwargs,
):
point = Position(point)
super().__init__(
arc_center=point,
radius=radius,
Expand Down
9 changes: 6 additions & 3 deletions manim/mobject/mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ..utils.exceptions import MultiAnimationOverrideException
from ..utils.iterables import list_update, remove_list_redundancies
from ..utils.paths import straight_path
from ..utils.position import Position
from ..utils.simple_functions import get_parameters
from ..utils.space_ops import (
angle_between_vectors,
Expand Down Expand Up @@ -1089,7 +1090,7 @@ def apply_to_family(self, func: Callable[["Mobject"], None]) -> "Mobject":
for mob in self.family_members_with_points():
func(mob)

def shift(self, *vectors: np.ndarray) -> "Mobject":
def shift(self, *vectors: Position) -> "Mobject":
"""Shift by the given vectors.

Parameters
Expand All @@ -1106,7 +1107,9 @@ def shift(self, *vectors: np.ndarray) -> "Mobject":
--------
:meth:`move_to`
"""

vectors = [
Position(pos) for pos in vectors
] # todo remove - only temp for testing
if config.renderer == "opengl":
self.apply_points_function(
lambda points: points + vectors[0],
Expand All @@ -1118,7 +1121,7 @@ def shift(self, *vectors: np.ndarray) -> "Mobject":
total_vector = reduce(op.add, vectors)
for mob in self.family_members_with_points():
mob.points = mob.points.astype("float")
mob.points += total_vector
mob.points += total_vector()
if hasattr(mob, "data") and "points" in mob.data:
mob.data["points"] += total_vector
return self
Expand Down
76 changes: 76 additions & 0 deletions manim/utils/position.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Sequence, Union

import numpy as np

__all__ = [
"Position",
]


class Position:
MAX_DIMS = 3

def __init__(
self,
*args: Union[float, int, Sequence[float], "Position"],
default_val=0,
dtype=np.float64
):
if args is None or len(args) == 0:
self.pos = np.array([default_val for _ in range(Position.MAX_DIMS)])
elif isinstance(args[0], Position):
self.pos = args[0]()
else:
self.pos: np.ndarray = np.array(args, dtype=dtype).flatten()[
: Position.MAX_DIMS
]
self.default_val = default_val
self.add_padding()

def add_padding(self, padding=3):
zeros_to_pad = padding - len(self.pos)
if zeros_to_pad <= 0:
return
self.pos = np.append(self.pos, [self.default_val] * zeros_to_pad)

@property
def x(self) -> float:
return self.pos[0]

@property
def y(self) -> float:
return self.pos[1]

@property
def z(self) -> float:
return self.pos[2]

def as_numpy(self):
return self.pos

def as_list(self) -> list:
return [i.item() for i in self.pos]

def as_tuple(self) -> tuple:
return tuple(i.item() for i in self.pos)

def __call__(self, *args, **kwargs) -> np.ndarray:
return self.as_numpy()

def __add__(self, other):
return Position(self.pos + other.pos)

def __sub__(self, other):
return Position(self.pos - other.pos)

def __mul__(self, other):
return Position(self.pos * other.pos)

def __truediv__(self, other):
return Position(self.pos / other.pos)

def __floordiv__(self, other):
return Position(self.pos // other.pos)

def __mod__(self, other):
return Position(self.pos % other.pos)
10 changes: 9 additions & 1 deletion tests/test_vectorized_mobject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from manim import Circle, Line, Mobject, Square, VDict, VGroup, VMobject
from manim import Circle, Dot, Line, Mobject, Square, VDict, VGroup, VMobject


def test_vmobject_point_from_propotion():
Expand Down Expand Up @@ -204,3 +204,11 @@ def test_vgroup_item_assignment_only_allows_vmobjects():
vgroup = VGroup(VMobject())
with pytest.raises(TypeError, match="All submobjects must be of type VMobject"):
vgroup[0] = "invalid object"


def test_dot_allows_position_omission():
dot = Dot([1])
assert len(dot.arc_center()) == 3

dot = Dot([1, 2])
assert len(dot.arc_center()) == 3