Skip to content

Commit 233c206

Browse files
committed
add configuration options for marker [no ci]
1 parent a95ec8d commit 233c206

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

bayesflow/diagnostics/plots/pairs_quantity.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def pairs_quantity(
2222
height: float = 2.5,
2323
cmap: str | matplotlib.colors.Colormap = "viridis",
2424
alpha: float = 0.9,
25+
s: float = 8.0,
26+
marker: str = "o",
2527
label: str = None,
2628
label_fontsize: int = 14,
2729
tick_fontsize: int = 12,
@@ -94,6 +96,10 @@ def pairs_quantity(
9496
The colormap for the plot.
9597
alpha : float in [0, 1], optional, default: 0.9
9698
The opacity of the plot
99+
s : float, optional, default: 8.0
100+
The marker size in points**2 for the scatter plot.
101+
marker : str, optional, default: 'o'
102+
The marker for the scatter plot.
97103
label : str, optional, default: None
98104
Label for the dataset to plot.
99105
label_fontsize : int, optional, default: 14
@@ -177,7 +183,17 @@ def pairs_quantity(
177183

178184
if i == j:
179185
ax = g.axes[i, j].twinx()
180-
ax.scatter(targets[:, i], values[:, i], c=row_values, cmap=cmap, s=4, vmin=vmin, vmax=vmax, alpha=alpha)
186+
ax.scatter(
187+
targets[:, i],
188+
values[:, i],
189+
c=row_values,
190+
cmap=cmap,
191+
s=s,
192+
marker=marker,
193+
vmin=vmin,
194+
vmax=vmax,
195+
alpha=alpha,
196+
)
181197
ax.spines["left"].set_visible(False)
182198
ax.spines["top"].set_visible(False)
183199
ax.tick_params(axis="both", which="major", labelsize=tick_fontsize)
@@ -197,10 +213,11 @@ def pairs_quantity(
197213
targets[:, i],
198214
c=row_values,
199215
cmap=cmap,
200-
s=4,
216+
s=s,
201217
vmin=vmin,
202218
vmax=vmax,
203219
alpha=alpha,
220+
marker=marker,
204221
)
205222

206223
def inches_to_figure(fig, values):

bayesflow/diagnostics/plots/plot_quantity.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def plot_quantity(
2626
title_fontsize: int = 18,
2727
tick_fontsize: int = 12,
2828
color: str = "#132a70",
29+
s: float = 25.0,
30+
marker: str = "o",
31+
alpha: float = 0.5,
2932
xlabel: str = "Ground truth",
3033
ylabel: str = "",
3134
num_col: int = None,
@@ -90,6 +93,12 @@ def plot_quantity(
9093
The font size of the axis ticklabels
9194
color : str, optional, default: '#8f2727'
9295
The color for the true vs. estimated scatter points and error bars
96+
s : float, optional, default: 25.0
97+
The marker size in points**2 for the scatter plot.
98+
marker : str, optional, default: 'o'
99+
The marker for the scatter plot.
100+
alpha : float, default: 0.5
101+
The opacity for the scatter plot
93102
num_row : int, optional, default: None
94103
The number of rows for the subplots. Dynamically determined if None.
95104
num_col : int, optional, default: None
@@ -143,7 +152,7 @@ def plot_quantity(
143152
if i >= num_variables:
144153
break
145154

146-
ax.scatter(targets[:, i], values[:, i], color=color, alpha=0.5)
155+
ax.scatter(targets[:, i], values[:, i], color=color, alpha=alpha, s=s, marker=marker)
147156

148157
prettify_subplots(axes, num_subplots=num_variables, tick_fontsize=tick_fontsize)
149158

0 commit comments

Comments
 (0)