Skip to content

Commit 584ade2

Browse files
committed
Extract a draw_collection function
1 parent 766c870 commit 584ade2

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

tikzplotlib/_save.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,24 @@ def flatten(self):
316316
return content_out
317317

318318

319+
def _draw_collection(data, child):
320+
if isinstance(
321+
child, (mpl.collections.PatchCollection, mpl.collections.PolyCollection)
322+
):
323+
return _patch.draw_patchcollection(data, child)
324+
elif isinstance(child, mpl.collections.PathCollection):
325+
return _path.draw_pathcollection(data, child)
326+
elif isinstance(child, mpl.collections.LineCollection):
327+
return _line2d.draw_linecollection(data, child)
328+
elif isinstance(child, mpl.collections.QuadMesh):
329+
return qmsh.draw_quadmesh(data, child)
330+
else:
331+
warnings.warn(
332+
"tikzplotlib: Don't know how to handle object {}.".format(type(child))
333+
)
334+
return data, []
335+
336+
319337
def _recurse(data, obj):
320338
"""Iterates over all children of the current object, gathers the contents
321339
contributing to the resulting PGFPlots file, and returns those.
@@ -365,19 +383,8 @@ def _recurse(data, obj):
365383
elif isinstance(child, mpl.patches.Patch):
366384
data, cont = _patch.draw_patch(data, child)
367385
content.extend(cont, child.get_zorder())
368-
elif isinstance(
369-
child, (mpl.collections.PatchCollection, mpl.collections.PolyCollection)
370-
):
371-
data, cont = _patch.draw_patchcollection(data, child)
372-
content.extend(cont, child.get_zorder())
373-
elif isinstance(child, mpl.collections.PathCollection):
374-
data, cont = _path.draw_pathcollection(data, child)
375-
content.extend(cont, child.get_zorder())
376-
elif isinstance(child, mpl.collections.LineCollection):
377-
data, cont = _line2d.draw_linecollection(data, child)
378-
content.extend(cont, child.get_zorder())
379-
elif isinstance(child, mpl.collections.QuadMesh):
380-
data, cont = qmsh.draw_quadmesh(data, child)
386+
elif isinstance(child, mpl.collections.Collection):
387+
data, cont = _draw_collection(data, child)
381388
content.extend(cont, child.get_zorder())
382389
elif isinstance(child, mpl.legend.Legend):
383390
data = _legend.draw_legend(data, child)

0 commit comments

Comments
 (0)