1717import numpy as np
1818import pandas as pd
1919import pytest
20- import sklearn .metrics as sklearn_metrics # type: ignore
2120
2221import bigframes
2322from bigframes .ml import metrics
@@ -66,6 +65,7 @@ def test_r2_score_force_finite(session):
6665
6766
6867def test_r2_score_ok_fit_matches_sklearn (session ):
68+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
6969 pd_df = pd .DataFrame ({"y_true" : [1 , 2 , 3 , 4 , 5 ], "y_pred" : [2 , 3 , 4 , 3 , 6 ]})
7070
7171 df = session .read_pandas (pd_df )
@@ -113,6 +113,7 @@ def test_accuracy_score_not_normailze(session):
113113
114114
115115def test_accuracy_score_fit_matches_sklearn (session ):
116+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
116117 pd_df = pd .DataFrame ({"y_true" : [1 , 2 , 3 , 4 , 5 ], "y_pred" : [2 , 3 , 4 , 3 , 6 ]})
117118
118119 df = session .read_pandas (pd_df )
@@ -203,6 +204,7 @@ def test_roc_curve_binary_classification_prediction_returns_expected(session):
203204
204205
205206def test_roc_curve_binary_classification_prediction_matches_sklearn (session ):
207+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
206208 pd_df = pd .DataFrame (
207209 {
208210 "y_true" : [0 , 0 , 1 , 1 , 0 , 1 , 0 , 1 , 1 , 1 ],
@@ -294,6 +296,7 @@ def test_roc_curve_binary_classification_decision_returns_expected(session):
294296
295297
296298def test_roc_curve_binary_classification_decision_matches_sklearn (session ):
299+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
297300 # Instead of operating on probabilities, assume a 70% decision threshold
298301 # has been applied, and operate on the final output
299302 y_score = [0.1 , 0.4 , 0.35 , 0.8 , 0.65 , 0.9 , 0.5 , 0.3 , 0.6 , 0.45 ]
@@ -420,6 +423,7 @@ def test_roc_auc_score_returns_expected(session):
420423
421424
422425def test_roc_auc_score_returns_matches_sklearn (session ):
426+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
423427 pd_df = pd .DataFrame (
424428 {
425429 "y_true" : [0 , 0 , 1 , 1 , 0 , 1 , 0 , 1 , 1 , 1 ],
@@ -525,6 +529,7 @@ def test_confusion_matrix_column_index(session):
525529
526530
527531def test_confusion_matrix_matches_sklearn (session ):
532+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
528533 pd_df = pd .DataFrame (
529534 {
530535 "y_true" : [2 , 3 , 3 , 3 , 4 , 1 ],
@@ -543,6 +548,7 @@ def test_confusion_matrix_matches_sklearn(session):
543548
544549
545550def test_confusion_matrix_str_matches_sklearn (session ):
551+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
546552 pd_df = pd .DataFrame (
547553 {
548554 "y_true" : ["cat" , "ant" , "cat" , "cat" , "ant" , "bird" ],
@@ -603,6 +609,7 @@ def test_recall_score(session):
603609
604610
605611def test_recall_score_matches_sklearn (session ):
612+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
606613 pd_df = pd .DataFrame (
607614 {
608615 "y_true" : [2 , 0 , 2 , 2 , 0 , 1 ],
@@ -620,6 +627,7 @@ def test_recall_score_matches_sklearn(session):
620627
621628
622629def test_recall_score_str_matches_sklearn (session ):
630+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
623631 pd_df = pd .DataFrame (
624632 {
625633 "y_true" : ["cat" , "ant" , "cat" , "cat" , "ant" , "bird" ],
@@ -673,6 +681,7 @@ def test_precision_score(session):
673681
674682
675683def test_precision_score_matches_sklearn (session ):
684+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
676685 pd_df = pd .DataFrame (
677686 {
678687 "y_true" : [2 , 0 , 2 , 2 , 0 , 1 ],
@@ -695,6 +704,7 @@ def test_precision_score_matches_sklearn(session):
695704
696705
697706def test_precision_score_str_matches_sklearn (session ):
707+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
698708 pd_df = pd .DataFrame (
699709 {
700710 "y_true" : ["cat" , "ant" , "cat" , "cat" , "ant" , "bird" ],
@@ -752,6 +762,7 @@ def test_f1_score(session):
752762
753763
754764def test_f1_score_matches_sklearn (session ):
765+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
755766 pd_df = pd .DataFrame (
756767 {
757768 "y_true" : [2 , 0 , 2 , 2 , 0 , 1 ],
@@ -769,6 +780,7 @@ def test_f1_score_matches_sklearn(session):
769780
770781
771782def test_f1_score_str_matches_sklearn (session ):
783+ sklearn_metrics = pytest .importorskip ("sklearn.metrics" )
772784 pd_df = pd .DataFrame (
773785 {
774786 "y_true" : ["cat" , "ant" , "cat" , "cat" , "ant" , "bird" ],
0 commit comments