Skip to content

Commit 0a62448

Browse files
author
ryan
committed
Coords to point accept sequence and allow omitted positions
1 parent 0561beb commit 0a62448

File tree

6 files changed

+102
-9
lines changed

6 files changed

+102
-9
lines changed

manim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
from .utils.images import *
102102
from .utils.iterables import *
103103
from .utils.paths import *
104+
from .utils.position import *
104105
from .utils.rate_functions import *
105106
from .utils.simple_functions import *
106107
from .utils.sounds import *

manim/mobject/coordinate_systems.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ def __init__(
9696
self.y_length = y_length
9797
self.num_sampled_graph_points_per_tick = 10
9898

99-
def coords_to_point(self, *coords):
99+
def coords_to_point(self, *coords: Union[float, int, Sequence[float]]):
100100
raise NotImplementedError()
101101

102102
def point_to_coords(self, point):
103103
raise NotImplementedError()
104104

105-
def c2p(self, *coords):
105+
def c2p(self, *coords: Union[float, int, Sequence[float]]):
106106
"""Abbreviation for coords_to_point"""
107107
return self.coords_to_point(*coords)
108108

@@ -1080,7 +1080,9 @@ def create_axis(
10801080
axis.shift(-axis.number_to_point(self.origin_shift(range_terms)))
10811081
return axis
10821082

1083-
def coords_to_point(self, *coords: Sequence[float]) -> np.ndarray:
1083+
def coords_to_point(
1084+
self, *coords: Union[float, int, Sequence[float]]
1085+
) -> np.ndarray:
10841086
"""Transforms the vector formed from ``coords`` formed by the :class:`Axes`
10851087
into the corresponding vector with respect to the default basis.
10861088
@@ -1090,6 +1092,7 @@ def coords_to_point(self, *coords: Sequence[float]) -> np.ndarray:
10901092
A point that results from a change of basis from the coordinate system
10911093
defined by the :class:`Axes` to that of ``manim``'s default coordinate system
10921094
"""
1095+
coords = np.array(coords).flatten()
10931096
origin = self.x_axis.number_to_point(self.origin_shift(self.x_range))
10941097
result = np.array(origin)
10951098
for axis, coord in zip(self.get_axes(), coords):

manim/mobject/geometry.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def construct(self):
7676
from ..mobject.types.vectorized_mobject import DashedVMobject, VGroup, VMobject
7777
from ..utils.color import *
7878
from ..utils.iterables import adjacent_n_tuples, adjacent_pairs
79+
from ..utils.position import Position
7980
from ..utils.simple_functions import fdiv
8081
from ..utils.space_ops import (
8182
angle_between_vectors,
@@ -283,7 +284,7 @@ def __init__(
283284
start_angle=0,
284285
angle=TAU / 4,
285286
num_components=9,
286-
arc_center=ORIGIN,
287+
arc_center=Position(),
287288
**kwargs,
288289
):
289290
if radius is None: # apparently None is passed by ArcBetweenPoints
@@ -614,13 +615,14 @@ def construct(self):
614615

615616
def __init__(
616617
self,
617-
point=ORIGIN,
618+
point=Position(ORIGIN),
618619
radius: float = DEFAULT_DOT_RADIUS,
619620
stroke_width=0,
620621
fill_opacity=1.0,
621622
color=WHITE,
622623
**kwargs,
623624
):
625+
point = Position(point)
624626
super().__init__(
625627
arc_center=point,
626628
radius=radius,

manim/mobject/mobject.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from ..utils.exceptions import MultiAnimationOverrideException
4646
from ..utils.iterables import list_update, remove_list_redundancies
4747
from ..utils.paths import straight_path
48+
from ..utils.position import Position
4849
from ..utils.simple_functions import get_parameters
4950
from ..utils.space_ops import (
5051
angle_between_vectors,
@@ -1092,7 +1093,7 @@ def apply_to_family(self, func: Callable[["Mobject"], None]) -> "Mobject":
10921093
for mob in self.family_members_with_points():
10931094
func(mob)
10941095

1095-
def shift(self, *vectors: np.ndarray) -> "Mobject":
1096+
def shift(self, *vectors: Position) -> "Mobject":
10961097
"""Shift by the given vectors.
10971098
10981099
Parameters
@@ -1109,7 +1110,9 @@ def shift(self, *vectors: np.ndarray) -> "Mobject":
11091110
--------
11101111
:meth:`move_to`
11111112
"""
1112-
1113+
vectors = [
1114+
Position(pos) for pos in vectors
1115+
] # todo remove - only temp for testing
11131116
if config.renderer == "opengl":
11141117
self.apply_points_function(
11151118
lambda points: points + vectors[0],
@@ -1121,7 +1124,7 @@ def shift(self, *vectors: np.ndarray) -> "Mobject":
11211124
total_vector = reduce(op.add, vectors)
11221125
for mob in self.family_members_with_points():
11231126
mob.points = mob.points.astype("float")
1124-
mob.points += total_vector
1127+
mob.points += total_vector()
11251128
if hasattr(mob, "data") and "points" in mob.data:
11261129
mob.data["points"] += total_vector
11271130
return self

manim/utils/position.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Sequence, Union
2+
3+
import numpy as np
4+
5+
__all__ = [
6+
"Position",
7+
]
8+
9+
10+
class Position:
11+
MAX_DIMS = 3
12+
13+
def __init__(
14+
self,
15+
*args: Union[float, int, Sequence[float], "Position"],
16+
default_val=0,
17+
dtype=np.float64
18+
):
19+
if args is None or len(args) == 0:
20+
self.pos = np.array([default_val for _ in range(Position.MAX_DIMS)])
21+
elif isinstance(args[0], Position):
22+
self.pos = args[0]()
23+
else:
24+
self.pos: np.ndarray = np.array(args, dtype=dtype).flatten()[
25+
: Position.MAX_DIMS
26+
]
27+
self.default_val = default_val
28+
self.add_padding()
29+
30+
def add_padding(self, padding=3):
31+
zeros_to_pad = padding - len(self.pos)
32+
if zeros_to_pad <= 0:
33+
return
34+
self.pos = np.append(self.pos, [self.default_val] * zeros_to_pad)
35+
36+
@property
37+
def x(self) -> float:
38+
return self.pos[0]
39+
40+
@property
41+
def y(self) -> float:
42+
return self.pos[1]
43+
44+
@property
45+
def z(self) -> float:
46+
return self.pos[2]
47+
48+
def as_numpy(self):
49+
return self.pos
50+
51+
def as_list(self) -> list:
52+
return [i.item() for i in self.pos]
53+
54+
def as_tuple(self) -> tuple:
55+
return tuple(i.item() for i in self.pos)
56+
57+
def __call__(self, *args, **kwargs) -> np.ndarray:
58+
return self.as_numpy()
59+
60+
def __add__(self, other):
61+
return Position(self.pos + other.pos)
62+
63+
def __sub__(self, other):
64+
return Position(self.pos - other.pos)
65+
66+
def __mul__(self, other):
67+
return Position(self.pos * other.pos)
68+
69+
def __truediv__(self, other):
70+
return Position(self.pos / other.pos)
71+
72+
def __floordiv__(self, other):
73+
return Position(self.pos // other.pos)
74+
75+
def __mod__(self, other):
76+
return Position(self.pos % other.pos)

tests/test_vectorized_mobject.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import pytest
33

4-
from manim import Circle, Line, Mobject, Square, VDict, VGroup, VMobject
4+
from manim import Circle, Dot, Line, Mobject, Square, VDict, VGroup, VMobject
55

66

77
def test_vmobject_point_from_propotion():
@@ -204,3 +204,11 @@ def test_vgroup_item_assignment_only_allows_vmobjects():
204204
vgroup = VGroup(VMobject())
205205
with pytest.raises(TypeError, match="All submobjects must be of type VMobject"):
206206
vgroup[0] = "invalid object"
207+
208+
209+
def test_dot_allows_position_omission():
210+
dot = Dot([1])
211+
assert len(dot.arc_center()) == 3
212+
213+
dot = Dot([1, 2])
214+
assert len(dot.arc_center()) == 3

0 commit comments

Comments
 (0)