Skip to content

Commit 1edcd99

Browse files
Speed up width/height/depth by reducing copying (#3180)
* Speed up width/height/depth by reducing copying * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test * fix example and improve tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo * np.max/min is 2x slower than max/min --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1e8b349 commit 1edcd99

File tree

2 files changed

+145
-15
lines changed

2 files changed

+145
-15
lines changed

manim/mobject/mobject.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1842,16 +1842,29 @@ def restore(self):
18421842
self.become(self.saved_state)
18431843
return self
18441844

1845-
##
1846-
1847-
def reduce_across_dimension(self, points_func, reduce_func, dim):
1848-
points = self.get_all_points()
1849-
if points is None or len(points) == 0:
1850-
# Note, this default means things like empty VGroups
1851-
# will appear to have a center at [0, 0, 0]
1845+
def reduce_across_dimension(self, reduce_func, dim: int) -> float:
1846+
"""Find the min or max value from a dimension across all points in this and submobjects."""
1847+
assert dim >= 0 and dim <= 2
1848+
if len(self.submobjects) == 0 and len(self.points) == 0:
1849+
# If we have no points and no submobjects, return 0 (e.g. center)
18521850
return 0
1853-
values = points_func(points[:, dim])
1854-
return reduce_func(values)
1851+
1852+
# If we do not have points (but do have submobjects)
1853+
# use only the points from those.
1854+
if len(self.points) == 0:
1855+
rv = None
1856+
else:
1857+
# Otherwise, be sure to include our own points
1858+
rv = reduce_func(self.points[:, dim])
1859+
# Recursively ask submobjects (if any) for the biggest/
1860+
# smallest dimension they have and compare it to the return value.
1861+
for mobj in self.submobjects:
1862+
value = mobj.reduce_across_dimension(reduce_func, dim)
1863+
if rv is None:
1864+
rv = value
1865+
else:
1866+
rv = reduce_func([value, rv])
1867+
return rv
18551868

18561869
def nonempty_submobjects(self):
18571870
return [
@@ -1860,13 +1873,23 @@ def nonempty_submobjects(self):
18601873
if len(submob.submobjects) != 0 or len(submob.points) != 0
18611874
]
18621875

1863-
def get_merged_array(self, array_attr):
1876+
def get_merged_array(self, array_attr) -> np.ndarray:
1877+
"""Return all of a given attribute from this mobject and all submobjects.
1878+
1879+
May contain duplicates; the order is in a depth-first (pre-order)
1880+
traversal of the submobjects.
1881+
"""
18641882
result = getattr(self, array_attr)
18651883
for submob in self.submobjects:
18661884
result = np.append(result, submob.get_merged_array(array_attr), axis=0)
18671885
return result
18681886

1869-
def get_all_points(self):
1887+
def get_all_points(self) -> np.ndarray:
1888+
"""Return all points from this mobject and all submobjects.
1889+
1890+
May contain duplicates; the order is in a depth-first (pre-order)
1891+
traversal of the submobjects.
1892+
"""
18701893
return self.get_merged_array("points")
18711894

18721895
# Getters
@@ -1987,10 +2010,9 @@ def get_nadir(self) -> np.ndarray:
19872010
def length_over_dim(self, dim):
19882011
"""Measure the length of an :class:`~.Mobject` in a certain direction."""
19892012
return self.reduce_across_dimension(
1990-
np.max,
1991-
np.max,
2013+
max,
19922014
dim,
1993-
) - self.reduce_across_dimension(np.min, np.min, dim)
2015+
) - self.reduce_across_dimension(min, dim)
19942016

19952017
def get_coord(self, dim, direction=ORIGIN):
19962018
"""Meant to generalize ``get_x``, ``get_y`` and ``get_z``"""

tests/module/mobject/mobject/test_mobject.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3+
import numpy as np
34
import pytest
45

5-
from manim import Mobject
6+
from manim import DL, UR, Circle, Mobject, Rectangle, Square, VGroup
67

78

89
def test_mobject_add():
@@ -49,3 +50,110 @@ def test_mobject_remove():
4950
assert len(obj.submobjects) == 10
5051

