@@ -77,7 +77,7 @@ def __init__(
7777 fill_opacity = 1.0 ,
7878 ** kwargs ,
7979 ):
80- self .def_id_to_mobject = {}
80+ self .def_map = {}
8181 self .file_name = file_name or self .file_name
8282 self .ensure_valid_file ()
8383 self .should_center = should_center
@@ -125,17 +125,7 @@ def generate_points(self):
125125 """
126126 doc = minidom_parse (self .file_path )
127127 for svg in doc .getElementsByTagName ("svg" ):
128- mobjects = self .get_mobjects_from (
129- svg ,
130- # these are the default styling specifications for SVG images,
131- # according to https://www.w3.org/TR/SVG/painting.html, ctrl-F for "initial"
132- {
133- "fill" : "black" ,
134- "fill-opacity" : "1" ,
135- "stroke" : "none" ,
136- "stroke-opacity" : "1" ,
137- },
138- )
128+ mobjects = self .get_mobjects_from (svg , {})
139129 if self .unpack_groups :
140130 self .add (* mobjects )
141131 else :
@@ -182,7 +172,9 @@ def get_mobjects_from(
182172 elif element .tagName in ["g" , "svg" , "symbol" , "defs" ]:
183173 result += it .chain (
184174 * [
185- self .get_mobjects_from (child , style , within_defs or is_defs )
175+ self .get_mobjects_from (
176+ child , style , within_defs = within_defs or is_defs
177+ )
186178 for child in element .childNodes
187179 ]
188180 )
@@ -191,8 +183,8 @@ def get_mobjects_from(
191183 if temp != "" :
192184 result .append (self .path_string_to_mobject (temp , style ))
193185 elif element .tagName == "use" :
194- # note, style is not passed down to " use" elements
195- result += self .use_to_mobjects (element )
186+ # note, style is calcuated in a different way for ` use` elements.
187+ result += self .use_to_mobjects (element , style )
196188 elif element .tagName == "rect" :
197189 result .append (self .rect_to_mobject (element , style ))
198190 elif element .tagName == "circle" :
@@ -210,7 +202,9 @@ def get_mobjects_from(
210202 result = [VGroup (* result )]
211203
212204 if within_defs and element .hasAttribute ("id" ):
213- self .def_id_to_mobject [element .getAttribute ("id" )] = result
205+ # it seems wasteful to throw away the actual element,
206+ # but I'd like the parsing to be as similar as possible
207+ self .def_map [element .getAttribute ("id" )] = (style , element )
214208 if is_defs :
215209 # defs shouldn't be part of the result tree, only the id dictionary.
216210 return []
@@ -235,7 +229,9 @@ def path_string_to_mobject(self, path_string: str, style: dict):
235229 """
236230 return SVGPathMobject (path_string , ** parse_style (style ))
237231
238- def use_to_mobjects (self , use_element : MinidomElement ) -> List [VMobject ]:
232+ def use_to_mobjects (
233+ self , use_element : MinidomElement , local_style : Dict
234+ ) -> List [VMobject ]:
239235 """Converts a SVG <use> element to a collection of VMobjects.
240236
241237 Parameters
@@ -244,22 +240,34 @@ def use_to_mobjects(self, use_element: MinidomElement) -> List[VMobject]:
244240 An SVG <use> element which represents nodes that should be
245241 duplicated elsewhere.
246242
243+ local_style : :class:`Dict`
244+ The styling using SVG property names at the point the element is `<use>`d.
245+ Not all values are applied; styles defined when the element is specified in
246+ the `<def>` tag cannot be overriden here.
247+
247248 Returns
248249 -------
249250 List[VMobject]
250- A collection of VMobjects that are copies of the defined objects
251+ A collection of VMobjects that are a copy of the defined object
251252 """
252253
253254 # Remove initial "#" character
254255 ref = use_element .getAttribute ("xlink:href" )[1 :]
255256
256257 try :
257- return [ i . copy () for i in self .def_id_to_mobject [ref ] ]
258+ def_style , def_element = self .def_map [ref ]
258259 except KeyError :
259260 warning_text = f"{ self .file_name } contains a reference to id #{ ref } , which is not recognized"
260261 warnings .warn (warning_text )
261262 return []
262263
264+ # In short, the def-ed style overrides the new style,
265+ # in cases when the def-ed styled is defined.
266+ style = local_style .copy ()
267+ style .update (def_style )
268+
269+ return self .get_mobjects_from (def_element , style )
270+
263271 def attribute_to_float (self , attr ):
264272 """A helper method which converts the attribute to float.
265273
@@ -385,20 +393,21 @@ def rect_to_mobject(self, rect_element: MinidomElement, style: dict):
385393
386394 corner_radius = float (corner_radius )
387395
396+ parsed_style = parse_style (style )
397+ parsed_style ["stroke_width" ] = stroke_width
398+
388399 if corner_radius == 0 :
389400 mob = Rectangle (
390401 width = self .attribute_to_float (rect_element .getAttribute ("width" )),
391402 height = self .attribute_to_float (rect_element .getAttribute ("height" )),
392- stroke_width = stroke_width ,
393- ** parse_style (style ),
403+ ** parsed_style ,
394404 )
395405 else :
396406 mob = RoundedRectangle (
397407 width = self .attribute_to_float (rect_element .getAttribute ("width" )),
398408 height = self .attribute_to_float (rect_element .getAttribute ("height" )),
399- stroke_width = stroke_width ,
400409 corner_radius = corner_radius ,
401- ** parse_style ( style ) ,
410+ ** parsed_style ,
402411 )
403412
404413 mob .shift (mob .get_center () - mob .get_corner (UP + LEFT ))
0 commit comments