Skip to content

Commit 7851660

Browse files
authored
Merge pull request #286 from KumarLabJax/refactor-classifier
some refactoring in classifier.py
2 parents 0209d0a + 850524c commit 7851660

File tree

2 files changed

+126
-96
lines changed

2 files changed

+126
-96
lines changed

src/jabs/classifier/classifier.py

Lines changed: 64 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@
33
import re
44
import typing
55
import warnings
6-
from importlib import import_module
76
from pathlib import Path
87

98
import joblib
109
import numpy as np
1110
import pandas as pd
12-
from catboost import CatBoostClassifier
13-
from sklearn.ensemble import RandomForestClassifier
1411
from sklearn.exceptions import InconsistentVersionWarning
1512
from sklearn.metrics import (
1613
accuracy_score,
@@ -27,57 +24,33 @@
2724
from jabs.core.utils import hash_file
2825
from jabs.project import Project, TrackLabels, load_training_data
2926

30-
_VERSION = 10
31-
32-
try:
33-
_xgboost = import_module("xgboost")
34-
except ImportError:
35-
# we were unable to import the xgboost module. It's either not
36-
# installed (it should be if the user installed JABS as a package)
37-
# or it may have been unable to be imported due to a missing
38-
# libomp. Either way, we won't add it to the available choices so
39-
# we can otherwise ignore this exception
40-
_xgboost = None
41-
logging.warning(
42-
"Unable to import xgboost. XGBoost support will be unavailable. "
43-
"You may need to install xgboost and/or libomp."
44-
)
45-
46-
47-
# Classifier factory helpers and mapping
48-
def _make_random_forest(n_jobs: int, random_seed: int | None):
49-
"""Factory function to construct a RandomForest classifier."""
50-
return RandomForestClassifier(n_jobs=n_jobs, random_state=random_seed)
51-
52-
53-
def _make_catboost(n_jobs: int, random_seed: int | None):
54-
"""Factory function to construct a CatBoost classifier."""
55-
return CatBoostClassifier(
56-
thread_count=n_jobs,
57-
random_state=random_seed,
58-
verbose=False, # Suppress training output
59-
allow_writing_files=False, # Don't write intermediate files
60-
)
61-
62-
63-
def _make_xgboost(n_jobs: int, random_seed: int | None):
64-
"""Factory function to construct an XGBoost classifier."""
65-
if _xgboost is None:
66-
raise RuntimeError(
67-
"XGBoost classifier requested but 'xgboost' is not available in this environment."
68-
)
69-
return _xgboost.XGBClassifier(n_jobs=n_jobs, random_state=random_seed)
27+
from .factories import make_catboost, make_random_forest, make_xgboost
7028

29+
_VERSION = 11
7130

7231
# _CLASSIFIER_FACTORIES serves as both the single source of truth for classifiers
7332
# supported by the current JABS environment, in addition to the mapping of ClassifierTypes
7433
# to factory functions that produce instantiated classifiers for that type
7534
_CLASSIFIER_FACTORIES: dict[ClassifierType, typing.Callable[[int, int | None], typing.Any]] = {
76-
ClassifierType.RANDOM_FOREST: _make_random_forest,
77-
ClassifierType.CATBOOST: _make_catboost,
35+
ClassifierType.RANDOM_FOREST: make_random_forest,
36+
ClassifierType.CATBOOST: make_catboost,
7837
}
79-
if _xgboost is not None:
80-
_CLASSIFIER_FACTORIES[ClassifierType.XGBOOST] = _make_xgboost
38+
39+
# Attempt to register XGBoost if available. While it will be installed, because it is a
40+
# package dependency, it might not be usable on macOS due to missing libomp. In this case
41+
# xgboost will raise an ImportError, so we try importing and catch ImportError to see if it
42+
# is usable in the current environment.
43+
# Users will be warned if XGBoost support is unavailable. This can be resolved by installing
44+
# libomp using homebrew.
45+
try:
46+
import xgboost # noqa F401
47+
except ImportError:
48+
logging.warning(
49+
"Unable to import xgboost. XGBoost support will be unavailable. "
50+
"You may need to install xgboost and/or libomp."
51+
)
52+
else:
53+
_CLASSIFIER_FACTORIES[ClassifierType.XGBOOST] = make_xgboost
8154

8255

8356
class Classifier:
@@ -428,39 +401,31 @@ def train(self, data, random_seed: int | None = None):
428401
features, labels = self.downsample_balance(features, labels, random_seed)
429402

430403
classifier = self._create_classifier(random_seed=random_seed)
431-
432-
if self._classifier_type in (ClassifierType.XGBOOST, ClassifierType.CATBOOST):
433-
with warnings.catch_warnings():
434-
warnings.simplefilter("ignore", category=FutureWarning)
435-
# XGBoost and CatBoost natively support NaN as a marker for missing values and handle them
436-
# during tree construction. For these classifiers we therefore convert infinite values to NaN
437-
# and leave them as missing, instead of imputing them with 0. This differs from the
438-
# Random Forest path below, where both infinities and NaN are
439-
# replaced with 0.
440-
cleaned_features = features.replace([np.inf, -np.inf], np.nan)
441-
self._classifier = classifier.fit(cleaned_features, labels)
442-
else:
443-
# RandomForestClassifier (and most other sklearn estimators) do not natively support NaN
444-
# values, so here we replace infinite values and NaNs with 0 before fitting.
445-
cleaned_features = features.replace([np.inf, -np.inf], 0).fillna(0)
404+
cleaned_features = self._clean_features(features)
405+
with warnings.catch_warnings():
406+
warnings.simplefilter("ignore", category=FutureWarning)
446407
self._classifier = classifier.fit(cleaned_features, labels)
447408

448409
# Classifier may have been re-used from a prior training, blank the logging attributes
449410
self._classifier_file = None
450411
self._classifier_hash = None
451412
self._classifier_source = None
452413

453-
def sort_features_to_classify(self, features):
454-
"""sorts features to match the current classifier"""
455-
if self._classifier_type == ClassifierType.XGBOOST:
414+
def get_features_to_classify(self, features: pd.DataFrame) -> pd.DataFrame:
415+
"""gets features for classification, handling classifier-specific quirks."""
416+
if self.classifier_type == ClassifierType.XGBOOST:
417+
# XGBoost feature names are obtained from the booster
456418
classifier_columns = self._classifier.get_booster().feature_names
457-
elif self._classifier_type == ClassifierType.CATBOOST:
458-
classifier_columns = self._classifier.feature_names_
459419
else:
460-
# sklearn places feature names in feature_names_in_
461-
classifier_columns = self._classifier.feature_names_in_
462-
features_sorted = features[classifier_columns]
463-
return features_sorted
420+
# For other classifiers, use the feature names from the underlying model
421+
if hasattr(self._classifier, "feature_names_in_"):
422+
classifier_columns = list(self._classifier.feature_names_in_)
423+
elif hasattr(self._classifier, "feature_names_"):
424+
classifier_columns = list(self._classifier.feature_names_)
425+
else:
426+
raise RuntimeError("Error obtaining feature names from classifier.")
427+
428+
return features[classifier_columns]
464429

465430
def predict(
466431
self, features: pd.DataFrame, frame_indexes: np.ndarray | None = None
@@ -474,18 +439,10 @@ def predict(
474439
Returns:
475440
predicted class vector
476441
"""
477-
if self._classifier_type in (ClassifierType.XGBOOST, ClassifierType.CATBOOST):
478-
with warnings.catch_warnings():
479-
warnings.simplefilter("ignore", category=FutureWarning)
480-
# XGBoost and CatBoost can handle NaN, just replace infinities
481-
result = self._classifier.predict(
482-
self.sort_features_to_classify(features.replace([np.inf, -np.inf], np.nan))
483-
)
484-
else:
485-
# Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
486-
result = self._classifier.predict(
487-
self.sort_features_to_classify(features.replace([np.inf, -np.inf], 0).fillna(0))
488-
)
442+
cleaned_features = self.get_features_to_classify(self._clean_features(features))
443+
with warnings.catch_warnings():
444+
warnings.simplefilter("ignore", category=FutureWarning)
445+
result = self._classifier.predict(cleaned_features)
489446

490447
# Insert -1s into class prediction when no prediction is made
491448
if frame_indexes is not None:
@@ -507,18 +464,10 @@ def predict_proba(
507464
Returns:
508465
prediction probability matrix
509466
"""
510-
if self._classifier_type in (ClassifierType.XGBOOST, ClassifierType.CATBOOST):
511-
with warnings.catch_warnings():
512-
warnings.simplefilter("ignore", category=FutureWarning)
513-
# XGBoost and CatBoost can handle NaN, just replace infinities
514-
result = self._classifier.predict_proba(
515-
self.sort_features_to_classify(features.replace([np.inf, -np.inf], np.nan))
516-
)
517-
else:
518-
# Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
519-
result = self._classifier.predict_proba(
520-
self.sort_features_to_classify(features.replace([np.inf, -np.inf], 0).fillna(0))
521-
)
467+
cleaned_features = self.get_features_to_classify(self._clean_features(features))
468+
with warnings.catch_warnings():
469+
warnings.simplefilter("ignore", category=FutureWarning)
470+
result = self._classifier.predict_proba(cleaned_features)
522471

523472
# Insert 0 probabilities when no prediction is made
524473
if frame_indexes is not None:
@@ -564,7 +513,7 @@ def load(self, path: Path):
564513

565514
if c.version != _VERSION:
566515
raise ValueError(
567-
f"Error deserializing classifier. File version {c.version}, expected {_VERSION}."
516+
f"Unable to deserialize pickled classifier. File version {c.version}, expected {_VERSION}."
568517
)
569518

570519
# make sure the value passed for the classifier parameter is valid
@@ -741,3 +690,22 @@ def label_threshold_met(
741690
def _supported_classifier_choices() -> set[ClassifierType]:
742691
"""Determine the list of supported classifier types in the current JABS environment."""
743692
return set(_CLASSIFIER_FACTORIES.keys())
693+
694+
def _clean_features(self, features: pd.DataFrame) -> pd.DataFrame:
695+
"""Clean features for prediction, handling missing and infinite values.
696+
697+
Args:
698+
features: DataFrame of feature data to clean.
699+
700+
Returns:
701+
Cleaned DataFrame with missing and infinite values handled.
702+
"""
703+
if self._classifier_type in (
704+
ClassifierType.XGBOOST,
705+
ClassifierType.CATBOOST,
706+
):
707+
# these classifiers can handle NaN, just replace infinities
708+
return features.replace([np.inf, -np.inf], np.nan)
709+
else:
710+
# Random forests can't handle NAs & infs, so fill them with 0s
711+
return features.replace([np.inf, -np.inf], 0).fillna(0)

src/jabs/classifier/factories.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Factory functions for various classifiers."""
2+
3+
from catboost import CatBoostClassifier
4+
from sklearn.base import ClassifierMixin
5+
from sklearn.ensemble import RandomForestClassifier
6+
7+
8+
def make_random_forest(n_jobs: int, random_seed: int | None) -> RandomForestClassifier:
9+
"""Factory function to construct a RandomForest classifier.
10+
11+
Args:
12+
n_jobs (int): Number of parallel jobs.
13+
random_seed (int | None): Random seed for reproducibility.
14+
15+
Returns:
16+
RandomForestClassifier: An instance of RandomForestClassifier.
17+
"""
18+
return RandomForestClassifier(n_jobs=n_jobs, random_state=random_seed)
19+
20+
21+
def make_catboost(n_jobs: int, random_seed: int | None) -> CatBoostClassifier:
22+
"""Factory function to construct a CatBoost classifier.
23+
24+
Args:
25+
n_jobs (int): Number of parallel jobs.
26+
random_seed (int | None): Random seed for reproducibility.
27+
28+
Returns:
29+
CatBoostClassifier: An instance of CatBoostClassifier.
30+
"""
31+
return CatBoostClassifier(
32+
thread_count=n_jobs,
33+
random_state=random_seed,
34+
verbose=False, # Suppress training output
35+
allow_writing_files=False, # Don't write intermediate files
36+
)
37+
38+
39+
def make_xgboost(n_jobs: int, random_seed: int | None) -> ClassifierMixin:
40+
"""Factory function to construct an XGBoost classifier.
41+
42+
XGBoost might not be available in all environments (such as macOS without
43+
libomp installed), so we try to import here.
44+
45+
Args:
46+
n_jobs (int): Number of parallel jobs.
47+
random_seed (int | None): Random seed for reproducibility.
48+
49+
Returns:
50+
An instance of XGBClassifier. Note: type hint is ClassifierMixin to avoid
51+
direct dependency on xgboost in type hints.
52+
53+
Raises:
54+
RuntimeError: If XGBoost is not available.
55+
"""
56+
try:
57+
import xgboost
58+
except ImportError as e:
59+
raise RuntimeError(
60+
"XGBoost classifier requested but 'xgboost' is not available in this environment."
61+
) from e
62+
return xgboost.XGBClassifier(n_jobs=n_jobs, random_state=random_seed)

0 commit comments

Comments
 (0)