Skip to content

Commit fbd6f57

Browse files
authored
Merge pull request #373 from eric-wieser/fix-collections
Add support for transforms and offsets in custom collections
2 parents 1b0bb9f + 6b8484d commit fbd6f57

File tree

4 files changed

+254
-19
lines changed

4 files changed

+254
-19
lines changed

test/test_custom_collection.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Custom collection test
2+
3+
This tests plots a subclass of Collection, which contains enough information
4+
as a base class to be rendered.
5+
"""
6+
import matplotlib
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
10+
from helpers import assert_equality
11+
12+
13+
class TransformedEllipseCollection(matplotlib.collections.Collection):
14+
"""
15+
A gutted version of matplotlib.collections.EllipseCollection that lets us
16+
pass the transformation matrix directly.
17+
18+
This is useful for plotting cholesky factors of covariance matrices.
19+
"""
20+
21+
def __init__(self, matrices, **kwargs):
22+
super().__init__(**kwargs)
23+
self.set_transform(matplotlib.transforms.IdentityTransform())
24+
self._transforms = np.zeros(matrices.shape[:-2] + (3, 3))
25+
self._transforms[..., :2, :2] = matrices
26+
self._transforms[..., 2, 2] = 1
27+
self._paths = [matplotlib.path.Path.unit_circle()]
28+
29+
def _set_transforms(self):
30+
"""Calculate transforms immediately before drawing."""
31+
m = self.axes.transData.get_affine().get_matrix().copy()
32+
m[:2, 2:] = 0
33+
self.set_transform(matplotlib.transforms.Affine2D(m))
34+
35+
@matplotlib.artist.allow_rasterization
36+
def draw(self, renderer):
37+
self._set_transforms()
38+
super().draw(renderer)
39+
40+
41+
def rot(theta):
42+
""" Get a stack of rotation matrices """
43+
return np.stack(
44+
[
45+
np.stack([np.cos(theta), -np.sin(theta)], axis=-1),
46+
np.stack([np.sin(theta), np.cos(theta)], axis=-1),
47+
],
48+
axis=-2,
49+
)
50+
51+
52+
def plot():
53+
# plot data
54+
fig = plt.figure()
55+
ax = fig.add_subplot(111)
56+
57+
theta = np.linspace(0, 2 * np.pi, 12, endpoint=False)
58+
mats = rot(theta) @ np.diag([0.1, 0.2])
59+
x = np.cos(theta)
60+
y = np.sin(theta)
61+
62+
c = TransformedEllipseCollection(
63+
mats,
64+
offsets=np.stack((x, y), axis=-1),
65+
edgecolor="tab:red",
66+
alpha=0.5,
67+
facecolor="tab:blue",
68+
transOffset=ax.transData,
69+
)
70+
ax.add_collection(c)
71+
ax.set(xlim=[-1.5, 1.5], ylim=[-1.5, 1.5])
72+
73+
return fig
74+
75+
76+
def test():
77+
assert_equality(plot, "test_custom_collection_reference.tex")
78+
return
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
\begin{tikzpicture}
2+
3+
\definecolor{color0}{rgb}{0.83921569,0.15294118,0.15686275}
4+
\definecolor{color1}{rgb}{0.12156863,0.46666667,0.70588235}
5+
6+
\begin{axis}[
7+
tick align=outside,
8+
tick pos=left,
9+
x grid style={white!69.019608!black},
10+
xmin=-1.5, xmax=1.5,
11+
xtick style={color=black},
12+
y grid style={white!69.019608!black},
13+
ymin=-1.5, ymax=1.5,
14+
ytick style={color=black}
15+
]
16+
\path [draw=color0, fill=color1, opacity=0.5]
17+
(axis cs:1,-0.2)
18+
.. controls (axis cs:1.0265203,-0.2) and (axis cs:1.051958,-0.17892674) .. (axis cs:1.0707107,-0.14142136)
19+
.. controls (axis cs:1.0894634,-0.10391597) and (axis cs:1.1,-0.05304062) .. (axis cs:1.1,0)
20+
.. controls (axis cs:1.1,0.05304062) and (axis cs:1.0894634,0.10391597) .. (axis cs:1.0707107,0.14142136)
21+
.. controls (axis cs:1.051958,0.17892674) and (axis cs:1.0265203,0.2) .. (axis cs:1,0.2)
22+
.. controls (axis cs:0.97347969,0.2) and (axis cs:0.94804201,0.17892674) .. (axis cs:0.92928932,0.14142136)
23+
.. controls (axis cs:0.91053663,0.10391597) and (axis cs:0.9,0.05304062) .. (axis cs:0.9,0)
24+
.. controls (axis cs:0.9,-0.05304062) and (axis cs:0.91053663,-0.10391597) .. (axis cs:0.92928932,-0.14142136)
25+
.. controls (axis cs:0.94804201,-0.17892674) and (axis cs:0.97347969,-0.2) .. (axis cs:1,-0.2)
26+
--cycle;
27+
\path [draw=color0, fill=color1, opacity=0.5]
28+
(axis cs:0.9660254,0.32679492)
29+
.. controls (axis cs:0.98899267,0.34005507) and (axis cs:1.0004857,0.37102389) .. (axis cs:0.99797333,0.41288085)
30+
.. controls (axis cs:0.99546094,0.45473781) and (axis cs:0.97914825,0.50406548) .. (axis cs:0.95262794,0.55)
31+
.. controls (axis cs:0.92610763,0.59593452) and (axis cs:0.89154497,0.63472556) .. (axis cs:0.85655197,0.65782983)
32+
.. controls (axis cs:0.82155897,0.68093409) and (axis cs:0.78899267,0.68646524) .. (axis cs:0.7660254,0.67320508)
33+
.. controls (axis cs:0.74305814,0.65994493) and (axis cs:0.7315651,0.62897611) .. (axis cs:0.73407748,0.58711915)
34+
.. controls (axis cs:0.73658987,0.54526219) and (axis cs:0.75290255,0.49593452) .. (axis cs:0.77942286,0.45)
35+
.. controls (axis cs:0.80594317,0.40406548) and (axis cs:0.84050584,0.36527444) .. (axis cs:0.87549884,0.34217017)
36+
.. controls (axis cs:0.91049184,0.31906591) and (axis cs:0.94305814,0.31353476) .. (axis cs:0.9660254,0.32679492)
37+
--cycle;
38+
\path [draw=color0, fill=color1, opacity=0.5]
39+
(axis cs:0.67320508,0.7660254)
40+
.. controls (axis cs:0.68646524,0.78899267) and (axis cs:0.68093409,0.82155897) .. (axis cs:0.65782983,0.85655197)
41+
.. controls (axis cs:0.63472556,0.89154497) and (axis cs:0.59593452,0.92610763) .. (axis cs:0.55,0.95262794)
42+
.. controls (axis cs:0.50406548,0.97914825) and (axis cs:0.45473781,0.99546094) .. (axis cs:0.41288085,0.99797333)
43+
.. controls (axis cs:0.37102389,1.0004857) and (axis cs:0.34005507,0.98899267) .. (axis cs:0.32679492,0.9660254)
44+
.. controls (axis cs:0.31353476,0.94305814) and (axis cs:0.31906591,0.91049184) .. (axis cs:0.34217017,0.87549884)
45+
.. controls (axis cs:0.36527444,0.84050584) and (axis cs:0.40406548,0.80594317) .. (axis cs:0.45,0.77942286)
46+
.. controls (axis cs:0.49593452,0.75290255) and (axis cs:0.54526219,0.73658987) .. (axis cs:0.58711915,0.73407748)
47+
.. controls (axis cs:0.62897611,0.7315651) and (axis cs:0.65994493,0.74305814) .. (axis cs:0.67320508,0.7660254)
48+
--cycle;
49+
\path [draw=color0, fill=color1, opacity=0.5]
50+
(axis cs:0.2,1)
51+
.. controls (axis cs:0.2,1.0265203) and (axis cs:0.17892674,1.051958) .. (axis cs:0.14142136,1.0707107)
52+
.. controls (axis cs:0.10391597,1.0894634) and (axis cs:0.05304062,1.1) .. (axis cs:6.7355574e-17,1.1)
53+
.. controls (axis cs:-0.05304062,1.1) and (axis cs:-0.10391597,1.0894634) .. (axis cs:-0.14142136,1.0707107)
54+
.. controls (axis cs:-0.17892674,1.051958) and (axis cs:-0.2,1.0265203) .. (axis cs:-0.2,1)
55+
.. controls (axis cs:-0.2,0.97347969) and (axis cs:-0.17892674,0.94804201) .. (axis cs:-0.14142136,0.92928932)
56+
.. controls (axis cs:-0.10391597,0.91053663) and (axis cs:-0.05304062,0.9) .. (axis cs:5.5109106e-17,0.9)
57+
.. controls (axis cs:0.05304062,0.9) and (axis cs:0.10391597,0.91053663) .. (axis cs:0.14142136,0.92928932)
58+
.. controls (axis cs:0.17892674,0.94804201) and (axis cs:0.2,0.97347969) .. (axis cs:0.2,1)
59+
--cycle;
60+
\path [draw=color0, fill=color1, opacity=0.5]
61+
(axis cs:-0.32679492,0.9660254)
62+
.. controls (axis cs:-0.34005507,0.98899267) and (axis cs:-0.37102389,1.0004857) .. (axis cs:-0.41288085,0.99797333)
63+
.. controls (axis cs:-0.45473781,0.99546094) and (axis cs:-0.50406548,0.97914825) .. (axis cs:-0.55,0.95262794)
64+
.. controls (axis cs:-0.59593452,0.92610763) and (axis cs:-0.63472556,0.89154497) .. (axis cs:-0.65782983,0.85655197)
65+
.. controls (axis cs:-0.68093409,0.82155897) and (axis cs:-0.68646524,0.78899267) .. (axis cs:-0.67320508,0.7660254)
66+
.. controls (axis cs:-0.65994493,0.74305814) and (axis cs:-0.62897611,0.7315651) .. (axis cs:-0.58711915,0.73407748)
67+
.. controls (axis cs:-0.54526219,0.73658987) and (axis cs:-0.49593452,0.75290255) .. (axis cs:-0.45,0.77942286)
68+
.. controls (axis cs:-0.40406548,0.80594317) and (axis cs:-0.36527444,0.84050584) .. (axis cs:-0.34217017,0.87549884)
69+
.. controls (axis cs:-0.31906591,0.91049184) and (axis cs:-0.31353476,0.94305814) .. (axis cs:-0.32679492,0.9660254)
70+
--cycle;
71+
\path [draw=color0, fill=color1, opacity=0.5]
72+
(axis cs:-0.7660254,0.67320508)
73+
.. controls (axis cs:-0.78899267,0.68646524) and (axis cs:-0.82155897,0.68093409) .. (axis cs:-0.85655197,0.65782983)
74+
.. controls (axis cs:-0.89154497,0.63472556) and (axis cs:-0.92610763,0.59593452) .. (axis cs:-0.95262794,0.55)
75+
.. controls (axis cs:-0.97914825,0.50406548) and (axis cs:-0.99546094,0.45473781) .. (axis cs:-0.99797333,0.41288085)
76+
.. controls (axis cs:-1.0004857,0.37102389) and (axis cs:-0.98899267,0.34005507) .. (axis cs:-0.9660254,0.32679492)
77+
.. controls (axis cs:-0.94305814,0.31353476) and (axis cs:-0.91049184,0.31906591) .. (axis cs:-0.87549884,0.34217017)
78+
.. controls (axis cs:-0.84050584,0.36527444) and (axis cs:-0.80594317,0.40406548) .. (axis cs:-0.77942286,0.45)
79+
.. controls (axis cs:-0.75290255,0.49593452) and (axis cs:-0.73658987,0.54526219) .. (axis cs:-0.73407748,0.58711915)
80+
.. controls (axis cs:-0.7315651,0.62897611) and (axis cs:-0.74305814,0.65994493) .. (axis cs:-0.7660254,0.67320508)
81+
--cycle;
82+
\path [draw=color0, fill=color1, opacity=0.5]
83+
(axis cs:-1,0.2)
84+
.. controls (axis cs:-1.0265203,0.2) and (axis cs:-1.051958,0.17892674) .. (axis cs:-1.0707107,0.14142136)
85+
.. controls (axis cs:-1.0894634,0.10391597) and (axis cs:-1.1,0.05304062) .. (axis cs:-1.1,1.3471115e-16)
86+
.. controls (axis cs:-1.1,-0.05304062) and (axis cs:-1.0894634,-0.10391597) .. (axis cs:-1.0707107,-0.14142136)
87+
.. controls (axis cs:-1.051958,-0.17892674) and (axis cs:-1.0265203,-0.2) .. (axis cs:-1,-0.2)
88+
.. controls (axis cs:-0.97347969,-0.2) and (axis cs:-0.94804201,-0.17892674) .. (axis cs:-0.92928932,-0.14142136)
89+
.. controls (axis cs:-0.91053663,-0.10391597) and (axis cs:-0.9,-0.05304062) .. (axis cs:-0.9,1.1021821e-16)
90+
.. controls (axis cs:-0.9,0.05304062) and (axis cs:-0.91053663,0.10391597) .. (axis cs:-0.92928932,0.14142136)
91+
.. controls (axis cs:-0.94804201,0.17892674) and (axis cs:-0.97347969,0.2) .. (axis cs:-1,0.2)
92+
--cycle;
93+
\path [draw=color0, fill=color1, opacity=0.5]
94+
(axis cs:-0.9660254,-0.32679492)
95+
.. controls (axis cs:-0.98899267,-0.34005507) and (axis cs:-1.0004857,-0.37102389) .. (axis cs:-0.99797333,-0.41288085)
96+
.. controls (axis cs:-0.99546094,-0.45473781) and (axis cs:-0.97914825,-0.50406548) .. (axis cs:-0.95262794,-0.55)
97+
.. controls (axis cs:-0.92610763,-0.59593452) and (axis cs:-0.89154497,-0.63472556) .. (axis cs:-0.85655197,-0.65782983)
98+
.. controls (axis cs:-0.82155897,-0.68093409) and (axis cs:-0.78899267,-0.68646524) .. (axis cs:-0.7660254,-0.67320508)
99+
.. controls (axis cs:-0.74305814,-0.65994493) and (axis cs:-0.7315651,-0.62897611) .. (axis cs:-0.73407748,-0.58711915)
100+
.. controls (axis cs:-0.73658987,-0.54526219) and (axis cs:-0.75290255,-0.49593452) .. (axis cs:-0.77942286,-0.45)
101+
.. controls (axis cs:-0.80594317,-0.40406548) and (axis cs:-0.84050584,-0.36527444) .. (axis cs:-0.87549884,-0.34217017)
102+
.. controls (axis cs:-0.91049184,-0.31906591) and (axis cs:-0.94305814,-0.31353476) .. (axis cs:-0.9660254,-0.32679492)
103+
--cycle;
104+
\path [draw=color0, fill=color1, opacity=0.5]
105+
(axis cs:-0.67320508,-0.7660254)
106+
.. controls (axis cs:-0.68646524,-0.78899267) and (axis cs:-0.68093409,-0.82155897) .. (axis cs:-0.65782983,-0.85655197)
107+
.. controls (axis cs:-0.63472556,-0.89154497) and (axis cs:-0.59593452,-0.92610763) .. (axis cs:-0.55,-0.95262794)
108+
.. controls (axis cs:-0.50406548,-0.97914825) and (axis cs:-0.45473781,-0.99546094) .. (axis cs:-0.41288085,-0.99797333)
109+
.. controls (axis cs:-0.37102389,-1.0004857) and (axis cs:-0.34005507,-0.98899267) .. (axis cs:-0.32679492,-0.9660254)
110+
.. controls (axis cs:-0.31353476,-0.94305814) and (axis cs:-0.31906591,-0.91049184) .. (axis cs:-0.34217017,-0.87549884)
111+
.. controls (axis cs:-0.36527444,-0.84050584) and (axis cs:-0.40406548,-0.80594317) .. (axis cs:-0.45,-0.77942286)
112+
.. controls (axis cs:-0.49593452,-0.75290255) and (axis cs:-0.54526219,-0.73658987) .. (axis cs:-0.58711915,-0.73407748)
113+
.. controls (axis cs:-0.62897611,-0.7315651) and (axis cs:-0.65994493,-0.74305814) .. (axis cs:-0.67320508,-0.7660254)
114+
--cycle;
115+
\path [draw=color0, fill=color1, opacity=0.5]
116+
(axis cs:-0.2,-1)
117+
.. controls (axis cs:-0.2,-1.0265203) and (axis cs:-0.17892674,-1.051958) .. (axis cs:-0.14142136,-1.0707107)
118+
.. controls (axis cs:-0.10391597,-1.0894634) and (axis cs:-0.05304062,-1.1) .. (axis cs:-2.0206672e-16,-1.1)
119+
.. controls (axis cs:0.05304062,-1.1) and (axis cs:0.10391597,-1.0894634) .. (axis cs:0.14142136,-1.0707107)
120+
.. controls (axis cs:0.17892674,-1.051958) and (axis cs:0.2,-1.0265203) .. (axis cs:0.2,-1)
121+
.. controls (axis cs:0.2,-0.97347969) and (axis cs:0.17892674,-0.94804201) .. (axis cs:0.14142136,-0.92928932)
122+
.. controls (axis cs:0.10391597,-0.91053663) and (axis cs:0.05304062,-0.9) .. (axis cs:-1.6532732e-16,-0.9)
123+
.. controls (axis cs:-0.05304062,-0.9) and (axis cs:-0.10391597,-0.91053663) .. (axis cs:-0.14142136,-0.92928932)
124+
.. controls (axis cs:-0.17892674,-0.94804201) and (axis cs:-0.2,-0.97347969) .. (axis cs:-0.2,-1)
125+
--cycle;
126+
\path [draw=color0, fill=color1, opacity=0.5]
127+
(axis cs:0.32679492,-0.9660254)
128+
.. controls (axis cs:0.34005507,-0.98899267) and (axis cs:0.37102389,-1.0004857) .. (axis cs:0.41288085,-0.99797333)
129+
.. controls (axis cs:0.45473781,-0.99546094) and (axis cs:0.50406548,-0.97914825) .. (axis cs:0.55,-0.95262794)
130+
.. controls (axis cs:0.59593452,-0.92610763) and (axis cs:0.63472556,-0.89154497) .. (axis cs:0.65782983,-0.85655197)
131+
.. controls (axis cs:0.68093409,-0.82155897) and (axis cs:0.68646524,-0.78899267) .. (axis cs:0.67320508,-0.7660254)
132+
.. controls (axis cs:0.65994493,-0.74305814) and (axis cs:0.62897611,-0.7315651) .. (axis cs:0.58711915,-0.73407748)
133+
.. controls (axis cs:0.54526219,-0.73658987) and (axis cs:0.49593452,-0.75290255) .. (axis cs:0.45,-0.77942286)
134+
.. controls (axis cs:0.40406548,-0.80594317) and (axis cs:0.36527444,-0.84050584) .. (axis cs:0.34217017,-0.87549884)
135+
.. controls (axis cs:0.31906591,-0.91049184) and (axis cs:0.31353476,-0.94305814) .. (axis cs:0.32679492,-0.9660254)
136+
--cycle;
137+
\path [draw=color0, fill=color1, opacity=0.5]
138+
(axis cs:0.7660254,-0.67320508)
139+
.. controls (axis cs:0.78899267,-0.68646524) and (axis cs:0.82155897,-0.68093409) .. (axis cs:0.85655197,-0.65782983)
140+
.. controls (axis cs:0.89154497,-0.63472556) and (axis cs:0.92610763,-0.59593452) .. (axis cs:0.95262794,-0.55)
141+
.. controls (axis cs:0.97914825,-0.50406548) and (axis cs:0.99546094,-0.45473781) .. (axis cs:0.99797333,-0.41288085)
142+
.. controls (axis cs:1.0004857,-0.37102389) and (axis cs:0.98899267,-0.34005507) .. (axis cs:0.9660254,-0.32679492)
143+
.. controls (axis cs:0.94305814,-0.31353476) and (axis cs:0.91049184,-0.31906591) .. (axis cs:0.87549884,-0.34217017)
144+
.. controls (axis cs:0.84050584,-0.36527444) and (axis cs:0.80594317,-0.40406548) .. (axis cs:0.77942286,-0.45)
145+
.. controls (axis cs:0.75290255,-0.49593452) and (axis cs:0.73658987,-0.54526219) .. (axis cs:0.73407748,-0.58711915)
146+
.. controls (axis cs:0.7315651,-0.62897611) and (axis cs:0.74305814,-0.65994493) .. (axis cs:0.7660254,-0.67320508)
147+
--cycle;
148+
149+
\end{axis}
150+
151+
\end{tikzpicture}

tikzplotlib/_patch.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ def _patch_legend(obj, draw_options, legend_type):
6363
return legend
6464

6565

66+
def zip_modulo(*seqs):
67+
n = max(len(seq) for seq in seqs)
68+
for i in range(n):
69+
yield tuple(seq[i % len(seq)] for seq in seqs)
70+
71+
6672
def draw_patchcollection(data, obj):
6773
"""Returns PGFPlots code for a number of patch objects.
6874
"""
@@ -78,15 +84,15 @@ def ensure_list(x):
7884
fcs = ensure_list(obj.get_facecolor())
7985
lss = ensure_list(obj.get_linestyle())
8086
ws = ensure_list(obj.get_linewidth())
87+
ts = ensure_list(obj.get_transforms())
88+
offs = obj.get_offsets()
8189

8290
paths = obj.get_paths()
83-
for i, path in enumerate(paths):
84-
# Gather the draw options.
85-
ec = ecs[i % len(ecs)]
86-
fc = fcs[i % len(fcs)]
87-
ls = lss[i % len(lss)]
88-
w = ws[i % len(ws)]
91+
for path, ec, fc, ls, w, t, off in zip_modulo(paths, ecs, fcs, lss, ws, ts, offs):
92+
if t is None:
93+
t = mpl.transforms.IdentityTransform()
8994

95+
path = path.transformed(mpl.transforms.Affine2D(t).translate(*off))
9096
data, draw_options = mypath.get_draw_options(data, obj, ec, fc, ls, w)
9197
data, cont, draw_options, is_area = mypath.draw_path(
9298
data, path, draw_options=draw_options

tikzplotlib/_save.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,17 @@ def flatten(self):
316316
return content_out
317317

318318

319+
def _draw_collection(data, child):
320+
if isinstance(child, mpl.collections.PathCollection):
321+
return _path.draw_pathcollection(data, child)
322+
elif isinstance(child, mpl.collections.LineCollection):
323+
return _line2d.draw_linecollection(data, child)
324+
elif isinstance(child, mpl.collections.QuadMesh):
325+
return qmsh.draw_quadmesh(data, child)
326+
else:
327+
return _patch.draw_patchcollection(data, child)
328+
329+
319330
def _recurse(data, obj):
320331
"""Iterates over all children of the current object, gathers the contents
321332
contributing to the resulting PGFPlots file, and returns those.
@@ -365,19 +376,8 @@ def _recurse(data, obj):
365376
elif isinstance(child, mpl.patches.Patch):
366377
data, cont = _patch.draw_patch(data, child)
367378
content.extend(cont, child.get_zorder())
368-
elif isinstance(
369-
child, (mpl.collections.PatchCollection, mpl.collections.PolyCollection)
370-
):
371-
data, cont = _patch.draw_patchcollection(data, child)
372-
content.extend(cont, child.get_zorder())
373-
elif isinstance(child, mpl.collections.PathCollection):
374-
data, cont = _path.draw_pathcollection(data, child)
375-
content.extend(cont, child.get_zorder())
376-
elif isinstance(child, mpl.collections.LineCollection):
377-
data, cont = _line2d.draw_linecollection(data, child)
378-
content.extend(cont, child.get_zorder())
379-
elif isinstance(child, mpl.collections.QuadMesh):
380-
data, cont = qmsh.draw_quadmesh(data, child)
379+
elif isinstance(child, mpl.collections.Collection):
380+
data, cont = _draw_collection(data, child)
381381
content.extend(cont, child.get_zorder())
382382
elif isinstance(child, mpl.legend.Legend):
383383
data = _legend.draw_legend(data, child)

0 commit comments

Comments
 (0)