33import re
44import typing
55import warnings
6- from importlib import import_module
76from pathlib import Path
87
98import joblib
109import numpy as np
1110import pandas as pd
12- from catboost import CatBoostClassifier
13- from sklearn .ensemble import RandomForestClassifier
1411from sklearn .exceptions import InconsistentVersionWarning
1512from sklearn .metrics import (
1613 accuracy_score ,
2724from jabs .core .utils import hash_file
2825from jabs .project import Project , TrackLabels , load_training_data
2926
30- _VERSION = 10
31-
32- try :
33- _xgboost = import_module ("xgboost" )
34- except ImportError :
35- # we were unable to import the xgboost module. It's either not
36- # installed (it should be if the user installed JABS as a package)
37- # or it may have been unable to be imported due to a missing
38- # libomp. Either way, we won't add it to the available choices so
39- # we can otherwise ignore this exception
40- _xgboost = None
41- logging .warning (
42- "Unable to import xgboost. XGBoost support will be unavailable. "
43- "You may need to install xgboost and/or libomp."
44- )
45-
46-
47- # Classifier factory helpers and mapping
48- def _make_random_forest (n_jobs : int , random_seed : int | None ):
49- """Factory function to construct a RandomForest classifier."""
50- return RandomForestClassifier (n_jobs = n_jobs , random_state = random_seed )
51-
52-
53- def _make_catboost (n_jobs : int , random_seed : int | None ):
54- """Factory function to construct a CatBoost classifier."""
55- return CatBoostClassifier (
56- thread_count = n_jobs ,
57- random_state = random_seed ,
58- verbose = False , # Suppress training output
59- allow_writing_files = False , # Don't write intermediate files
60- )
61-
62-
63- def _make_xgboost (n_jobs : int , random_seed : int | None ):
64- """Factory function to construct an XGBoost classifier."""
65- if _xgboost is None :
66- raise RuntimeError (
67- "XGBoost classifier requested but 'xgboost' is not available in this environment."
68- )
69- return _xgboost .XGBClassifier (n_jobs = n_jobs , random_state = random_seed )
27+ from .factories import make_catboost , make_random_forest , make_xgboost
7028
29+ _VERSION = 11
7130
7231# _CLASSIFIER_FACTORIES serves as both the single source of truth for classifiers
7332# supported by the current JABS environment, in addition to the mapping of ClassifierTypes
7433# to factory functions that produce instantiated classifiers for that type
7534_CLASSIFIER_FACTORIES : dict [ClassifierType , typing .Callable [[int , int | None ], typing .Any ]] = {
76- ClassifierType .RANDOM_FOREST : _make_random_forest ,
77- ClassifierType .CATBOOST : _make_catboost ,
35+ ClassifierType .RANDOM_FOREST : make_random_forest ,
36+ ClassifierType .CATBOOST : make_catboost ,
7837}
79- if _xgboost is not None :
80- _CLASSIFIER_FACTORIES [ClassifierType .XGBOOST ] = _make_xgboost
38+
39+ # Attempt to register XGBoost if available. While it will be installed, because it is a
40+ # package dependency, it might not be usable on macOS due to missing libomp. In this case
41+ # xgboost will raise an ImportError, so we try importing and catch ImportError to see if it
42+ # is usable in the current environment.
43+ # Users will be warned if XGBoost support is unavailable. This can be resolved by installing
44+ # libomp using homebrew.
45+ try :
46+ import xgboost # noqa F401
47+ except ImportError :
48+ logging .warning (
49+ "Unable to import xgboost. XGBoost support will be unavailable. "
50+ "You may need to install xgboost and/or libomp."
51+ )
52+ else :
53+ _CLASSIFIER_FACTORIES [ClassifierType .XGBOOST ] = make_xgboost
8154
8255
8356class Classifier :
@@ -428,39 +401,31 @@ def train(self, data, random_seed: int | None = None):
428401 features , labels = self .downsample_balance (features , labels , random_seed )
429402
430403 classifier = self ._create_classifier (random_seed = random_seed )
431-
432- if self ._classifier_type in (ClassifierType .XGBOOST , ClassifierType .CATBOOST ):
433- with warnings .catch_warnings ():
434- warnings .simplefilter ("ignore" , category = FutureWarning )
435- # XGBoost and CatBoost natively support NaN as a marker for missing values and handle them
436- # during tree construction. For these classifiers we therefore convert infinite values to NaN
437- # and leave them as missing, instead of imputing them with 0. This differs from the
438- # Random Forest path below, where both infinities and NaN are
439- # replaced with 0.
440- cleaned_features = features .replace ([np .inf , - np .inf ], np .nan )
441- self ._classifier = classifier .fit (cleaned_features , labels )
442- else :
443- # RandomForestClassifier (and most other sklearn estimators) do not natively support NaN
444- # values, so here we replace infinite values and NaNs with 0 before fitting.
445- cleaned_features = features .replace ([np .inf , - np .inf ], 0 ).fillna (0 )
404+ cleaned_features = self ._clean_features (features )
405+ with warnings .catch_warnings ():
406+ warnings .simplefilter ("ignore" , category = FutureWarning )
446407 self ._classifier = classifier .fit (cleaned_features , labels )
447408
448409 # Classifier may have been re-used from a prior training, blank the logging attributes
449410 self ._classifier_file = None
450411 self ._classifier_hash = None
451412 self ._classifier_source = None
452413
453- def sort_features_to_classify (self , features ):
454- """sorts features to match the current classifier"""
455- if self ._classifier_type == ClassifierType .XGBOOST :
414+ def get_features_to_classify (self , features : pd .DataFrame ) -> pd .DataFrame :
415+ """gets features for classification, handling classifier-specific quirks."""
416+ if self .classifier_type == ClassifierType .XGBOOST :
417+ # XGBoost feature names are obtained from the booster
456418 classifier_columns = self ._classifier .get_booster ().feature_names
457- elif self ._classifier_type == ClassifierType .CATBOOST :
458- classifier_columns = self ._classifier .feature_names_
459419 else :
460- # sklearn places feature names in feature_names_in_
461- classifier_columns = self ._classifier .feature_names_in_
462- features_sorted = features [classifier_columns ]
463- return features_sorted
420+ # For other classifiers, use the feature names from the underlying model
421+ if hasattr (self ._classifier , "feature_names_in_" ):
422+ classifier_columns = list (self ._classifier .feature_names_in_ )
423+ elif hasattr (self ._classifier , "feature_names_" ):
424+ classifier_columns = list (self ._classifier .feature_names_ )
425+ else :
426+ raise RuntimeError ("Error obtaining feature names from classifier." )
427+
428+ return features [classifier_columns ]
464429
465430 def predict (
466431 self , features : pd .DataFrame , frame_indexes : np .ndarray | None = None
@@ -474,18 +439,10 @@ def predict(
474439 Returns:
475440 predicted class vector
476441 """
477- if self ._classifier_type in (ClassifierType .XGBOOST , ClassifierType .CATBOOST ):
478- with warnings .catch_warnings ():
479- warnings .simplefilter ("ignore" , category = FutureWarning )
480- # XGBoost and CatBoost can handle NaN, just replace infinities
481- result = self ._classifier .predict (
482- self .sort_features_to_classify (features .replace ([np .inf , - np .inf ], np .nan ))
483- )
484- else :
485- # Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
486- result = self ._classifier .predict (
487- self .sort_features_to_classify (features .replace ([np .inf , - np .inf ], 0 ).fillna (0 ))
488- )
442+ cleaned_features = self .get_features_to_classify (self ._clean_features (features ))
443+ with warnings .catch_warnings ():
444+ warnings .simplefilter ("ignore" , category = FutureWarning )
445+ result = self ._classifier .predict (cleaned_features )
489446
490447 # Insert -1s into class prediction when no prediction is made
491448 if frame_indexes is not None :
@@ -507,18 +464,10 @@ def predict_proba(
507464 Returns:
508465 prediction probability matrix
509466 """
510- if self ._classifier_type in (ClassifierType .XGBOOST , ClassifierType .CATBOOST ):
511- with warnings .catch_warnings ():
512- warnings .simplefilter ("ignore" , category = FutureWarning )
513- # XGBoost and CatBoost can handle NaN, just replace infinities
514- result = self ._classifier .predict_proba (
515- self .sort_features_to_classify (features .replace ([np .inf , - np .inf ], np .nan ))
516- )
517- else :
518- # Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
519- result = self ._classifier .predict_proba (
520- self .sort_features_to_classify (features .replace ([np .inf , - np .inf ], 0 ).fillna (0 ))
521- )
467+ cleaned_features = self .get_features_to_classify (self ._clean_features (features ))
468+ with warnings .catch_warnings ():
469+ warnings .simplefilter ("ignore" , category = FutureWarning )
470+ result = self ._classifier .predict_proba (cleaned_features )
522471
523472 # Insert 0 probabilities when no prediction is made
524473 if frame_indexes is not None :
@@ -564,7 +513,7 @@ def load(self, path: Path):
564513
565514 if c .version != _VERSION :
566515 raise ValueError (
567- f"Error deserializing classifier. File version { c .version } , expected { _VERSION } ."
516+ f"Unable to deserialize pickled classifier. File version { c .version } , expected { _VERSION } ."
568517 )
569518
570519 # make sure the value passed for the classifier parameter is valid
@@ -741,3 +690,22 @@ def label_threshold_met(
741690 def _supported_classifier_choices () -> set [ClassifierType ]:
742691 """Determine the list of supported classifier types in the current JABS environment."""
743692 return set (_CLASSIFIER_FACTORIES .keys ())
693+
694+ def _clean_features (self , features : pd .DataFrame ) -> pd .DataFrame :
695+ """Clean features for prediction, handling missing and infinite values.
696+
697+ Args:
698+ features: DataFrame of feature data to clean.
699+
700+ Returns:
701+ Cleaned DataFrame with missing and infinite values handled.
702+ """
703+ if self ._classifier_type in (
704+ ClassifierType .XGBOOST ,
705+ ClassifierType .CATBOOST ,
706+ ):
707+ # these classifiers can handle NaN, just replace infinities
708+ return features .replace ([np .inf , - np .inf ], np .nan )
709+ else :
710+ # Random forests can't handle NAs & infs, so fill them with 0s
711+ return features .replace ([np .inf , - np .inf ], 0 ).fillna (0 )
0 commit comments