Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 85 additions & 35 deletions manim/mobject/svg/svg_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..geometry.line import Line
from ..geometry.polygram import Polygon, Rectangle, RoundedRectangle
from ..opengl.opengl_compatibility import ConvertToOpenGL
from ..types.vectorized_mobject import VMobject
from ..types.vectorized_mobject import VGroup, VMobject

__all__ = ["SVGMobject", "VMobjectFromSVGPath"]

Expand Down Expand Up @@ -127,6 +127,7 @@ def __init__(
self.stroke_color = stroke_color
self.stroke_opacity = stroke_opacity # type: ignore[assignment]
self.stroke_width = stroke_width # type: ignore[assignment]
self.id_to_vgroup_dict: dict[str, VGroup] = {}
if self.stroke_width is None:
self.stroke_width = 0

Expand Down Expand Up @@ -203,8 +204,11 @@ def generate_mobject(self) -> None:
svg = se.SVG.parse(modified_file_path)
modified_file_path.unlink()

mobjects = self.get_mobjects_from(svg)
self.add(*mobjects)
mobjects_dict = self.get_mobjects_from(svg)
for key, value in mobjects_dict.items():
self.id_to_vgroup_dict[key] = value
self.add(value)

self.flip(RIGHT) # Flip y

def get_file_path(self) -> Path:
Expand Down Expand Up @@ -258,45 +262,91 @@ def generate_config_style_dict(self) -> dict[str, str]:
result[svg_key] = str(svg_default_dict[style_key])
return result

def get_mobjects_from(self, svg: se.SVG) -> list[VMobject]:
def get_mobjects_from(self, svg: se.SVG) -> dict[str, VGroup]:
"""Convert the elements of the SVG to a list of mobjects.

Parameters
----------
svg
The parsed SVG file.
"""
result: list[VMobject] = []
for shape in svg.elements():
# can we combine the two continue cases into one?
if isinstance(shape, se.Group): # noqa: SIM114
continue
elif isinstance(shape, se.Path):
mob: VMobject = self.path_to_mobject(shape)
elif isinstance(shape, se.SimpleLine):
mob = self.line_to_mobject(shape)
elif isinstance(shape, se.Rect):
mob = self.rect_to_mobject(shape)
elif isinstance(shape, (se.Circle, se.Ellipse)):
mob = self.ellipse_to_mobject(shape)
elif isinstance(shape, se.Polygon):
mob = self.polygon_to_mobject(shape)
elif isinstance(shape, se.Polyline):
mob = self.polyline_to_mobject(shape)
elif isinstance(shape, se.Text):
mob = self.text_to_mobject(shape)
elif isinstance(shape, se.Use) or type(shape) is se.SVGElement:
continue
else:
logger.warning(f"Unsupported element type: {type(shape)}")
continue
if mob is None or not mob.has_points():
continue
self.apply_style_to_mobject(mob, shape)
if isinstance(shape, se.Transformable) and shape.apply:
self.handle_transform(mob, shape.transform)
result.append(mob)
return result
stack: list[tuple[se.SVGElement, int]] = []
stack.append((svg, 1))
group_id_number = 0
vgroup_stack: list[str] = ["root"]
vgroup_names: list[str] = ["root"]
vgroups: dict[str, VGroup] = {"root": VGroup()}
while len(stack) > 0:
element, depth = stack.pop()
# Reduce stack heights
vgroup_stack = vgroup_stack[0:(depth)]
try:
group_name = str(element.values["id"])
except Exception:
group_name = f"numbered_group_{group_id_number}"
group_id_number += 1
if isinstance(element, se.Group):
vg = VGroup()
vgroups[group_name] = vg
vgroup_names.append(group_name)
vgroup_stack.append(group_name)
parent_name = vgroup_stack[depth - 1]
vgroups[parent_name].add(vgroups[group_name])

if isinstance(element, (se.Group, se.Use)):
for subelement in element[::-1]:
stack.append((subelement, depth + 1))
# Add element to the parent vgroup
try:
parent_name = vgroup_stack[depth - 2]
if isinstance(
element,
(
se.Path,
se.SimpleLine,
se.Rect,
se.Circle,
se.Ellipse,
se.Polygon,
se.Polyline,
se.Text,
),
):
mob = self.get_mob_from_shape_element(element)
vgroups[parent_name].add(mob)
except Exception as e:
print(e)

return vgroups

def get_mob_from_shape_element(self, shape: se.SVGElement) -> VMobject:
if isinstance(shape, se.Group): # noqa: SIM114
raise Exception("Should never get here")
elif isinstance(shape, se.Path):
mob: VMobject = self.path_to_mobject(shape)
elif isinstance(shape, se.SimpleLine):
mob = self.line_to_mobject(shape)
elif isinstance(shape, se.Rect):
mob = self.rect_to_mobject(shape)
elif isinstance(shape, (se.Circle, se.Ellipse)):
mob = self.ellipse_to_mobject(shape)
elif isinstance(shape, se.Polygon):
mob = self.polygon_to_mobject(shape)
elif isinstance(shape, se.Polyline):
mob = self.polyline_to_mobject(shape)
elif isinstance(shape, se.Text):
mob = self.text_to_mobject(shape)
elif isinstance(shape, se.Use) or type(shape) is se.SVGElement:
raise Exception("Should never get here - se.Use or se.SVGElement")
else:
logger.warning(f"Unsupported element type: {type(shape)}")
raise Exception(f"Unsupported element type: {type(shape)}")
if mob is None or not mob.has_points():
raise Exception("mob is empty or have no points")
self.apply_style_to_mobject(mob, shape)
if isinstance(shape, se.Transformable) and shape.apply:
self.handle_transform(mob, shape.transform)
return mob

@staticmethod
def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject:
Expand Down
Loading