Skip to content

Commit 4402ef9

Browse files
friedkeenanbehackljsonvillanueva
authored
Fix :meth:~.VMobject.point_from_proportion to account for the length of curves. (#1274)
* Fix :meth:`~.VMobject.point_from_proportion` to account for the length of curves. * Add test for VMobject.point_from_proportion Co-authored-by: Benjamin Hackl <[email protected]> Co-authored-by: Jason Villanueva <[email protected]>
1 parent 3af4c86 commit 4402ef9

File tree

3 files changed

+175
-8
lines changed

3 files changed

+175
-8
lines changed

manim/mobject/types/opengl_vectorized_mobject.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -549,14 +549,88 @@ def get_nth_curve_points(self, n):
549549
def get_nth_curve_function(self, n):
550550
return bezier(self.get_nth_curve_points(n))
551551

552+
def get_nth_curve_function_with_length(
553+
self, n: int, n_sample_points: int = 10
554+
) -> typing.Tuple[typing.Callable[[float], np.ndarray], float]:
555+
"""Returns the expression of the nth curve along with its (approximate) length.
556+
557+
Parameters
558+
----------
559+
n
560+
The index of the desired curve.
561+
n_sample_points
562+
The number of points to sample to find the length.
563+
564+
Returns
565+
-------
566+
curve : typing.Callable[[float], np.ndarray]
567+
The function for the nth curve.
568+
length : :class:`float`
569+
The length of the nth curve.
570+
"""
571+
572+
curve = self.get_nth_curve_function(n)
573+
574+
points = np.array([curve(a) for a in np.linspace(0, 1, n_sample_points)])
575+
diffs = points[1:] - points[:-1]
576+
norms = np.apply_along_axis(get_norm, 1, diffs)
577+
578+
length = np.sum(norms)
579+
580+
return curve, length
581+
552582
def get_num_curves(self):
553583
return self.get_num_points() // self.n_points_per_curve
554584

555-
def point_from_proportion(self, alpha):
585+
def get_curve_functions(
586+
self,
587+
) -> typing.Iterable[typing.Callable[[float], np.ndarray]]:
588+
"""Gets the functions for the curves of the mobject.
589+
590+
Returns
591+
-------
592+
typing.Iterable[typing.Callable[[float], np.ndarray]]
593+
The functions for the curves.
594+
"""
595+
556596
num_curves = self.get_num_curves()
557-
n, residue = integer_interpolate(0, num_curves, alpha)
558-
curve_func = self.get_nth_curve_function(n)
559-
return curve_func(residue)
597+
598+
for n in range(num_curves):
599+
yield self.get_nth_curve_function(n)
600+
601+
def get_curve_functions_with_lengths(
602+
self,
603+
) -> typing.Iterable[typing.Tuple[typing.Callable[[float], np.ndarray], float]]:
604+
"""Gets the functions and lengths of the curves for the mobject.
605+
606+
Returns
607+
-------
608+
typing.Iterable[typing.Tuple[typing.Callable[[float], np.ndarray], float]]
609+
The functions and lengths of the curves.
610+
"""
611+
612+
num_curves = self.get_num_curves()
613+
614+
for n in range(num_curves):
615+
yield self.get_nth_curve_function_with_length(n)
616+
617+
def point_from_proportion(self, alpha):
618+
curves_with_lengths = list(self.get_curve_functions_with_lengths())
619+
620+
total_length = np.sum(length for _, length in curves_with_lengths)
621+
target_length = alpha * total_length
622+
current_length = 0
623+
624+
for curve, length in curves_with_lengths:
625+
if current_length + length >= target_length:
626+
if length != 0:
627+
residue = (target_length - current_length) / length
628+
else:
629+
residue = 0
630+
631+
return curve(residue)
632+
633+
current_length += length
560634

561635
def get_anchors_and_handles(self):
562636
"""

manim/mobject/types/vectorized_mobject.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,36 @@ def get_nth_curve_function(self, n: int) -> typing.Callable[[float], np.ndarray]
895895
"""
896896
return bezier(self.get_nth_curve_points(n))
897897

898+
def get_nth_curve_function_with_length(
899+
self, n: int, n_sample_points: int = 10
900+
) -> typing.Tuple[typing.Callable[[float], np.ndarray], float]:
901+
"""Returns the expression of the nth curve along with its (approximate) length.
902+
903+
Parameters
904+
----------
905+
n
906+
The index of the desired curve.
907+
n_sample_points
908+
The number of points to sample to find the length.
909+
910+
Returns
911+
-------
912+
curve : typing.Callable[[float], np.ndarray]
913+
The function for the nth curve.
914+
length : :class:`float`
915+
The length of the nth curve.
916+
"""
917+
918+
curve = self.get_nth_curve_function(n)
919+
920+
points = np.array([curve(a) for a in np.linspace(0, 1, n_sample_points)])
921+
diffs = points[1:] - points[:-1]
922+
norms = np.apply_along_axis(get_norm, 1, diffs)
923+
924+
length = np.sum(norms)
925+
926+
return curve, length
927+
898928
def get_num_curves(self) -> int:
899929
"""Returns the number of curves of the vmobject.
900930
@@ -906,6 +936,38 @@ def get_num_curves(self) -> int:
906936
nppcc = self.n_points_per_cubic_curve
907937
return len(self.points) // nppcc
908938

939+
def get_curve_functions(
940+
self,
941+
) -> typing.Iterable[typing.Callable[[float], np.ndarray]]:
942+
"""Gets the functions for the curves of the mobject.
943+
944+
Returns
945+
-------
946+
typing.Iterable[typing.Callable[[float], np.ndarray]]
947+
The functions for the curves.
948+
"""
949+
950+
num_curves = self.get_num_curves()
951+
952+
for n in range(num_curves):
953+
yield self.get_nth_curve_function(n)
954+
955+
def get_curve_functions_with_lengths(
956+
self,
957+
) -> typing.Iterable[typing.Tuple[typing.Callable[[float], np.ndarray], float]]:
958+
"""Gets the functions and lengths of the curves for the mobject.
959+
960+
Returns
961+
-------
962+
typing.Iterable[typing.Tuple[typing.Callable[[float], np.ndarray], float]]
963+
The functions and lengths of the curves.
964+
"""
965+
966+
num_curves = self.get_num_curves()
967+
968+
for n in range(num_curves):
969+
yield self.get_nth_curve_function_with_length(n)
970+
909971
def point_from_proportion(self, alpha: float) -> np.ndarray:
910972
"""Get the bezier curve evaluated at a position P,
911973
where P is the point corresponding to the proportion defined by the given alpha.
@@ -920,10 +982,23 @@ def point_from_proportion(self, alpha: float) -> np.ndarray:
920982
np.ndarray
921983
Point evaluated.
922984
"""
923-
num_cubics = self.get_num_curves()
924-
n, residue = integer_interpolate(0, num_cubics, alpha)
925-
curve = self.get_nth_curve_function(n)
926-
return curve(residue)
985+
986+
curves_with_lengths = list(self.get_curve_functions_with_lengths())
987+
988+
total_length = np.sum(length for _, length in curves_with_lengths)
989+
target_length = alpha * total_length
990+
current_length = 0
991+
992+
for curve, length in curves_with_lengths:
993+
if current_length + length >= target_length:
994+
if length != 0:
995+
residue = (target_length - current_length) / length
996+
else:
997+
residue = 0
998+
999+
return curve(residue)
1000+
1001+
current_length += length
9271002

9281003
def get_anchors_and_handles(self) -> typing.Iterable[np.ndarray]:
9291004
"""Returns anchors1, handles1, handles2, anchors2,

tests/test_vectorized_mobject.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,26 @@
1+
import numpy as np
12
import pytest
23

34
from manim import Line, Mobject, VDict, VGroup, VMobject
45

56

7+
def test_vmobject_point_from_propotion():
8+
obj = VMobject()
9+
10+
# One long line, one short line
11+
obj.set_points_as_corners(
12+
[
13+
np.array([0, 0, 0]),
14+
np.array([4, 0, 0]),
15+
np.array([4, 2, 0]),
16+
]
17+
)
18+
19+
# Total length of 6, so halfway along the object
20+
# would be at length 3, which lands in the first, long line.
21+
assert np.all(obj.point_from_proportion(0.5) == np.array([3, 0, 0]))
22+
23+
624
def test_vgroup_init():
725
"""Test the VGroup instantiation."""
826
VGroup()

0 commit comments

Comments
 (0)