Skip to content

Commit 330873f

Browse files
committed
allow rgb spec for individual points in scatter
1 parent e4a4e6b commit 330873f

File tree

1 file changed

+62
-9
lines changed

1 file changed

+62
-9
lines changed

tikzplotlib/_path.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,13 @@ def draw_pathcollection(data, obj):
115115
content = []
116116
# gather data
117117
assert obj.get_offsets() is not None
118-
labels = ["x" + 21 * " ", "y" + 21 * " "]
118+
labels = ["x" + 3 * " ", "y" + 3 * " "]
119119
dd = obj.get_offsets()
120120

121-
draw_options = ["only marks"]
121+
fmt = "{:" + data["float format"] + "}"
122+
dd_strings = np.array([[fmt.format(val) for val in row] for row in dd])
123+
124+
draw_options = ["scatter", "only marks"]
122125
table_options = []
123126

124127
if obj.get_array() is not None:
@@ -133,21 +136,64 @@ def draw_pathcollection(data, obj):
133136
marker0 = None
134137
else:
135138
# gather the draw options
139+
add_individual_color_code = False
140+
136141
try:
137-
ec = obj.get_edgecolors()[0]
142+
ec = obj.get_edgecolors()
138143
except (TypeError, IndexError):
139144
ec = None
145+
else:
146+
if len(ec) == 1:
147+
ec = ec[0]
148+
else:
149+
assert len(ec) == len(dd)
150+
labels.append("draw" + 3 * " ")
151+
ec_strings = [
152+
",".join(fmt.format(item) for item in row)
153+
for row in ec[:, :3] * 255
154+
]
155+
dd_strings = np.column_stack([dd_strings, ec_strings])
156+
add_individual_color_code = True
140157

141158
try:
142-
fc = obj.get_facecolors()[0]
159+
fc = obj.get_facecolors()
143160
except (TypeError, IndexError):
144161
fc = None
162+
else:
163+
if len(fc) == 1:
164+
fc = fc[0]
165+
else:
166+
assert len(fc) == len(dd)
167+
labels.append("fill" + 3 * " ")
168+
fc_strings = [
169+
",".join(fmt.format(item) for item in row)
170+
for row in fc[:, :3] * 255
171+
]
172+
dd_strings = np.column_stack([dd_strings, fc_strings])
173+
add_individual_color_code = True
145174

146175
try:
147176
ls = obj.get_linestyle()[0]
148177
except (TypeError, IndexError):
149178
ls = None
150179

180+
if add_individual_color_code:
181+
draw_options.extend(
182+
[
183+
"scatter",
184+
"visualization depends on={value \\thisrow{draw} \\as \\drawcolor}",
185+
"visualization depends on={value \\thisrow{fill} \\as \\fillcolor}",
186+
"scatter/@pre marker code/.code={%\n"
187+
" \\expanded{%\n"
188+
" \\noexpand\\definecolor{thispointdrawcolor}{RGB}{\\drawcolor}%\n"
189+
" \\noexpand\\definecolor{thispointfillcolor}{RGB}{\\fillcolor}%\n"
190+
" }%\n"
191+
" \\scope[draw=thispointdrawcolor, fill=thispointfillcolor]%\n"
192+
"}",
193+
"scatter/@post marker code/.code={%\n" " \\endscope\n" "}",
194+
]
195+
)
196+
151197
# "solution" from
152198
# <https://github.com/matplotlib/matplotlib/issues/4672#issuecomment-378702670>
153199
p = obj.get_paths()[0]
@@ -179,7 +225,14 @@ def draw_pathcollection(data, obj):
179225
draw_options += ["mark options={{{}}}".format(",".join(marker_options))]
180226

181227
# `only mark` plots don't need linewidth
182-
data, extra_draw_options = get_draw_options(data, obj, ec, fc, ls, None)
228+
data, extra_draw_options = get_draw_options(
229+
data,
230+
obj,
231+
None if ec is None or len(ec) > 1 else ec,
232+
None if fc is None or len(fc) > 1 else fc,
233+
ls,
234+
None,
235+
)
183236
draw_options += extra_draw_options
184237

185238
if obj.get_cmap():
@@ -196,7 +249,7 @@ def draw_pathcollection(data, obj):
196249

197250
if len(obj.get_sizes()) == len(dd):
198251
# See Pgfplots manual, chapter 4.25.
199-
# In Pgfplots, \mark size specifies raddi, in matplotlib circle areas.
252+
# In Pgfplots, \mark size specifies radii, in matplotlib circle areas.
200253
radii = np.sqrt(obj.get_sizes() / np.pi)
201254
dd = np.column_stack([dd, radii])
202255
labels.append("sizedata" + 14 * " ")
@@ -217,9 +270,9 @@ def draw_pathcollection(data, obj):
217270

218271
content.append((" ".join(labels)).strip() + "\n")
219272
ff = data["float format"]
220-
fmt = (" ".join(dd.shape[1] * ["{:" + ff + "}"])) + "\n"
221-
for d in dd:
222-
content.append(fmt.format(*tuple(d)))
273+
274+
for row in dd_strings:
275+
content.append(" ".join(row) + "\n")
223276
content.append("};\n")
224277

225278
if legend_text is not None:

0 commit comments

Comments
 (0)