diff --git a/extras/helper_functions.py b/extras/helper_functions.py index a5d604f3..30a23158 100644 --- a/extras/helper_functions.py +++ b/extras/helper_functions.py @@ -81,13 +81,13 @@ def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_s xlabel="Predicted label", ylabel="True label", xticks=np.arange(n_classes), # create enough axis slots for each class - yticks=np.arange(n_classes), - xticklabels=labels, # axes will labeled with class names (if they exist) or ints + xticklabels=labels, yticklabels=labels) # Make x-axis labels appear on bottom ax.xaxis.set_label_position("bottom") ax.xaxis.tick_bottom() + ax.set_xticklabels(labels, rotation=90) # axes will labeled with class names (if they exist) or ints # Set the threshold for different colors threshold = (cm.max() + cm.min()) / 2.