@@ -85,7 +85,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, verbose: bool = False, *
8585 :type is_abstain: `boolean`
8686 :return: Array of predictions of shape `(nb_inputs, nb_classes)`.
8787 """
88- from scipy .stats import binom_test
88+ from scipy .stats import binomtest
8989
9090 is_abstain = kwargs .get ("is_abstain" )
9191 if is_abstain is not None and not isinstance (is_abstain , bool ): # pragma: no cover
@@ -100,12 +100,15 @@ def predict(self, x: np.ndarray, batch_size: int = 128, verbose: bool = False, *
100100 # get class counts
101101 counts_pred = self ._prediction_counts (x_i , batch_size = batch_size )
102102 top = counts_pred .argsort ()[::- 1 ]
103- count1 = np .max (counts_pred )
104- count2 = counts_pred [top [1 ]]
103+ # Conersion to int
104+ count1 = int (np .max (counts_pred ))
105+ count2 = int (counts_pred [top [1 ]])
105106
106107 # predict or abstain
107108 smooth_prediction = np .zeros (counts_pred .shape )
108- if (not is_abstain ) or (binom_test (count1 , count1 + count2 , p = 0.5 ) <= self .alpha ):
109+ #Get p value from BinomTestResult object
110+ p_value = binomtest (count1 , count1 + count2 , p = 0.5 ).pvalue
111+ if (not is_abstain ) or (p_value <= self .alpha ):
109112 smooth_prediction [np .argmax (counts_pred )] = 1
110113 elif is_abstain :
111114 n_abstained += 1
0 commit comments