Skip to content

Commit df36f4f

Browse files
Add type annotations to tex_mobject.py (#4355)
* Starting to work on type annotations for tex_mobject.py * More work * Finished. * Code cleanup. * ... * Removed the ignore errors line in mypy for tex_mobject * Fix typing of colors --------- Co-authored-by: Francisco Manríquez Novoa <[email protected]>
1 parent 7ea765a commit df36f4f

File tree

4 files changed

+69
-56
lines changed

4 files changed

+69
-56
lines changed

manim/mobject/svg/svg_mobject.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import svgelements as se
1212

1313
from manim import config, logger
14+
from manim.utils.color import ParsableManimColor
1415

1516
from ...constants import RIGHT
1617
from ...utils.bezier import get_quadratic_approximation_of_cubic
@@ -99,11 +100,11 @@ def __init__(
99100
should_center: bool = True,
100101
height: float | None = 2,
101102
width: float | None = None,
102-
color: str | None = None,
103+
color: ParsableManimColor | None = None,
103104
opacity: float | None = None,
104-
fill_color: str | None = None,
105+
fill_color: ParsableManimColor | None = None,
105106
fill_opacity: float | None = None,
106-
stroke_color: str | None = None,
107+
stroke_color: ParsableManimColor | None = None,
107108
stroke_opacity: float | None = None,
108109
stroke_width: float | None = None,
109110
svg_default: dict | None = None,

manim/mobject/text/tex_mobject.py

Lines changed: 64 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
import itertools as it
2727
import operator as op
2828
import re
29-
from collections.abc import Iterable
29+
from collections.abc import Iterable, Sequence
3030
from functools import reduce
3131
from textwrap import dedent
3232
from typing import Any
3333

34+
from typing_extensions import Self
35+
3436
from manim import config, logger
3537
from manim.constants import *
3638
from manim.mobject.geometry.line import Line
@@ -39,8 +41,6 @@
3941
from manim.utils.tex import TexTemplate
4042
from manim.utils.tex_file_writing import tex_to_svg_file
4143

42-
tex_string_to_mob_map = {}
43-
4444

4545
class SingleStringMathTex(SVGMobject):
4646
"""Elementary building block for rendering text with LaTeX.
@@ -74,9 +74,8 @@ def __init__(
7474
self.tex_environment = tex_environment
7575
if tex_template is None:
7676
tex_template = config["tex_template"]
77-
self.tex_template = tex_template
77+
self.tex_template: TexTemplate = tex_template
7878

79-
assert isinstance(tex_string, str)
8079
self.tex_string = tex_string
8180
file_name = tex_to_svg_file(
8281
self._get_modified_expression(tex_string),
@@ -106,7 +105,7 @@ def __init__(
106105
if self.organize_left_to_right:
107106
self._organize_submobjects_left_to_right()
108107

109-
def __repr__(self):
108+
def __repr__(self) -> str:
110109
return f"{type(self).__name__}({repr(self.tex_string)})"
111110

112111
@property
@@ -115,7 +114,7 @@ def font_size(self) -> float:
115114
return self.height / self.initial_height / SCALE_FACTOR_PER_FONT_POINT
116115

117116
@font_size.setter
118-
def font_size(self, font_val: float):
117+
def font_size(self, font_val: float) -> None:
119118
if font_val <= 0:
120119
raise ValueError("font_size must be greater than 0.")
121120
elif self.height > 0:
@@ -126,13 +125,13 @@ def font_size(self, font_val: float):
126125
# font_size does not depend on current size.
127126
self.scale(font_val / self.font_size)
128127

129-
def _get_modified_expression(self, tex_string):
128+
def _get_modified_expression(self, tex_string: str) -> str:
130129
result = tex_string
131130
result = result.strip()
132131
result = self._modify_special_strings(result)
133132
return result
134133

135-
def _modify_special_strings(self, tex):
134+
def _modify_special_strings(self, tex: str) -> str:
136135
tex = tex.strip()
137136
should_add_filler = reduce(
138137
op.or_,
@@ -185,7 +184,7 @@ def _modify_special_strings(self, tex):
185184
tex = ""
186185
return tex
187186

188-
def _remove_stray_braces(self, tex):
187+
def _remove_stray_braces(self, tex: str) -> str:
189188
r"""
190189
Makes :class:`~.MathTex` resilient to unmatched braces.
191190
@@ -203,14 +202,14 @@ def _remove_stray_braces(self, tex):
203202
num_rights += 1
204203
return tex
205204

206-
def _organize_submobjects_left_to_right(self):
205+
def _organize_submobjects_left_to_right(self) -> Self:
207206
self.sort(lambda p: p[0])
208207
return self
209208

210-
def get_tex_string(self):
209+
def get_tex_string(self) -> str:
211210
return self.tex_string
212211

213-
def init_colors(self, propagate_colors=True):
212+
def init_colors(self, propagate_colors: bool = True) -> Self:
214213
for submobject in self.submobjects:
215214
# needed to preserve original (non-black)
216215
# TeX colors of individual submobjects
@@ -221,6 +220,7 @@ def init_colors(self, propagate_colors=True):
221220
submobject.init_colors()
222221
elif config.renderer == RendererType.CAIRO:
223222
submobject.init_colors(propagate_colors=propagate_colors)
223+
return self
224224

225225

226226
class MathTex(SingleStringMathTex):
@@ -256,21 +256,22 @@ def construct(self):
256256

257257
def __init__(
258258
self,
259-
*tex_strings,
259+
*tex_strings: str,
260260
arg_separator: str = " ",
261261
substrings_to_isolate: Iterable[str] | None = None,
262-
tex_to_color_map: dict[str, ManimColor] = None,
262+
tex_to_color_map: dict[str, ManimColor] | None = None,
263263
tex_environment: str = "align*",
264-
**kwargs,
264+
**kwargs: Any,
265265
):
266266
self.tex_template = kwargs.pop("tex_template", config["tex_template"])
267267
self.arg_separator = arg_separator
268268
self.substrings_to_isolate = (
269269
[] if substrings_to_isolate is None else substrings_to_isolate
270270
)
271-
self.tex_to_color_map = tex_to_color_map
272-
if self.tex_to_color_map is None:
273-
self.tex_to_color_map = {}
271+
if tex_to_color_map is None:
272+
self.tex_to_color_map: dict[str, ManimColor] = {}
273+
else:
274+
self.tex_to_color_map = tex_to_color_map
274275
self.tex_environment = tex_environment
275276
self.brace_notation_split_occurred = False
276277
self.tex_strings = self._break_up_tex_strings(tex_strings)
@@ -302,12 +303,14 @@ def __init__(
302303
if self.organize_left_to_right:
303304
self._organize_submobjects_left_to_right()
304305

305-
def _break_up_tex_strings(self, tex_strings):
306+
def _break_up_tex_strings(self, tex_strings: Sequence[str]) -> list[str]:
306307
# Separate out anything surrounded in double braces
307308
pre_split_length = len(tex_strings)
308-
tex_strings = [re.split("{{(.*?)}}", str(t)) for t in tex_strings]
309-
tex_strings = sum(tex_strings, [])
310-
if len(tex_strings) > pre_split_length:
309+
tex_strings_brace_splitted = [
310+
re.split("{{(.*?)}}", str(t)) for t in tex_strings
311+
]
312+
tex_strings_combined = sum(tex_strings_brace_splitted, [])
313+
if len(tex_strings_combined) > pre_split_length:
311314
self.brace_notation_split_occurred = True
312315

313316
# Separate out any strings specified in the isolate
@@ -325,19 +328,19 @@ def _break_up_tex_strings(self, tex_strings):
325328
pattern = "|".join(patterns)
326329
if pattern:
327330
pieces = []
328-
for s in tex_strings:
331+
for s in tex_strings_combined:
329332
pieces.extend(re.split(pattern, s))
330333
else:
331-
pieces = tex_strings
334+
pieces = tex_strings_combined
332335
return [p for p in pieces if p]
333336

334-
def _break_up_by_substrings(self):
337+
def _break_up_by_substrings(self) -> Self:
335338
"""
336339
Reorganize existing submobjects one layer
337340
deeper based on the structure of tex_strings (as a list
338341
of tex_strings)
339342
"""
340-
new_submobjects = []
343+
new_submobjects: list[VMobject] = []
341344
curr_index = 0
342345
for tex_string in self.tex_strings:
343346
sub_tex_mob = SingleStringMathTex(
@@ -359,8 +362,10 @@ def _break_up_by_substrings(self):
359362
self.submobjects = new_submobjects
360363
return self
361364

362-
def get_parts_by_tex(self, tex, substring=True, case_sensitive=True):
363-
def test(tex1, tex2):
365+
def get_parts_by_tex(
366+
self, tex: str, substring: bool = True, case_sensitive: bool = True
367+
) -> VGroup:
368+
def test(tex1: str, tex2: str) -> bool:
364369
if not case_sensitive:
365370
tex1 = tex1.lower()
366371
tex2 = tex2.lower()
@@ -371,19 +376,25 @@ def test(tex1, tex2):
371376

372377
return VGroup(*(m for m in self.submobjects if test(tex, m.get_tex_string())))
373378

374-
def get_part_by_tex(self, tex, **kwargs):
379+
def get_part_by_tex(self, tex: str, **kwargs: Any) -> MathTex | None:
375380
all_parts = self.get_parts_by_tex(tex, **kwargs)
376381
return all_parts[0] if all_parts else None
377382

378-
def set_color_by_tex(self, tex, color, **kwargs):
383+
def set_color_by_tex(
384+
self, tex: str, color: ParsableManimColor, **kwargs: Any
385+
) -> Self:
379386
parts_to_color = self.get_parts_by_tex(tex, **kwargs)
380387
for part in parts_to_color:
381388
part.set_color(color)
382389
return self
383390

384391
def set_opacity_by_tex(
385-
self, tex: str, opacity: float = 0.5, remaining_opacity: float = None, **kwargs
386-
):
392+
self,
393+
tex: str,
394+
opacity: float = 0.5,
395+
remaining_opacity: float | None = None,
396+
**kwargs: Any,
397+
) -> Self:
387398
"""
388399
Sets the opacity of the tex specified. If 'remaining_opacity' is specified,
389400
then the remaining tex will be set to that opacity.
@@ -404,7 +415,9 @@ def set_opacity_by_tex(
404415
part.set_opacity(opacity)
405416
return self
406417

407-
def set_color_by_tex_to_color_map(self, texs_to_color_map, **kwargs):
418+
def set_color_by_tex_to_color_map(
419+
self, texs_to_color_map: dict[str, ManimColor], **kwargs: Any
420+
) -> Self:
408421
for texs, color in list(texs_to_color_map.items()):
409422
try:
410423
# If the given key behaves like tex_strings
@@ -416,17 +429,19 @@ def set_color_by_tex_to_color_map(self, texs_to_color_map, **kwargs):
416429
self.set_color_by_tex(tex, color, **kwargs)
417430
return self
418431

419-
def index_of_part(self, part):
432+
def index_of_part(self, part: MathTex) -> int:
420433
split_self = self.split()
421434
if part not in split_self:
422435
raise ValueError("Trying to get index of part not in MathTex")
423436
return split_self.index(part)
424437

425-
def index_of_part_by_tex(self, tex, **kwargs):
438+
def index_of_part_by_tex(self, tex: str, **kwargs: Any) -> int:
426439
part = self.get_part_by_tex(tex, **kwargs)
440+
if part is None:
441+
return -1
427442
return self.index_of_part(part)
428443

429-
def sort_alphabetically(self):
444+
def sort_alphabetically(self) -> None:
430445
self.submobjects.sort(key=lambda m: m.get_tex_string())
431446

432447

@@ -482,11 +497,11 @@ def construct(self):
482497

483498
def __init__(
484499
self,
485-
*items,
486-
buff=MED_LARGE_BUFF,
487-
dot_scale_factor=2,
488-
tex_environment=None,
489-
**kwargs,
500+
*items: str,
501+
buff: float = MED_LARGE_BUFF,
502+
dot_scale_factor: float = 2,
503+
tex_environment: str = "",
504+
**kwargs: Any,
490505
):
491506
self.buff = buff
492507
self.dot_scale_factor = dot_scale_factor
@@ -501,12 +516,12 @@ def __init__(
501516
part.add_to_back(dot)
502517
self.arrange(DOWN, aligned_edge=LEFT, buff=self.buff)
503518

504-
def fade_all_but(self, index_or_string, opacity=0.5):
519+
def fade_all_but(self, index_or_string: int | str, opacity: float = 0.5) -> None:
505520
arg = index_or_string
506521
if isinstance(arg, str):
507522
part = self.get_part_by_tex(arg)
508523
elif isinstance(arg, int):
509-
part = self.submobjects[arg]
524+
part = self.submobjects[arg] # type: ignore[assignment]
510525
else:
511526
raise TypeError(f"Expected int or string, got {arg}")
512527
for other_part in self.submobjects:
@@ -536,11 +551,11 @@ def construct(self):
536551

537552
def __init__(
538553
self,
539-
*text_parts,
540-
include_underline=True,
541-
match_underline_width_to_text=False,
542-
underline_buff=MED_SMALL_BUFF,
543-
**kwargs,
554+
*text_parts: str,
555+
include_underline: bool = True,
556+
match_underline_width_to_text: bool = False,
557+
underline_buff: float = MED_SMALL_BUFF,
558+
**kwargs: Any,
544559
):
545560
self.include_underline = include_underline
546561
self.match_underline_width_to_text = match_underline_width_to_text

manim/scene/vector_space_scene.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def add_transformable_label(
962962
label_mob.target_text = new_label # type: ignore[attr-defined]
963963
else:
964964
label_mob.target_text = ( # type: ignore[attr-defined]
965-
f"{transformation_name}({label_mob.get_tex_string()})" # type: ignore[no-untyped-call]
965+
f"{transformation_name}({label_mob.get_tex_string()})"
966966
)
967967
label_mob.vector = vector # type: ignore[attr-defined]
968968
label_mob.kwargs = kwargs # type: ignore[attr-defined]

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,6 @@ ignore_errors = True
150150
[mypy-manim.mobject.table]
151151
ignore_errors = True
152152

153-
[mypy-manim.mobject.text.tex_mobject]
154-
ignore_errors = True
155-
156153
[mypy-manim.mobject.text.text_mobject]
157154
ignore_errors = True
158155

0 commit comments

Comments
 (0)