1212from typing import TYPE_CHECKING , Callable
1313
1414if TYPE_CHECKING :
15+ from typing import Any
16+
17+ from typing_extensions import Self
18+
1519 from manim .mobject .geometry .tips import ArrowTip
1620 from manim .typing import Point3DLike
1721
@@ -164,8 +168,8 @@ def __init__(
164168 decimal_number_config : dict | None = None ,
165169 numbers_to_exclude : Iterable [float ] | None = None ,
166170 numbers_to_include : Iterable [float ] | None = None ,
167- ** kwargs ,
168- ):
171+ ** kwargs : Any ,
172+ ) -> None :
169173 # avoid mutable arguments in defaults
170174 if numbers_to_exclude is None :
171175 numbers_to_exclude = []
@@ -189,6 +193,9 @@ def __init__(
189193
190194 # turn into a NumPy array to scale by just applying the function
191195 self .x_range = np .array (x_range , dtype = float )
196+ self .x_min : float
197+ self .x_max : float
198+ self .x_step : float
192199 self .x_min , self .x_max , self .x_step = scaling .function (self .x_range )
193200 self .length = length
194201 self .unit_size = unit_size
@@ -250,7 +257,9 @@ def __init__(
250257 dict (
251258 zip (
252259 tick_range ,
253- self .scaling .get_custom_labels (
260+ # TODO:
261+ # Argument 2 to "zip" has incompatible type "Iterable[Mobject]"; expected "Iterable[str | float | VMobject]" [arg-type]
262+ self .scaling .get_custom_labels ( # type: ignore[arg-type]
254263 tick_range ,
255264 unit_decimal_places = decimal_number_config [
256265 "num_decimal_places"
@@ -267,21 +276,25 @@ def __init__(
267276 font_size = self .font_size ,
268277 )
269278
270- def rotate_about_zero (self , angle : float , axis : Sequence [float ] = OUT , ** kwargs ):
279+ def rotate_about_zero (
280+ self , angle : float , axis : Sequence [float ] = OUT , ** kwargs : Any
281+ ) -> Self :
271282 return self .rotate_about_number (0 , angle , axis , ** kwargs )
272283
273284 def rotate_about_number (
274- self , number : float , angle : float , axis : Sequence [float ] = OUT , ** kwargs
275- ):
285+ self , number : float , angle : float , axis : Sequence [float ] = OUT , ** kwargs : Any
286+ ) -> Self :
276287 return self .rotate (angle , axis , about_point = self .n2p (number ), ** kwargs )
277288
278- def add_ticks (self ):
289+ def add_ticks (self ) -> None :
279290 """Adds ticks to the number line. Ticks can be accessed after creation
280291 via ``self.ticks``.
281292 """
282293 ticks = VGroup ()
283294 elongated_tick_size = self .tick_size * self .longer_tick_multiple
284- elongated_tick_offsets = self .numbers_with_elongated_ticks - self .x_min
295+ elongated_tick_offsets = (
296+ np .array (self .numbers_with_elongated_ticks ) - self .x_min
297+ )
285298 for x in self .get_tick_range ():
286299 size = self .tick_size
287300 if np .any (np .isclose (x - self .x_min , elongated_tick_offsets )):
@@ -413,19 +426,22 @@ def point_to_number(self, point: Sequence[float]) -> float:
413426 point = np .asarray (point )
414427 start , end = self .get_start_and_end ()
415428 unit_vect = normalize (end - start )
416- proportion = np .dot (point - start , unit_vect ) / np .dot (end - start , unit_vect )
429+ proportion : float = np .dot (point - start , unit_vect ) / np .dot (
430+ end - start , unit_vect
431+ )
417432 return interpolate (self .x_min , self .x_max , proportion )
418433
419434 def n2p (self , number : float | np .ndarray ) -> np .ndarray :
420435 """Abbreviation for :meth:`~.NumberLine.number_to_point`."""
421436 return self .number_to_point (number )
422437
423- def p2n (self , point : Sequence [ float ] ) -> float :
438+ def p2n (self , point : Point3DLike ) -> float :
424439 """Abbreviation for :meth:`~.NumberLine.point_to_number`."""
425440 return self .point_to_number (point )
426441
427442 def get_unit_size (self ) -> float :
428- return self .get_length () / (self .x_range [1 ] - self .x_range [0 ])
443+ val : float = self .get_length () / (self .x_range [1 ] - self .x_range [0 ])
444+ return val
429445
430446 def get_unit_vector (self ) -> np .ndarray :
431447 return super ().get_unit_vector () * self .unit_size
@@ -436,8 +452,8 @@ def get_number_mobject(
436452 direction : Sequence [float ] | None = None ,
437453 buff : float | None = None ,
438454 font_size : float | None = None ,
439- label_constructor : VMobject | None = None ,
440- ** number_config ,
455+ label_constructor : type [ MathTex ] | None = None ,
456+ ** number_config : dict [ str , Any ] ,
441457 ) -> VMobject :
442458 """Generates a positioned :class:`~.DecimalNumber` mobject
443459 generated according to ``label_constructor``.
@@ -476,7 +492,12 @@ def get_number_mobject(
476492 label_constructor = self .label_constructor
477493
478494 num_mob = DecimalNumber (
479- x , font_size = font_size , mob_class = label_constructor , ** number_config
495+ # TODO:
496+ # error: Argument 4 to "DecimalNumber" has incompatible type "**dict[str, dict[str, Any]]"; expected "int" [arg-type]
497+ x ,
498+ font_size = font_size ,
499+ mob_class = label_constructor ,
500+ ** number_config , # type: ignore[arg-type]
480501 )
481502
482503 num_mob .next_to (self .number_to_point (x ), direction = direction , buff = buff )
@@ -485,7 +506,7 @@ def get_number_mobject(
485506 num_mob .shift (num_mob [0 ].width * LEFT / 2 )
486507 return num_mob
487508
488- def get_number_mobjects (self , * numbers , ** kwargs ) -> VGroup :
509+ def get_number_mobjects (self , * numbers : float , ** kwargs : Any ) -> VGroup :
489510 if len (numbers ) == 0 :
490511 numbers = self .default_numbers_to_display ()
491512 return VGroup ([self .get_number_mobject (number , ** kwargs ) for number in numbers ])
@@ -498,9 +519,9 @@ def add_numbers(
498519 x_values : Iterable [float ] | None = None ,
499520 excluding : Iterable [float ] | None = None ,
500521 font_size : float | None = None ,
501- label_constructor : VMobject | None = None ,
502- ** kwargs ,
503- ):
522+ label_constructor : type [ MathTex ] | None = None ,
523+ ** kwargs : Any ,
524+ ) -> Self :
504525 """Adds :class:`~.DecimalNumber` mobjects representing their position
505526 at each tick of the number line. The numbers can be accessed after creation
506527 via ``self.numbers``.
@@ -551,11 +572,11 @@ def add_numbers(
551572 def add_labels (
552573 self ,
553574 dict_values : dict [float , str | float | VMobject ],
554- direction : Sequence [float ] = None ,
575+ direction : Sequence [float ] | None = None ,
555576 buff : float | None = None ,
556577 font_size : float | None = None ,
557- label_constructor : VMobject | None = None ,
558- ):
578+ label_constructor : type [ MathTex ] | None = None ,
579+ ) -> Self :
559580 """Adds specifically positioned labels to the :class:`~.NumberLine` using a ``dict``.
560581 The labels can be accessed after creation via ``self.labels``.
561582
@@ -598,6 +619,7 @@ def add_labels(
598619 label = self ._create_label_tex (label , label_constructor )
599620
600621 if hasattr (label , "font_size" ):
622+ assert isinstance (label , MathTex )
601623 label .font_size = font_size
602624 else :
603625 raise AttributeError (f"{ label } is not compatible with add_labels." )
@@ -612,7 +634,7 @@ def _create_label_tex(
612634 self ,
613635 label_tex : str | float | VMobject ,
614636 label_constructor : Callable | None = None ,
615- ** kwargs ,
637+ ** kwargs : Any ,
616638 ) -> VMobject :
617639 """Checks if the label is a :class:`~.VMobject`, otherwise, creates a
618640 label by passing ``label_tex`` to ``label_constructor``.
@@ -633,24 +655,25 @@ def _create_label_tex(
633655 :class:`~.VMobject`
634656 The label.
635657 """
636- if label_constructor is None :
637- label_constructor = self .label_constructor
638658 if isinstance (label_tex , (VMobject , OpenGLVMobject )):
639659 return label_tex
640- else :
660+ if label_constructor is None :
661+ label_constructor = self .label_constructor
662+ if isinstance (label_tex , str ):
641663 return label_constructor (label_tex , ** kwargs )
664+ return label_constructor (str (label_tex ), ** kwargs )
642665
643666 @staticmethod
644- def _decimal_places_from_step (step ) -> int :
645- step = str (step )
646- if "." not in step :
667+ def _decimal_places_from_step (step : float ) -> int :
668+ step_str = str (step )
669+ if "." not in step_str :
647670 return 0
648- return len (step .split ("." )[- 1 ])
671+ return len (step_str .split ("." )[- 1 ])
649672
650- def __matmul__ (self , other : float ):
673+ def __matmul__ (self , other : float ) -> np . ndarray :
651674 return self .n2p (other )
652675
653- def __rmatmul__ (self , other : Point3DLike | Mobject ):
676+ def __rmatmul__ (self , other : Point3DLike | Mobject ) -> float :
654677 if isinstance (other , Mobject ):
655678 other = other .get_center ()
656679 return self .p2n (other )
@@ -659,11 +682,11 @@ def __rmatmul__(self, other: Point3DLike | Mobject):
659682class UnitInterval (NumberLine ):
660683 def __init__ (
661684 self ,
662- unit_size = 10 ,
663- numbers_with_elongated_ticks = None ,
664- decimal_number_config = None ,
665- ** kwargs ,
666- ):
685+ unit_size : float = 10 ,
686+ numbers_with_elongated_ticks : list [ float ] | None = None ,
687+ decimal_number_config : dict [ str , Any ] | None = None ,
688+ ** kwargs : Any ,
689+ ) -> None :
667690 numbers_with_elongated_ticks = (
668691 [0 , 1 ]
669692 if numbers_with_elongated_ticks is None
0 commit comments