5152
assert obj.remove(Mobject()) is obj
53+
54+
55+
def test_mobject_dimensions_single_mobject():
56+
# A Mobject with no points and no submobjects has no dimensions
57+
empty = Mobject()
58+
assert empty.width == 0
59+
assert empty.height == 0
60+
assert empty.depth == 0
61+
62+
has_points = Mobject()
63+
has_points.points = np.array([[-1, -2, -3], [1, 3, 5]])
64+
assert has_points.width == 2
65+
assert has_points.height == 5
66+
assert has_points.depth == 8
67+
68+
rect = Rectangle(width=3, height=5)
69+
70+
assert rect.width == 3
71+
assert rect.height == 5
72+
assert rect.depth == 0
73+
74+
# Dimensions should be recalculated after scaling
75+
rect.scale(2.0)
76+
assert rect.width == 6
77+
assert rect.height == 10
78+
assert rect.depth == 0
79+
80+
# Dimensions should not be dependent on location
81+
rect.move_to([-3, -4, -5])
82+
assert rect.width == 6
83+
assert rect.height == 10
84+
assert rect.depth == 0
85+
86+
circ = Circle(radius=2)
87+
88+
assert circ.width == 4
89+
assert circ.height == 4
90+
assert circ.depth == 0
91+
92+
93+
def is_close(x, y):
94+
return abs(x - y) < 0.00001
95+
96+
97+
def test_mobject_dimensions_nested_mobjects():
98+
vg = VGroup()
99+
100+
for x in range(-5, 8, 1):
101+
row = VGroup()
102+
vg += row
103+
for y in range(-17, 2, 1):
104+
for z in range(0, 10, 1):
105+
s = Square().move_to([x, y, z / 10])
106+
row += s
107+
108+
assert vg.width == 14.0, vg.width
109+
assert vg.height == 20.0, vg.height
110+
assert is_close(vg.depth, 0.9), vg.depth
111+
112+
# Dimensions should be recalculated after scaling
113+
vg.scale(0.5)
114+
assert vg.width == 7.0, vg.width
115+
assert vg.height == 10.0, vg.height
116+
assert is_close(vg.depth, 0.45), vg.depth
117+
118+
# Adding a mobject changes the bounds/dimensions
119+
rect = Rectangle(width=3, height=5)
120+
rect.move_to([9, 3, 1])
121+
vg += rect
122+
assert vg.width == 13.0, vg.width
123+
assert is_close(vg.height, 18.5), vg.height
124+
assert is_close(vg.depth, 0.775), vg.depth
125+
126+
127+
def test_mobject_dimensions_mobjects_with_no_points_are_at_origin():
128+
rect = Rectangle(width=2, height=3)
129+
rect.move_to([-4, -5, 0])
130+
outer_group = VGroup(rect)
131+
132+
# This is as one would expect
133+
assert outer_group.width == 2
134+
assert outer_group.height == 3
135+
136+
# Adding a mobject with no points has a quirk of adding a "point"
137+
# to [0, 0, 0] (the origin). This changes the size of the outer
138+
# group because now the bottom left corner is at [-5, -6.5, 0]
139+
# but the upper right corner is [0, 0, 0] instead of [-3, -3.5, 0]
140+
outer_group.add(VGroup())
141+
assert outer_group.width == 5
142+
assert outer_group.height == 6.5
143+
144+
145+
def test_mobject_dimensions_has_points_and_children():
146+
outer_rect = Rectangle(width=3, height=6)
147+
inner_rect = Rectangle(width=2, height=1)
148+
inner_rect.align_to(outer_rect.get_corner(UR), DL)
149+
outer_rect.add(inner_rect)
150+
151+
# The width of a mobject should depend both on its points and
152+
# the points of all children mobjects.
153+
assert outer_rect.width == 5 # 3 from outer_rect, 2 from inner_rect
154+
assert outer_rect.height == 7 # 6 from outer_rect, 1 from inner_rect
155+
assert outer_rect.depth == 0
156+
157+
assert inner_rect.width == 2
158+
assert inner_rect.height == 1
159+
assert inner_rect.depth == 0

0 commit comments

Comments
 (0)