Skip to content

Commit 90ae6ad

Browse files
JasonGrace2282pre-commit-ci[bot]Viicoschopan050
authored
Add @ shorthand for CoordinateSystem methods coords_to_point (c2p) and point_to_coords (p2c) (#3754)
* Add shorthand for axes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add spacing Co-authored-by: Victorien <[email protected]> * Convert CoordinateSystem example, and add to NumberLine * Add doctest for NumberLine * Add test * Fix typehint for c2p Co-authored-by: Victorien <[email protected]> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Victorien <[email protected]> Co-authored-by: Francisco Manríquez Novoa <[email protected]>
1 parent 938b8fc commit 90ae6ad

File tree

3 files changed

+41
-8
lines changed

3 files changed

+41
-8
lines changed

manim/mobject/graphing/coordinate_systems.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from manim.mobject.graphing.functions import ImplicitFunction, ParametricFunction
2828
from manim.mobject.graphing.number_line import NumberLine
2929
from manim.mobject.graphing.scale import LinearBase
30+
from manim.mobject.mobject import Mobject
3031
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
3132
from manim.mobject.opengl.opengl_surface import OpenGLSurface
3233
from manim.mobject.text.tex_mobject import MathTex
@@ -96,10 +97,10 @@ def construct(self):
9697
)
9798
9899
# Extra lines and labels for point (1,1)
99-
graphs += grid.get_horizontal_line(grid.c2p(1, 1, 0), color=BLUE)
100-
graphs += grid.get_vertical_line(grid.c2p(1, 1, 0), color=BLUE)
101-
graphs += Dot(point=grid.c2p(1, 1, 0), color=YELLOW)
102-
graphs += Tex("(1,1)").scale(0.75).next_to(grid.c2p(1, 1, 0))
100+
graphs += grid.get_horizontal_line(grid @ (1, 1, 0), color=BLUE)
101+
graphs += grid.get_vertical_line(grid @ (1, 1, 0), color=BLUE)
102+
graphs += Dot(point=grid @ (1, 1, 0), color=YELLOW)
103+
graphs += Tex("(1,1)").scale(0.75).next_to(grid @ (1, 1, 0))
103104
title = Title(
104105
# spaces between braces to prevent SyntaxError
105106
r"Graphs of $y=x^{ {1}\over{n} }$ and $y=x^n (n=1,2,3,...,20)$",
@@ -145,7 +146,7 @@ def __init__(
145146
self.y_length = y_length
146147
self.num_sampled_graph_points_per_tick = 10
147148

148-
def coords_to_point(self, *coords: Sequence[ManimFloat]):
149+
def coords_to_point(self, *coords: ManimFloat):
149150
raise NotImplementedError()
150151

151152
def point_to_coords(self, point: Point3D):
@@ -570,7 +571,7 @@ def get_horizontal_line(self, point: Sequence[float], **kwargs) -> Line:
570571
class GetHorizontalLineExample(Scene):
571572
def construct(self):
572573
ax = Axes().add_coordinates()
573-
point = ax.c2p(-4, 1.5)
574+
point = ax @ (-4, 1.5)
574575
575576
dot = Dot(point)
576577
line = ax.get_horizontal_line(point, line_func=Line)
@@ -1790,6 +1791,14 @@ def construct(self):
17901791

17911792
return T_label_group
17921793

1794+
def __matmul__(self, coord: Point3D | Mobject):
1795+
if isinstance(coord, Mobject):
1796+
coord = coord.get_center()
1797+
return self.coords_to_point(*coord)
1798+
1799+
def __rmatmul__(self, point: Point3D):
1800+
return self.point_to_coords(point)
1801+
17931802

17941803
class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL):
17951804
"""Creates a set of axes.
@@ -1990,6 +1999,7 @@ def coords_to_point(
19901999
self, *coords: float | Sequence[float] | Sequence[Sequence[float]] | np.ndarray
19912000
) -> np.ndarray:
19922001
"""Accepts coordinates from the axes and returns a point with respect to the scene.
2002+
Equivalent to `ax @ (coord1)`
19932003
19942004
Parameters
19952005
----------
@@ -2018,6 +2028,8 @@ def coords_to_point(
20182028
>>> ax = Axes()
20192029
>>> np.around(ax.coords_to_point(1, 0, 0), 2)
20202030
array([0.86, 0. , 0. ])
2031+
>>> np.around(ax @ (1, 0, 0), 2)
2032+
array([0.86, 0. , 0. ])
20212033
>>> np.around(ax.coords_to_point([[0, 1], [1, 1], [1, 0]]), 2)
20222034
array([[0. , 0.75, 0. ],
20232035
[0.86, 0.75, 0. ],

manim/mobject/graphing/number_line.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from manim.mobject.mobject import Mobject
56
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject
67

78
__all__ = ["NumberLine", "UnitInterval"]
@@ -12,6 +13,7 @@
1213

1314
if TYPE_CHECKING:
1415
from manim.mobject.geometry.tips import ArrowTip
16+
from manim.typing import Point3D
1517

1618
import numpy as np
1719

@@ -344,6 +346,7 @@ def get_tick_range(self) -> np.ndarray:
344346
def number_to_point(self, number: float | np.ndarray) -> np.ndarray:
345347
"""Accepts a value along the number line and returns a point with
346348
respect to the scene.
349+
Equivalent to `NumberLine @ number`
347350
348351
Parameters
349352
----------
@@ -364,6 +367,8 @@ def number_to_point(self, number: float | np.ndarray) -> np.ndarray:
364367
array([0., 0., 0.])
365368
>>> number_line.number_to_point(1)
366369
array([1., 0., 0.])
370+
>>> number_line @ 1
371+
array([1., 0., 0.])
367372
>>> number_line.number_to_point([1, 2, 3])
368373
array([[1., 0., 0.],
369374
[2., 0., 0.],
@@ -642,6 +647,14 @@ def _decimal_places_from_step(step) -> int:
642647
return 0
643648
return len(step.split(".")[-1])
644649

650+
def __matmul__(self, other: float):
651+
return self.n2p(other)
652+
653+
def __rmatmul__(self, other: Point3D | Mobject):
654+
if isinstance(other, Mobject):
655+
other = other.get_center()
656+
return self.p2n(other)
657+
645658

646659
class UnitInterval(NumberLine):
647660
def __init__(

tests/module/mobject/graphing/test_coordinate_system.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from manim import LEFT, ORIGIN, PI, UR, Axes, Circle, ComplexPlane
99
from manim import CoordinateSystem as CS
10-
from manim import NumberPlane, PolarPlane, ThreeDAxes, config, tempconfig
10+
from manim import Dot, NumberPlane, PolarPlane, ThreeDAxes, config, tempconfig
1111

1212

1313
def test_initial_config():
@@ -119,7 +119,15 @@ def test_coords_to_point():
119119

120120
# a point with respect to the axes
121121
c2p_coord = np.around(ax.coords_to_point(2, 2), decimals=4)
122-
np.testing.assert_array_equal(c2p_coord, (1.7143, 1.5, 0))
122+
c2p_coord_matmul = np.around(ax @ (2, 2), decimals=4)
123+
124+
expected = (1.7143, 1.5, 0)
125+
126+
np.testing.assert_array_equal(c2p_coord, expected)
127+
np.testing.assert_array_equal(c2p_coord_matmul, c2p_coord)
128+
129+
mob = Dot().move_to((2, 2, 0))
130+
np.testing.assert_array_equal(np.around(ax @ mob, decimals=4), expected)
123131

124132

125133
def test_coords_to_point_vectorized():

0 commit comments

Comments
 (0)