Skip to content

Commit 3cd2a27

Browse files
authored
Allow non-numerical values to be added to :class:~.NumberLine/:class:~.Axes (#1780)
* add add_labels for number_line * fix bug in add_numbers * change to self * drop add_coordinate_labels * add somed docs * improve docs * forgot :class: * fix typing * remove redundancies * typing error * dict --> Dict * introduce scaling in `add_labels` quite annoying to manually adjust size (especially without `font_size`) * added tests * don't use Text in testing
1 parent 3ac72fe commit 3cd2a27

File tree

5 files changed

+96
-5
lines changed

5 files changed

+96
-5
lines changed

manim/mobject/coordinate_systems.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import fractions as fr
1414
import numbers
15-
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
15+
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union
1616

1717
import numpy as np
1818
from colour import Color
@@ -276,7 +276,11 @@ def get_axis_labels(
276276
return self.axis_labels
277277

278278
def add_coordinates(
279-
self, *axes_numbers: Optional[Iterable[float]], **kwargs
279+
self,
280+
*axes_numbers: Union[
281+
Optional[Iterable[float]], Union[Dict[float, Union[str, float, "Mobject"]]]
282+
],
283+
**kwargs,
280284
) -> VGroup:
281285
"""Adds labels to the axes.
282286
@@ -294,6 +298,17 @@ def add_coordinates(
294298
ax.add_coordinates(x_labels, None, z_labels) # default y labels, custom x & z labels
295299
ax.add_coordinates(x_labels) # only x labels
296300
301+
..code-block:: python
302+
303+
# specifically control the position and value of the labels using a dict
304+
ax = Axes(x_range=[0, 7])
305+
x_pos = [x for x in range(1, 8)]
306+
307+
# strings are automatically converted into a `Tex` mobject.
308+
x_vals = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
309+
x_dict = dict(zip(x_pos, x_vals))
310+
ax.add_coordinates(x_dict)
311+
297312
Returns
298313
-------
299314
VGroup
@@ -306,10 +321,13 @@ def add_coordinates(
306321
axes_numbers = [None for _ in range(self.dimension)]
307322

308323
for axis, values in zip(self.axes, axes_numbers):
309-
labels = axis.add_numbers(values, **kwargs)
324+
if isinstance(values, dict):
325+
labels = axis.add_labels(values, **kwargs)
326+
else:
327+
labels = axis.add_numbers(values, **kwargs)
310328
self.coordinate_labels.add(labels)
311329

312-
return self.coordinate_labels
330+
return self
313331

314332
def get_line_from_axis_to_point(
315333
self,

manim/mobject/number_line.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
__all__ = ["NumberLine", "UnitInterval", "NumberLineOld"]
44

55
import operator as op
6+
from typing import TYPE_CHECKING, Dict, Union
67

78
import numpy as np
89

10+
from manim.mobject.svg.tex_mobject import MathTex, Tex
11+
912
from .. import config
1013
from ..constants import *
1114
from ..mobject.geometry import Line
@@ -18,6 +21,9 @@
1821
from ..utils.simple_functions import fdiv
1922
from ..utils.space_ops import normalize
2023

24+
if TYPE_CHECKING:
25+
from manim.mobject.mobject import Mobject
26+
2127

2228
class NumberLine(Line):
2329
"""Creates a number line with tick marks. Number ranges that include both negative and
@@ -340,7 +346,51 @@ def add_numbers(self, x_values=None, excluding=None, **kwargs):
340346
numbers.add(self.get_number_mobject(x, **kwargs))
341347
self.add(numbers)
342348
self.numbers = numbers
343-
return numbers
349+
return self
350+
351+
def add_labels(
352+
self,
353+
dict_values: Dict[float, Union[str, float, "Mobject"]],
354+
direction=None,
355+
buff=None,
356+
):
357+
"""Adds specifically positioned labels to the :class:`~.NumberLine` using a ``dict``."""
358+
if direction is None:
359+
direction = self.label_direction
360+
if buff is None:
361+
buff = self.line_to_number_buff
362+
363+
labels = VGroup()
364+
for x, label in dict_values.items():
365+
366+
label = self.create_label_tex(label)
367+
label.scale(self.number_scale_value)
368+
label.next_to(self.number_to_point(x), direction=direction, buff=buff)
369+
labels.add(label)
370+
371+
self.labels = labels
372+
self.add(labels)
373+
return self
374+
375+
@staticmethod
376+
def create_label_tex(label_tex) -> "Mobject":
377+
"""Checks if the label is a ``float``, ``int`` or a ``str`` and creates a :class:`~.MathTex`/:class:`~.Tex` label accordingly.
378+
379+
Parameters
380+
----------
381+
label_tex : The label to be compared against the above types.
382+
383+
Returns
384+
-------
385+
:class:`~.Mobject`
386+
The label.
387+
"""
388+
389+
if isinstance(label_tex, float) or isinstance(label_tex, int):
390+
label_tex = MathTex(label_tex)
391+
elif isinstance(label_tex, str):
392+
label_tex = Tex(label_tex)
393+
return label_tex
344394

345395
def decimal_places_from_step(self):
346396
step_as_str = str(self.x_step)
Binary file not shown.

tests/test_graphical_units/test_axes.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,13 @@ def test_axes(scene):
1616
)
1717
labels = graph.get_axis_labels()
1818
scene.add(graph, labels)
19+
20+
21+
@frames_comparison
22+
def test_custom_coordinates(scene):
23+
ax = Axes(x_range=[0, 10])
24+
25+
ax.add_coordinates(
26+
dict(zip([x for x in range(1, 10)], [Tex("str") for _ in range(1, 10)]))
27+
)
28+
scene.add(ax)

tests/test_number_line.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from manim import NumberLine
4+
from manim.mobject.numbers import Integer
45

56

67
def test_unit_vector():
@@ -50,3 +51,15 @@ def test_whole_numbers_step_size_default_to_0_decimal_places():
5051
assert actual_decimal_places == expected_decimal_places, (
5152
"Expected 1 decimal place but got " + actual_decimal_places
5253
)
54+
55+
56+
def test_add_labels():
57+
expected_label_length = 6
58+
num_line = NumberLine(x_range=[-4, 4])
59+
num_line.add_labels(
60+
dict(zip([x for x in range(-3, 3)], [Integer(m) for m in range(-1, 5)]))
61+
)
62+
actual_label_length = len(num_line.labels)
63+
assert (
64+
actual_label_length == expected_label_length
65+
), f"Expected a VGroup with {expected_label_length} integers but got {actual_label_length}."

0 commit comments

Comments
 (0)