@@ -490,35 +490,71 @@ def sort_features_to_classify(self, features):
490490 features_sorted = features [classifier_columns ]
491491 return features_sorted
492492
493- def predict (self , features ):
494- """predict classes for a given set of features"""
493+ def predict (
494+ self , features : pd .DataFrame , frame_indexes : np .ndarray | None = None
495+ ) -> np .ndarray :
496+ """predict classes for a given set of features
497+
498+ Args:
499+ features: DataFrame of feature data to classify
500+ frame_indexes: frame indexes to classify (default all)
501+
502+ Returns:
503+ predicted class vector
504+ """
495505 if self ._classifier_type in (ClassifierType .XGBOOST , ClassifierType .CATBOOST ):
496506 with warnings .catch_warnings ():
497507 warnings .simplefilter ("ignore" , category = FutureWarning )
498508 # XGBoost and CatBoost can handle NaN, just replace infinities
499509 result = self ._classifier .predict (
500510 self .sort_features_to_classify (features .replace ([np .inf , - np .inf ], np .nan ))
501511 )
502- return result
503- # Random Forest can't handle NAs & infs, so fill them with 0s
504- return self ._classifier .predict (
505- self .sort_features_to_classify (features .replace ([np .inf , - np .inf ], 0 ).fillna (0 ))
506- )
512+ else :
513+ # Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
514+ result = self ._classifier .predict (
515+ self .sort_features_to_classify (features .replace ([np .inf , - np .inf ], 0 ).fillna (0 ))
516+ )
517+
518+ # Insert -1s into class prediction when no prediction is made
519+ if frame_indexes is not None :
520+ result_adjusted = np .full (result .shape , - 1 , dtype = np .int8 )
521+ result_adjusted [frame_indexes ] = result [frame_indexes ]
522+ result = result_adjusted
523+
524+ return result
507525
508- def predict_proba (self , features ):
509- """predict probabilities for a given set of features"""
526+ def predict_proba (
527+ self , features : pd .DataFrame , frame_indexes : np .ndarray | None = None
528+ ) -> np .ndarray :
529+ """predict probabilities for a given set of features.
530+
531+ Args:
532+ features: DataFrame of feature data to classify
533+ frame_indexes: frame indexes to classify (default all)
534+
535+ Returns:
536+ prediction probability matrix
537+ """
510538 if self ._classifier_type in (ClassifierType .XGBOOST , ClassifierType .CATBOOST ):
511539 with warnings .catch_warnings ():
512540 warnings .simplefilter ("ignore" , category = FutureWarning )
513541 # XGBoost and CatBoost can handle NaN, just replace infinities
514542 result = self ._classifier .predict_proba (
515543 self .sort_features_to_classify (features .replace ([np .inf , - np .inf ], np .nan ))
516544 )
517- return result
518- # Random Forest can't handle NAs & infs, so fill them with 0s
519- return self ._classifier .predict_proba (
520- self .sort_features_to_classify (features .replace ([np .inf , - np .inf ], 0 ).fillna (0 ))
521- )
545+ else :
546+ # Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
547+ result = self ._classifier .predict_proba (
548+ self .sort_features_to_classify (features .replace ([np .inf , - np .inf ], 0 ).fillna (0 ))
549+ )
550+
551+ # Insert 0 probabilities when no prediction is made
552+ if frame_indexes is not None :
553+ result_adjusted = np .full (result .shape , 0 , dtype = np .float32 )
554+ result_adjusted [frame_indexes ] = result [frame_indexes ]
555+ result = result_adjusted
556+
557+ return result
522558
523559 def save (self , path : Path ):
524560 """save the classifier to a file
0 commit comments