|
1 | 1 | """Utility methods for CheckMates objectives."""
|
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import pandas as pd |
| 5 | + |
2 | 6 | from checkmates import objectives
|
3 | 7 | from checkmates.exceptions import ObjectiveCreationError, ObjectiveNotFoundError
|
4 | 8 | from checkmates.objectives.objective_base import ObjectiveBase
|
5 |
| -from checkmates.problem_types import handle_problem_types |
| 9 | +from checkmates.problem_types import ProblemTypes, handle_problem_types |
6 | 10 | from checkmates.utils.gen_utils import _get_subclasses
|
| 11 | +from checkmates.utils.logger import get_logger |
| 12 | + |
| 13 | +logger = get_logger(__file__) |
7 | 14 |
|
8 | 15 |
|
9 | 16 | def get_non_core_objectives():
|
@@ -90,6 +97,35 @@ def get_objective(objective, return_instance=False, **kwargs):
|
90 | 97 | return objective_class
|
91 | 98 |
|
92 | 99 |
|
| 100 | +def get_problem_type( |
| 101 | + input_problem_type: Optional[str], |
| 102 | + target_data: pd.Series, |
| 103 | +) -> ProblemTypes: |
| 104 | + """Helper function to determine if classification problem is binary or multiclass dependent on target variable values.""" |
| 105 | + if not input_problem_type: |
| 106 | + raise ValueError("problem type is required") |
| 107 | + if input_problem_type.lower() == "classification": |
| 108 | + values: pd.Series = target_data.value_counts() |
| 109 | + if values.size == 2: |
| 110 | + return ProblemTypes.BINARY |
| 111 | + elif values.size > 2: |
| 112 | + return ProblemTypes.MULTICLASS |
| 113 | + else: |
| 114 | + message: str = "The target field contains less than two unique values. It cannot be used for modeling." |
| 115 | + logger.error(message, exc_info=True) |
| 116 | + raise ValueError(message) |
| 117 | + |
| 118 | + if input_problem_type.lower() == "regression": |
| 119 | + return ProblemTypes.REGRESSION |
| 120 | + |
| 121 | + if input_problem_type.lower() == "time series regression": |
| 122 | + return ProblemTypes.TIME_SERIES_REGRESSION |
| 123 | + |
| 124 | + message = f"Unexpected problem type provided in configuration: {input_problem_type}" |
| 125 | + logger.error(message, exc_info=True) |
| 126 | + raise ValueError(message) |
| 127 | + |
| 128 | + |
93 | 129 | def get_default_primary_search_objective(problem_type):
|
94 | 130 | """Get the default primary search objective for a problem type.
|
95 | 131 |
|
|
0 commit comments