@@ -65,15 +65,15 @@ def _validate_thresholds(thresholds: Optional[Union[int, List[float], Array]]) -
6565
6666 if apc .is_array_api_obj (thresholds ):
6767 xp = apc .array_namespace (thresholds )
68- if not xp .all ((thresholds >= 0 ) & (thresholds <= 1 )): # type: ignore
68+ if not xp .all ((thresholds >= 0 ) & (thresholds <= 1 )):
6969 raise ValueError (
7070 "Expected argument `thresholds` to be an Array of floats in the [0,1] "
7171 f"range, but got { thresholds } " ,
7272 )
73- if not thresholds .ndim == 1 : # type: ignore
73+ if not thresholds .ndim == 1 :
7474 raise ValueError (
7575 "Expected argument `thresholds` to be a 1D Array, but got an Array with "
76- f"{ thresholds .ndim } dimensions" , # type: ignore
76+ f"{ thresholds .ndim } dimensions" ,
7777 )
7878
7979
@@ -284,9 +284,9 @@ def _binary_precision_recall_curve_compute(
284284 """Compute the precision and recall for all unique thresholds."""
285285 if apc .is_array_api_obj (state ) and thresholds is not None :
286286 xp = apc .array_namespace (state , thresholds )
287- tps = state [:, 1 , 1 ] # type: ignore[call-overload]
288- fps = state [:, 0 , 1 ] # type: ignore[call-overload]
289- fns = state [:, 1 , 0 ] # type: ignore[call-overload]
287+ tps = state [:, 1 , 1 ]
288+ fps = state [:, 0 , 1 ]
289+ fns = state [:, 1 , 0 ]
290290 precision = safe_divide (tps , tps + fps )
291291 recall = safe_divide (tps , tps + fns )
292292 precision = xp .concat (
@@ -322,8 +322,8 @@ def _binary_precision_recall_curve_compute(
322322 )
323323 thresholds = xp .flip (thresholds , axis = 0 )
324324 if hasattr (thresholds , "detach" ):
325- thresholds = clone (thresholds .detach ()) # type: ignore
326- return precision , recall , thresholds # type: ignore[return-value]
325+ thresholds = clone (thresholds .detach ())
326+ return precision , recall , thresholds
327327
328328
329329def binary_precision_recall_curve (
@@ -541,7 +541,7 @@ def _multiclass_precision_recall_curve_validate_arrays(
541541 f"values in `target` but found { num_unique_values } values." ,
542542 )
543543
544- return xp # type: ignore[no-any-return]
544+ return xp
545545
546546
547547def _multiclass_precision_recall_curve_format_arrays (
@@ -618,9 +618,9 @@ def _multiclass_precision_recall_curve_compute(
618618
619619 if apc .is_array_api_obj (state ) and thresholds is not None :
620620 xp = apc .array_namespace (state , thresholds )
621- tps = state [:, :, 1 , 1 ] # type: ignore[call-overload]
622- fps = state [:, :, 0 , 1 ] # type: ignore[call-overload]
623- fns = state [:, :, 1 , 0 ] # type: ignore[call-overload]
621+ tps = state [:, :, 1 , 1 ]
622+ fps = state [:, :, 0 , 1 ]
623+ fns = state [:, :, 1 , 0 ]
624624 precision = safe_divide (tps , tps + fps )
625625 recall = safe_divide (tps , tps + fns )
626626 precision = xp .concat (
@@ -989,9 +989,9 @@ def _multilabel_precision_recall_curve_compute(
989989 """Compute the precision and recall for all unique thresholds."""
990990 if apc .is_array_api_obj (state ) and thresholds is not None :
991991 xp = apc .array_namespace (state )
992- tps = state [:, :, 1 , 1 ] # type: ignore[call-overload]
993- fps = state [:, :, 0 , 1 ] # type: ignore[call-overload]
994- fns = state [:, :, 1 , 0 ] # type: ignore[call-overload]
992+ tps = state [:, :, 1 , 1 ]
993+ fps = state [:, :, 0 , 1 ]
994+ fns = state [:, :, 1 , 0 ]
995995 precision = safe_divide (tps , tps + fps )
996996 recall = safe_divide (tps , tps + fns )
997997 precision = xp .concat (
0 commit comments