Skip to content

Commit c52d151

Browse files
committed
Standardize figure settings
1 parent c1a4fec commit c52d151

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed

gui/plots.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
2020
if len(df) == 0 or "Equation" not in df.columns:
2121
return fig
2222

23-
# Plotting the data
2423
ax.loglog(
2524
df["Complexity"],
2625
df["Loss"],
@@ -31,23 +30,12 @@ def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
3130
markersize=6,
3231
)
3332

34-
# Set the axis limits
3533
ax.set_xlim(0.5, maxsize + 1)
3634
ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
3735
ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
3836
ax.set_ylim(ybottom, ytop)
3937

40-
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
41-
ax.spines["top"].set_visible(False)
42-
ax.spines["right"].set_visible(False)
43-
44-
# Range-frame the plot
45-
for direction in ["bottom", "left"]:
46-
ax.spines[direction].set_position(("outward", 10))
47-
48-
# Delete far ticks
49-
ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
50-
ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
38+
stylize_axis(ax)
5139

5240
ax.set_xlabel("Complexity")
5341
ax.set_ylabel("Loss")
@@ -57,14 +45,23 @@ def plot_pareto_curve(df: pd.DataFrame, maxsize: int):
5745

5846

5947
def plot_example_data(test_equation, num_points, noise_level, data_seed):
48+
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
49+
6050
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
6151
x = X["x"]
6252

63-
plt.rcParams["font.family"] = "IBM Plex Mono"
64-
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
65-
6653
ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
6754

55+
stylize_axis(ax)
56+
57+
ax.set_xlabel("x")
58+
ax.set_ylabel("y")
59+
fig.tight_layout(pad=2)
60+
61+
return fig
62+
63+
64+
def stylize_axis(ax):
6865
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
6966
ax.spines["top"].set_visible(False)
7067
ax.spines["right"].set_visible(False)
@@ -76,9 +73,3 @@ def plot_example_data(test_equation, num_points, noise_level, data_seed):
7673
# Delete far ticks
7774
ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
7875
ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
79-
80-
ax.set_xlabel("x")
81-
ax.set_ylabel("y")
82-
fig.tight_layout(pad=2)
83-
84-
return fig

0 commit comments

Comments
 (0)