diff --git a/src/spine/vis/train.py b/src/spine/vis/train.py index 72a33453..93ef9ad0 100644 --- a/src/spine/vis/train.py +++ b/src/spine/vis/train.py @@ -421,7 +421,12 @@ def draw( plt.ylabel(ylabel if len(metric) == 1 else "Metric") plt.gca().set_ylim(limits[metric[0]]) legend_title = model_name[model[0]] if len(model) == 1 else None - plt.legend(ncol=leg_ncols, title=legend_title) + plt.legend( + ncol=leg_ncols, + title=legend_title, + bbox_to_anchor=(1.05, 1), + loc="upper left", + ) if figure_name: plt.savefig(f"{figure_name}.png", bbox_inches="tight")