Skip to content

Commit ac1b732

Browse files
nikidukiSamoedvoorhsgithub-actions[bot]
authored
Added CatBoostScorer (#209)
* Added catboost in dependencies * Raw implementation of CatboostScorer * Update init.py files * Minor fix * Added CatBoostScorer * Added tests for CatBoostScorer * Fix init * Fix mypy * Minor fix * Fix loss function * fix multilabel prediction * refactor catboost scorer * minor fix * fix test to match * fix/wandb-final-metrics-skipped (#212) * fix * sklearn scorer proper name * fix typing errors * try to fix pydantic errors * Remove artifacts from final metrics (#216) * Update wandb.py * Update wandb.py * Update wandb.py * Update _optimization_info.py * remove print * fix few shot split (#219) * fix few shot split * lint * remove egor (#221) * Feat/bert early stopping (#223) * change how `clear_cache` is called * first version of early stopping * change mypy version * train_test_split bug fix * add `compute_metrics` and `EarlyStoppingCallback` * bug fix * fix mypy * try to fix `"eval_f1" not found` error * forgot to upd `from_context` * try to fix mypy * ty to fix "not found f1" error * refactor a little bit * disable early stopping for lora * fix typing errors * update contributing and makefile * minor change * use our metrics * add docstrings * set 3.10 for mypy * upd contributing.md * try to fix bug * try to fix typing issue * try to fix * add early stopping to ptuning * Check if metric can handle dataset type (#224) * add test for configuration * lint * satisfy mypy * add prompt logging (#220) * add prompt logging * Update optimizer_config.schema.json * fix --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * fix caching (#225) * Fix default prompt (#226) * fix default prompt * allow to use default prompt with override * some refactors * add `use_embedding_features` * fix tests * fix lint * fix strenum * Added catboost in dependencies * Raw implementation of CatboostScorer * Update init.py files * Minor fix * Added CatBoostScorer * Added tests for CatBoostScorer * Fix init * Fix mypy * Minor fix * Fix loss function * fix multilabel prediction * refactor catboost scorer * minor fix * fix test to match * fix loading * fix lint * fix typing * fix dumper * add early stopping * fix errors * codestyle * patch catboost with early stopping and catboost * try to fix * fix embed type --------- Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: Алексеев Илья <[email protected]> Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]>
1 parent 439fde2 commit ac1b732

File tree

9 files changed

+422
-0
lines changed

9 files changed

+422
-0
lines changed

autointent/_dump_tools.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import joblib
88
import numpy as np
99
import numpy.typing as npt
10+
from catboost import CatBoostClassifier
1011
from peft import PeftModel
1112
from pydantic import BaseModel
1213
from sklearn.base import BaseEstimator
@@ -47,6 +48,7 @@ class Dumper:
4748
hf_tokenizers = "hf_tokenizers"
4849
torch_models = "torch_models"
4950
ptuning_models = "ptuning_models"
51+
catboost_models = "catboost_models"
5052

5153
@staticmethod
5254
def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
@@ -67,6 +69,7 @@ def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
6769
path / Dumper.hf_tokenizers,
6870
path / Dumper.torch_models,
6971
path / Dumper.ptuning_models,
72+
path / Dumper.catboost_models,
7073
]
7174
for subdir in subdirectories:
7275
subdir.mkdir(parents=True, exist_ok=exists_ok)
@@ -165,6 +168,8 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
165168
except Exception as e:
166169
msg = f"Error dumping HF tokenizer {key}: {e}"
167170
logger.exception(msg)
171+
elif isinstance(val, CatBoostClassifier):
172+
val.save_model(str(path / Dumper.catboost_models / key), format="cbm")
168173
else:
169174
msg = f"Attribute {key} of type {type(val)} cannot be dumped to file system."
170175
logger.error(msg)
@@ -192,6 +197,7 @@ def load( # noqa: C901, PLR0912, PLR0915
192197
pydantic_models: dict[str, Any] = {}
193198
hf_models: dict[str, Any] = {}
194199
hf_tokenizers: dict[str, Any] = {}
200+
catboost_models: dict[str, Any] = {}
195201
torch_models: dict[str, Any] = {}
196202

197203
for child in path.iterdir():
@@ -267,6 +273,15 @@ def load( # noqa: C901, PLR0912, PLR0915
267273
except Exception as e: # noqa: PERF203
268274
msg = f"Error loading HF tokenizer {tokenizer_dir.name}: {e}"
269275
logger.exception(msg)
276+
elif child.name == Dumper.catboost_models:
277+
for model_file in child.iterdir():
278+
try:
279+
model = CatBoostClassifier()
280+
model.load_model(str(path / Dumper.catboost_models / model_file))
281+
catboost_models[model_file.name] = model
282+
except Exception as e: # noqa: PERF203
283+
msg = f"Error loading CatBoost model: {e}"
284+
logger.exception(msg)
270285
elif child.name == Dumper.torch_models:
271286
try:
272287
for model_dir in child.iterdir():
@@ -294,5 +309,6 @@ def load( # noqa: C901, PLR0912, PLR0915
294309
| pydantic_models
295310
| hf_models
296311
| hf_tokenizers
312+
| catboost_models
297313
| torch_models
298314
)

autointent/modules/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .scoring import (
1616
BERTLoRAScorer,
1717
BertScorer,
18+
CatBoostScorer,
1819
CNNScorer,
1920
DescriptionScorer,
2021
DNNCScorer,
@@ -41,6 +42,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4142

4243
SCORING_MODULES: dict[str, type[BaseScorer]] = _create_modules_dict(
4344
[
45+
CatBoostScorer,
4446
DNNCScorer,
4547
KNNScorer,
4648
LinearScorer,
@@ -68,6 +70,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
6870
"BaseModule",
6971
"BaseRegex",
7072
"BaseScorer",
73+
"CatBoostScorer",
7174
"DNNCScorer",
7275
"DescriptionScorer",
7376
"JinoosDecision",

autointent/modules/scoring/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._bert import BertScorer
2+
from ._catboost import CatBoostScorer
23
from ._cnn import CNNScorer
34
from ._description import DescriptionScorer
45
from ._dnnc import DNNCScorer
@@ -13,6 +14,7 @@
1314
"BERTLoRAScorer",
1415
"BertScorer",
1516
"CNNScorer",
17+
"CatBoostScorer",
1618
"DNNCScorer",
1719
"DescriptionScorer",
1820
"KNNScorer",
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .catboost_scorer import CatBoostScorer
2+
3+
__all__ = ["CatBoostScorer"]
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
"""CatBoostScorer class for CatBoost-based classification with switchable encoding."""
2+
3+
import logging
4+
from enum import Enum
5+
from typing import Any, cast
6+
7+
import numpy as np
8+
import numpy.typing as npt
9+
import pandas as pd
10+
from catboost import CatBoostClassifier
11+
12+
from autointent import Context, Embedder
13+
from autointent.configs import EmbedderConfig, TaskTypeEnum
14+
from autointent.custom_types import FloatFromZeroToOne, ListOfLabels
15+
from autointent.modules.base import BaseScorer
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class FeaturesType(str, Enum):
21+
"""Type of features used in CatBoostScorer."""
22+
23+
TEXT = "text"
24+
EMBEDDING = "embedding"
25+
BOTH = "both"
26+
27+
28+
class CatBoostScorer(BaseScorer):
29+
"""CatBoost scorer using either external embeddings or CatBoost's own BoW encoding.
30+
31+
Args:
32+
embedder_config: Config of the base transformer model (HFModelConfig, str, or dict)
33+
If None (default) the scorer relies on CatBoost's own Bag-of-Words encoding,
34+
otherwise the provided embedder is used.
35+
36+
features_type: Type of features used in CatBoost. Can be one of:
37+
- "text": Use only text features (CatBoost's BoW encoding).
38+
- "embedding": Use only embedding features.
39+
- "both": Use both text and embedding features.
40+
41+
use_embedding_features: If True, the model uses CatBoost `embedding_features` otherwise
42+
each number will be in separate column.
43+
44+
loss_function: CatBoost loss function. If None, an appropriate loss is
45+
chosen automatically from the task type.
46+
47+
verbose: If True, CatBoost prints training progress.
48+
49+
val_fraction: fraction of training data used for early stopping. Set to None to disaple early stopping.
50+
Note: early stopping is not supported with multilabel classification.
51+
52+
early_stopping_rounds: number of iterations without metric increasing waiting for early stopping.
53+
Ignored when ``val_fraction`` is ``None``.
54+
55+
**catboost_kwargs: Any additional keyword arguments forwarded to
56+
:class:`catboost.CatBoostClassifier`. Please refer to
57+
`catboost's documentation <https://catboost.ai/docs/en/concepts/python-reference_catboostclassifier>`_
58+
59+
Example:
60+
-------
61+
62+
.. testcode::
63+
64+
from autointent.modules import CatBoostScorer
65+
66+
scorer = CatBoostScorer(
67+
iterations=50,
68+
learning_rate=0.05,
69+
depth=6,
70+
l2_leaf_reg=3,
71+
eval_metric="Accuracy",
72+
random_seed=42,
73+
verbose=False,
74+
features_type="embedding", # or "text" or "both"
75+
)
76+
utterances = ["hello", "goodbye", "allo", "sayonara"]
77+
labels = [0, 1, 0, 1]
78+
scorer.fit(utterances, labels)
79+
test_utterances = ["hi", "bye"]
80+
probabilities = scorer.predict(test_utterances)
81+
print(probabilities)
82+
83+
.. testoutput::
84+
85+
[[0.41493207 0.58506793]
86+
[0.55036046 0.44963954]]
87+
88+
"""
89+
90+
name = "catboost"
91+
supports_multiclass = True
92+
supports_multilabel = True
93+
94+
_model: CatBoostClassifier
95+
96+
encoder_features_types = (FeaturesType.EMBEDDING, FeaturesType.BOTH)
97+
98+
def __init__(
99+
self,
100+
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
101+
features_type: FeaturesType = FeaturesType.BOTH,
102+
use_embedding_features: bool = True,
103+
loss_function: str | None = None,
104+
verbose: bool = False,
105+
val_fraction: float | None = 0.2,
106+
early_stopping_rounds: int = 100,
107+
**catboost_kwargs: dict[str, Any],
108+
) -> None:
109+
self.val_fraction = val_fraction
110+
self.early_stopping_rounds = early_stopping_rounds
111+
self.features_type = features_type
112+
self.use_embedding_features = use_embedding_features
113+
if features_type == FeaturesType.TEXT and use_embedding_features:
114+
msg = "Only catbooost text features will be used, `use_embedding_features` is ignored."
115+
logger.warning(msg)
116+
117+
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
118+
self.loss_function = loss_function
119+
self.verbose = verbose
120+
self.catboost_kwargs = catboost_kwargs or {}
121+
122+
@classmethod
123+
def from_context(
124+
cls,
125+
context: Context,
126+
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
127+
features_type: FeaturesType = FeaturesType.BOTH,
128+
use_embedding_features: bool = True,
129+
loss_function: str | None = None,
130+
verbose: bool = False,
131+
val_fraction: FloatFromZeroToOne | None = 0.2,
132+
early_stopping_rounds: int = 100,
133+
**catboost_kwargs: dict[str, Any],
134+
) -> "CatBoostScorer":
135+
if embedder_config is None:
136+
embedder_config = context.resolve_embedder()
137+
return cls(
138+
embedder_config=embedder_config,
139+
loss_function=loss_function,
140+
verbose=verbose,
141+
features_type=features_type,
142+
use_embedding_features=use_embedding_features,
143+
val_fraction=val_fraction,
144+
early_stopping_rounds=early_stopping_rounds,
145+
**catboost_kwargs,
146+
)
147+
148+
def get_implicit_initialization_params(self) -> dict[str, Any]:
149+
return {
150+
"embedder_config": self.embedder_config.model_dump()
151+
if self.features_type in self.encoder_features_types
152+
else None,
153+
}
154+
155+
def _prepare_data_for_fit(
156+
self,
157+
utterances: list[str],
158+
) -> pd.DataFrame:
159+
if self.features_type in self.encoder_features_types:
160+
encoded_utterances = self._embedder.embed(utterances, TaskTypeEnum.classification).tolist()
161+
if self.use_embedding_features:
162+
data = pd.DataFrame({"embedding": encoded_utterances})
163+
else:
164+
data = pd.DataFrame(np.array(encoded_utterances))
165+
if self.features_type == FeaturesType.BOTH:
166+
data["text"] = utterances
167+
else:
168+
data = pd.DataFrame({"text": utterances})
169+
170+
return data
171+
172+
def get_extra_params(self) -> dict[str, Any]:
173+
extra_params = {}
174+
if self.features_type == FeaturesType.EMBEDDING:
175+
if self.use_embedding_features: # to not raise error if embedding without embedding_features
176+
extra_params["embedding_features"] = ["embedding"]
177+
elif self.features_type in {FeaturesType.TEXT, FeaturesType.BOTH}:
178+
extra_params["text_features"] = ["text"]
179+
if self.features_type == FeaturesType.BOTH and self.use_embedding_features:
180+
extra_params["embedding_features"] = ["embedding"]
181+
else:
182+
msg = f"Unsupported features type: {self.features_type}"
183+
raise ValueError(msg)
184+
return extra_params
185+
186+
def fit(
187+
self,
188+
utterances: list[str],
189+
labels: ListOfLabels,
190+
) -> None:
191+
self._validate_task(labels)
192+
193+
if self.features_type in self.encoder_features_types:
194+
self._embedder = Embedder(self.embedder_config)
195+
196+
dataset = self._prepare_data_for_fit(utterances)
197+
198+
default_loss = (
199+
"MultiLogloss" if self._multilabel else ("MultiClass" if self._n_classes > 2 else "Logloss") # noqa: PLR2004
200+
)
201+
202+
if self._multilabel:
203+
self.val_fraction = None
204+
msg = "Disabling early stopping in CatBoostClassifier as it is not supported with multi-label task."
205+
logger.warning(msg)
206+
207+
self._model = CatBoostClassifier(
208+
loss_function=self.loss_function or default_loss,
209+
verbose=self.verbose,
210+
allow_writing_files=False,
211+
eval_fraction=self.val_fraction,
212+
**self.catboost_kwargs,
213+
**self.get_extra_params(),
214+
)
215+
self._model.fit(
216+
dataset, labels, early_stopping_rounds=self.early_stopping_rounds if self.val_fraction is not None else None
217+
)
218+
219+
def predict(self, utterances: list[str]) -> npt.NDArray[np.float64]:
220+
if getattr(self, "_model", None) is None:
221+
msg = "Model is not trained. Call fit() first."
222+
raise RuntimeError(msg)
223+
data = self._prepare_data_for_fit(utterances)
224+
return cast("npt.NDArray[np.float64]", self._model.predict_proba(data))
225+
226+
def clear_cache(self) -> None:
227+
if hasattr(self, "_model"):
228+
del self._model
229+
if hasattr(self, "_embedder"):
230+
del self._embedder

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dependencies = [
4646
"transformers[torch] (>=4.49.0,<5.0.0)",
4747
"peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)",
4848
"codecarbon (==2.6)",
49+
"catboost (>=1.2.8,<2.0.0)",
4950
]
5051

5152
[project.optional-dependencies]
@@ -69,6 +70,7 @@ typing = [
6970
"types-pygments (>=2.18.0.20240506,<3.0.0)",
7071
"types-setuptools (>=75.2.0.20241019,<76.0.0)",
7172
"joblib-stubs (>=1.4.2.5.20240918,<2.0.0)",
73+
"pandas-stubs (>= 2.2.3.250527, <3.0.0)",
7274
]
7375
docs = [
7476
"sphinx (>=8.1.3,<9.0.0)",
@@ -219,6 +221,7 @@ module = [
219221
"wandb",
220222
"dspy",
221223
"dspy.evaluate.auto_evaluation",
224+
"catboost",
222225
]
223226
ignore_missing_imports = true
224227

tests/assets/configs/multiclass.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@
2929
- module_name: sklearn
3030
clf_name: [RandomForestClassifier]
3131
n_estimators: [5, 10]
32+
- module_name: catboost
33+
iterations: [50, 100]
34+
learning_rate: [0.05, 0.1]
35+
depth: [1, 10]
36+
l2_leaf_reg: [1, 5]
37+
eval_metric: ["Accuracy"]
38+
random_seed: [42]
39+
features_type: ["embedding"]
40+
embedder_config:
41+
- model_name: prajjwal1/bert-tiny
3242
- module_name: bert
3343
classification_model_config:
3444
- model_name: avsolatorio/GIST-small-Embedding-v0

tests/assets/configs/multilabel.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@
2525
- module_name: sklearn
2626
clf_name: [RandomForestClassifier]
2727
n_estimators: [5, 10]
28+
- module_name: catboost
29+
iterations: [50, 100]
30+
learning_rate: [0.05, 0.1]
31+
depth: [1, 10]
32+
l2_leaf_reg: [1, 5]
33+
loss_function: ["MultiLogloss"]
34+
random_seed: [42]
35+
embedder_config:
36+
- null
37+
- model_name: prajjwal1/bert-tiny
2838
- module_name: bert
2939
classification_model_config:
3040
- model_name: avsolatorio/GIST-small-Embedding-v0

0 commit comments

Comments
 (0)