@@ -481,7 +481,7 @@ def log_precision_recall(
481481 Trial Component as an output artifact. If False will be an input artifact.
482482
483483 Raises:
484- ValueError: If mismatch between y_true and predicted_probabilities.
484+ ValueError: If length mismatch between y_true and predicted_probabilities.
485485 """
486486
487487 if len (y_true ) != len (predicted_probabilities ):
@@ -542,7 +542,7 @@ def log_roc_curve(
542542 """
543543
544544 if len (y_true ) != len (y_score ):
545- raise ValueError ("Mismatch between actual labels and predicted scores." )
545+ raise ValueError ("Length mismatch between actual labels and predicted scores." )
546546
547547 get_module ("sklearn" )
548548 from sklearn .metrics import roc_curve , auc
@@ -561,6 +561,50 @@ def log_roc_curve(
561561 }
562562 self ._log_graph_artifact (title , data , "ROCCurve" , output_artifact )
563563
564+ def log_confusion_matrix (
565+ self ,
566+ y_true ,
567+ y_pred ,
568+ title = None ,
569+ output_artifact = True ,
570+ ):
571+ """Log a confusion matrix artifact which will be displayed in
572+ studio. Requires sklearn.
573+
574+ Note that this method must be run from a SageMaker context such as studio or training job
575+ due to restrictions on the CreateArtifact API.
576+
577+ Examples
578+ .. code-block:: python
579+
580+ y_true = [2, 0, 2, 2, 0, 1]
581+ y_pred = [0, 0, 2, 2, 0, 2]
582+
583+ my_tracker.log_confusion_matrix(y_true, y_pred)
584+
585+
586+ Args:
587+ y_true (array): True labels. If labels are not binary then positive_label should be given.
588+ y_pred (array): Predicted labels.
589+ title (str, optional): Title of the graph, Defaults to none.
590+ output_artifact (boolean, optional): Determines if the artifact is associated with the
591+ Trial Component as an output artifact. If False will be an input artifact.
592+
593+ Raises:
594+ ValueError: If length mismatch between y_true and y_pred.
595+ """
596+
597+ if len (y_true ) != len (y_pred ):
598+ raise ValueError ("Length mismatch between actual labels and predicted labels." )
599+
600+ get_module ("sklearn" )
601+ from sklearn .metrics import confusion_matrix
602+
603+ matrix = confusion_matrix (y_true , y_pred )
604+
605+ data = {"type" : "ConfusionMatrix" , "version" : 0 , "title" : title , "confusionMatrix" : matrix .tolist ()}
606+ self ._log_graph_artifact (title , data , "ConfusionMatrix" , output_artifact )
607+
564608 def _log_graph_artifact (self , name , data , graph_type , output_artifact ):
565609 """Logs an artifact by uploading data to S3, creating an artifact, and associating that
566610 artifact with the tracker's Trial Component.
0 commit comments