Skip to content

Commit 398cd26

Browse files
committed
Add type annotations to mobject/graphing/number_line.py
1 parent 8f274f4 commit 398cd26

File tree

3 files changed

+70
-37
lines changed

3 files changed

+70
-37
lines changed

manim/mobject/graphing/number_line.py

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from typing import TYPE_CHECKING, Callable
1313

1414
if TYPE_CHECKING:
15+
from typing import Any
16+
17+
from typing_extensions import Self
18+
1519
from manim.mobject.geometry.tips import ArrowTip
1620
from manim.typing import Point3DLike
1721

@@ -164,8 +168,8 @@ def __init__(
164168
decimal_number_config: dict | None = None,
165169
numbers_to_exclude: Iterable[float] | None = None,
166170
numbers_to_include: Iterable[float] | None = None,
167-
**kwargs,
168-
):
171+
**kwargs: Any,
172+
) -> None:
169173
# avoid mutable arguments in defaults
170174
if numbers_to_exclude is None:
171175
numbers_to_exclude = []
@@ -189,6 +193,9 @@ def __init__(
189193

190194
# turn into a NumPy array to scale by just applying the function
191195
self.x_range = np.array(x_range, dtype=float)
196+
self.x_min: float
197+
self.x_max: float
198+
self.x_step: float
192199
self.x_min, self.x_max, self.x_step = scaling.function(self.x_range)
193200
self.length = length
194201
self.unit_size = unit_size
@@ -250,7 +257,9 @@ def __init__(
250257
dict(
251258
zip(
252259
tick_range,
253-
self.scaling.get_custom_labels(
260+
# TODO:
261+
# Argument 2 to "zip" has incompatible type "Iterable[Mobject]"; expected "Iterable[str | float | VMobject]" [arg-type]
262+
self.scaling.get_custom_labels( # type: ignore[arg-type]
254263
tick_range,
255264
unit_decimal_places=decimal_number_config[
256265
"num_decimal_places"
@@ -267,21 +276,25 @@ def __init__(
267276
font_size=self.font_size,
268277
)
269278

270-
def rotate_about_zero(self, angle: float, axis: Sequence[float] = OUT, **kwargs):
279+
def rotate_about_zero(
280+
self, angle: float, axis: Sequence[float] = OUT, **kwargs: Any
281+
) -> Self:
271282
return self.rotate_about_number(0, angle, axis, **kwargs)
272283

273284
def rotate_about_number(
274-
self, number: float, angle: float, axis: Sequence[float] = OUT, **kwargs
275-
):
285+
self, number: float, angle: float, axis: Sequence[float] = OUT, **kwargs: Any
286+
) -> Self:
276287
return self.rotate(angle, axis, about_point=self.n2p(number), **kwargs)
277288

278-
def add_ticks(self):
289+
def add_ticks(self) -> None:
279290
"""Adds ticks to the number line. Ticks can be accessed after creation
280291
via ``self.ticks``.
281292
"""
282293
ticks = VGroup()
283294
elongated_tick_size = self.tick_size * self.longer_tick_multiple
284-
elongated_tick_offsets = self.numbers_with_elongated_ticks - self.x_min
295+
elongated_tick_offsets = (
296+
np.array(self.numbers_with_elongated_ticks) - self.x_min
297+
)
285298
for x in self.get_tick_range():
286299
size = self.tick_size
287300
if np.any(np.isclose(x - self.x_min, elongated_tick_offsets)):
@@ -413,19 +426,22 @@ def point_to_number(self, point: Sequence[float]) -> float:
413426
point = np.asarray(point)
414427
start, end = self.get_start_and_end()
415428
unit_vect = normalize(end - start)
416-
proportion = np.dot(point - start, unit_vect) / np.dot(end - start, unit_vect)
429+
proportion: float = np.dot(point - start, unit_vect) / np.dot(
430+
end - start, unit_vect
431+
)
417432
return interpolate(self.x_min, self.x_max, proportion)
418433

419434
def n2p(self, number: float | np.ndarray) -> np.ndarray:
420435
"""Abbreviation for :meth:`~.NumberLine.number_to_point`."""
421436
return self.number_to_point(number)
422437

423-
def p2n(self, point: Sequence[float]) -> float:
438+
def p2n(self, point: Point3DLike) -> float:
424439
"""Abbreviation for :meth:`~.NumberLine.point_to_number`."""
425440
return self.point_to_number(point)
426441

427442
def get_unit_size(self) -> float:
428-
return self.get_length() / (self.x_range[1] - self.x_range[0])
443+
val: float = self.get_length() / (self.x_range[1] - self.x_range[0])
444+
return val
429445

430446
def get_unit_vector(self) -> np.ndarray:
431447
return super().get_unit_vector() * self.unit_size
@@ -436,8 +452,8 @@ def get_number_mobject(
436452
direction: Sequence[float] | None = None,
437453
buff: float | None = None,
438454
font_size: float | None = None,
439-
label_constructor: VMobject | None = None,
440-
**number_config,
455+
label_constructor: type[MathTex] | None = None,
456+
**number_config: dict[str, Any],
441457
) -> VMobject:
442458
"""Generates a positioned :class:`~.DecimalNumber` mobject
443459
generated according to ``label_constructor``.
@@ -476,7 +492,12 @@ def get_number_mobject(
476492
label_constructor = self.label_constructor
477493

478494
num_mob = DecimalNumber(
479-
x, font_size=font_size, mob_class=label_constructor, **number_config
495+
# TODO:
496+
# error: Argument 4 to "DecimalNumber" has incompatible type "**dict[str, dict[str, Any]]"; expected "int" [arg-type]
497+
x,
498+
font_size=font_size,
499+
mob_class=label_constructor,
500+
**number_config, # type: ignore[arg-type]
480501
)
481502

482503
num_mob.next_to(self.number_to_point(x), direction=direction, buff=buff)
@@ -485,7 +506,7 @@ def get_number_mobject(
485506
num_mob.shift(num_mob[0].width * LEFT / 2)
486507
return num_mob
487508

488-
def get_number_mobjects(self, *numbers, **kwargs) -> VGroup:
509+
def get_number_mobjects(self, *numbers: float, **kwargs: Any) -> VGroup:
489510
if len(numbers) == 0:
490511
numbers = self.default_numbers_to_display()
491512
return VGroup([self.get_number_mobject(number, **kwargs) for number in numbers])
@@ -498,9 +519,9 @@ def add_numbers(
498519
x_values: Iterable[float] | None = None,
499520
excluding: Iterable[float] | None = None,
500521
font_size: float | None = None,
501-
label_constructor: VMobject | None = None,
502-
**kwargs,
503-
):
522+
label_constructor: type[MathTex] | None = None,
523+
**kwargs: Any,
524+
) -> Self:
504525
"""Adds :class:`~.DecimalNumber` mobjects representing their position
505526
at each tick of the number line. The numbers can be accessed after creation
506527
via ``self.numbers``.
@@ -551,11 +572,11 @@ def add_numbers(
551572
def add_labels(
552573
self,
553574
dict_values: dict[float, str | float | VMobject],
554-
direction: Sequence[float] = None,
575+
direction: Sequence[float] | None = None,
555576
buff: float | None = None,
556577
font_size: float | None = None,
557-
label_constructor: VMobject | None = None,
558-
):
578+
label_constructor: type[MathTex] | None = None,
579+
) -> Self:
559580
"""Adds specifically positioned labels to the :class:`~.NumberLine` using a ``dict``.
560581
The labels can be accessed after creation via ``self.labels``.
561582
@@ -598,6 +619,7 @@ def add_labels(
598619
label = self._create_label_tex(label, label_constructor)
599620

600621
if hasattr(label, "font_size"):
622+
assert isinstance(label, MathTex)
601623
label.font_size = font_size
602624
else:
603625
raise AttributeError(f"{label} is not compatible with add_labels.")
@@ -612,7 +634,7 @@ def _create_label_tex(
612634
self,
613635
label_tex: str | float | VMobject,
614636
label_constructor: Callable | None = None,
615-
**kwargs,
637+
**kwargs: Any,
616638
) -> VMobject:
617639
"""Checks if the label is a :class:`~.VMobject`, otherwise, creates a
618640
label by passing ``label_tex`` to ``label_constructor``.
@@ -633,24 +655,25 @@ def _create_label_tex(
633655
:class:`~.VMobject`
634656
The label.
635657
"""
636-
if label_constructor is None:
637-
label_constructor = self.label_constructor
638658
if isinstance(label_tex, (VMobject, OpenGLVMobject)):
639659
return label_tex
640-
else:
660+
if label_constructor is None:
661+
label_constructor = self.label_constructor
662+
if isinstance(label_tex, str):
641663
return label_constructor(label_tex, **kwargs)
664+
return label_constructor(str(label_tex), **kwargs)
642665

643666
@staticmethod
644-
def _decimal_places_from_step(step) -> int:
645-
step = str(step)
646-
if "." not in step:
667+
def _decimal_places_from_step(step: float) -> int:
668+
step_str = str(step)
669+
if "." not in step_str:
647670
return 0
648-
return len(step.split(".")[-1])
671+
return len(step_str.split(".")[-1])
649672

650-
def __matmul__(self, other: float):
673+
def __matmul__(self, other: float) -> np.ndarray:
651674
return self.n2p(other)
652675

653-
def __rmatmul__(self, other: Point3DLike | Mobject):
676+
def __rmatmul__(self, other: Point3DLike | Mobject) -> float:
654677
if isinstance(other, Mobject):
655678
other = other.get_center()
656679
return self.p2n(other)
@@ -659,11 +682,11 @@ def __rmatmul__(self, other: Point3DLike | Mobject):
659682
class UnitInterval(NumberLine):
660683
def __init__(
661684
self,
662-
unit_size=10,
663-
numbers_with_elongated_ticks=None,
664-
decimal_number_config=None,
665-
**kwargs,
666-
):
685+
unit_size: float = 10,
686+
numbers_with_elongated_ticks: list[float] | None = None,
687+
decimal_number_config: dict[str, Any] | None = None,
688+
**kwargs: Any,
689+
) -> None:
667690
numbers_with_elongated_ticks = (
668691
[0, 1]
669692
if numbers_with_elongated_ticks is None

manim/mobject/graphing/scale.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from manim.mobject.text.numbers import Integer
1212

1313
if TYPE_CHECKING:
14-
from typing import Callable
14+
from typing import Callable, overload
1515

1616
from manim.mobject.mobject import Mobject
1717

@@ -28,6 +28,12 @@ class _ScaleBase:
2828
def __init__(self, custom_labels: bool = False):
2929
self.custom_labels = custom_labels
3030

31+
@overload
32+
def function(self, value: float) -> float: ...
33+
34+
@overload
35+
def function(self, value: np.array) -> np.array: ...
36+
3137
def function(self, value: float) -> float:
3238
"""The function that will be used to scale the values.
3339
@@ -61,6 +67,7 @@ def inverse_function(self, value: float) -> float:
6167
def get_custom_labels(
6268
self,
6369
val_range: Iterable[float],
70+
**kw_args: Any,
6471
) -> Iterable[Mobject]:
6572
"""Custom instructions for generating labels along an axis.
6673

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ ignore_errors = False
7373
[mypy-manim.mobject.graphing.scale.*]
7474
ignore_errors = False
7575

76+
[mypy-manim.mobject.graphing.number_line.*]
77+
ignore_errors = False
78+
7679
[mypy-manim.mobject.graphing.probability.*]
7780
ignore_errors = False
7881

0 commit comments

Comments
 (0)