Skip to content

Commit c381bbc

Browse files
committed
refactor catboost scorer
1 parent 1ccb642 commit c381bbc

File tree

1 file changed

+58
-148
lines changed

1 file changed

+58
-148
lines changed

autointent/modules/scoring/_catboost/catboost_scorer.py

Lines changed: 58 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,15 @@
33
import json
44
import shutil
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, Any, cast
7-
8-
if TYPE_CHECKING:
9-
from collections.abc import Sequence
6+
from typing import Any, cast
107

118
import numpy as np
129
import numpy.typing as npt
13-
import torch
1410
from catboost import CatBoostClassifier, Pool # type: ignore[import-untyped]
1511
from catboost.text_processing import Dictionary, Tokenizer # type: ignore[import-untyped]
16-
from transformers import AutoModel, AutoTokenizer # type: ignore[attr-defined]
1712

18-
from autointent import Context
19-
from autointent.configs import EmbedderConfig
13+
from autointent import Context, Embedder
14+
from autointent.configs import EmbedderConfig, TaskTypeEnum
2015
from autointent.custom_types import ListOfLabels
2116
from autointent.modules.base import BaseScorer
2217

@@ -29,7 +24,7 @@ class CatBoostScorer(BaseScorer):
2924
"""CatBoost scorer using either external embeddings or CatBoost's own BoW encoding.
3025
3126
Args:
32-
classification_model_config: Config of the base transformer model (HFModelConfig, str, or dict)
27+
embedder_config: Config of the base transformer model (HFModelConfig, str, or dict)
3328
If None (default) the scorer relies on CatBoost's own Bag-of-Words encoding,
3429
otherwise the provided embedder is used.
3530
iterations: Number of boosting iterations.
@@ -77,43 +72,44 @@ class CatBoostScorer(BaseScorer):
7772

7873
def __init__(
7974
self,
80-
classification_model_config: EmbedderConfig | str | dict[str, Any] | None = None,
75+
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
8176
iterations: int = 100,
8277
learning_rate: float = 0.1,
8378
loss_function: str | None = None,
8479
random_seed: int = 0,
8580
verbose: bool = False,
8681
**catboost_kwargs: Any, # noqa: ANN401
8782
) -> None:
88-
self.classification_model_config = EmbedderConfig.from_search_config(classification_model_config)
89-
self._use_embedder = classification_model_config is not None
83+
self._use_embedder = embedder_config is not None
84+
if self._use_embedder:
85+
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
86+
self._embedder = Embedder(self.embedder_config)
87+
else:
88+
self._init_catboost_text_tools()
9089
self.iterations = iterations
9190
self.learning_rate = learning_rate
9291
self.loss_function = loss_function
9392
self.random_seed = random_seed
9493
self.verbose = verbose
9594
self.catboost_kwargs = catboost_kwargs
9695
self._model: CatBoostClassifier
97-
self._embedder: Any
98-
self._tokenizer: Tokenizer
99-
self._dictionary: Dictionary
10096

10197
@classmethod
10298
def from_context(
10399
cls,
104100
context: Context,
105-
classification_model_config: EmbedderConfig | str | dict[str, Any] | None = None,
101+
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
106102
iterations: int = 100,
107103
learning_rate: float = 0.1,
108104
loss_function: str | None = None,
109105
random_seed: int = 0,
110106
verbose: bool = False,
111107
**catboost_kwargs: Any, # noqa: ANN401
112108
) -> "CatBoostScorer":
113-
if classification_model_config is None:
114-
classification_model_config = context.resolve_embedder()
109+
if embedder_config is None:
110+
embedder_config = context.resolve_embedder()
115111
return cls(
116-
classification_model_config=classification_model_config,
112+
embedder_config=embedder_config,
117113
iterations=iterations,
118114
learning_rate=learning_rate,
119115
loss_function=loss_function,
@@ -122,68 +118,29 @@ def from_context(
122118
**catboost_kwargs,
123119
)
124120

125-
def get_classification_model_config(self) -> dict[str, Any]:
126-
return self.classification_model_config.model_dump()
127-
128-
def get_implicit_initialization_params(self) -> dict[str, Any]:
129-
return {
130-
"classification_model_config": self.classification_model_config.model_dump(),
131-
}
132-
133-
def _load_embedder(self) -> Any: # noqa: ANN401
134-
if getattr(self, "_embedder", None) is not None:
135-
return self._embedder
136-
cfg = self.classification_model_config
137-
if hasattr(cfg, "encode"):
138-
self._embedder = cfg
139-
return self._embedder
140-
141-
model_name = getattr(cfg, "model_name", None)
142-
if model_name is None and hasattr(cfg, "model_dump"):
143-
model_name = cfg.model_dump().get("model_name")
144-
tokenizer = AutoTokenizer.from_pretrained(model_name)
145-
model = AutoModel.from_pretrained(model_name)
146-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147-
model.to(device).eval()
148-
149-
raw_max = getattr(tokenizer, "model_max_length", None)
150-
max_len = (
151-
DEFAULT_TOKEN_LENGTH
152-
if not isinstance(raw_max, int) or raw_max <= 0 or raw_max > MAX_TOKEN_LENGTH
153-
else raw_max
154-
)
155-
156-
def encode(texts: list[str]) -> npt.NDArray[np.float32]:
157-
with torch.no_grad():
158-
batch = tokenizer(
159-
texts,
160-
padding=True,
161-
truncation=True,
162-
max_length=max_len,
163-
return_tensors="pt",
164-
)
165-
batch = {k: v.to(device) for k, v in batch.items()}
166-
outputs = model(**batch)
167-
embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
168-
return np.array(embeddings, dtype=np.float32)
169-
170-
self._embedder = encode
171-
return self._embedder
172-
173-
def _init_text_tools(self) -> None:
121+
def _init_catboost_text_tools(self) -> None:
174122
if not hasattr(self, "_tokenizer"):
175123
self._tokenizer = Tokenizer(lowercasing=True, separator_type="BySense", token_types=["Word", "Number"])
176124
if not hasattr(self, "_dictionary"):
177125
self._dictionary = Dictionary(occurence_lower_bound=1, gram_order=1)
126+
if not hasattr(self, "_dictionary_fitted"):
127+
self._dictionary_fitted = False
128+
129+
def get_embedder_config(self) -> dict[str, Any]:
130+
return self.embedder_config.model_dump()
131+
132+
def get_implicit_initialization_params(self) -> dict[str, Any]:
133+
return {
134+
"embedder_config": self.embedder_config.model_dump(),
135+
}
178136

179137
def _encode_utterances(self, utterances: list[str]) -> npt.NDArray[np.float32]:
180138
if self._use_embedder:
181-
embedder = self._load_embedder()
182-
vecs = embedder.encode(utterances) if hasattr(embedder, "encode") else embedder(utterances)
139+
vecs = self._embedder.embed(utterances, task_type=TaskTypeEnum.classification)
183140
return np.asarray(vecs, dtype=np.float32)
184-
self._init_text_tools()
141+
185142
tokenized = [self._tokenizer.tokenize(u) for u in utterances]
186-
if not hasattr(self, "_dictionary_fitted"):
143+
if not self._dictionary_fitted:
187144
self._dictionary.fit(tokenized)
188145
self._dictionary_fitted = True
189146

@@ -205,15 +162,7 @@ def fit(
205162
self._validate_task(labels)
206163

207164
x = self._encode_utterances(utterances)
208-
y: npt.NDArray[np.float32] | npt.NDArray[np.int64]
209-
if self._multilabel:
210-
y_mat = np.zeros((len(labels), self._n_classes), dtype=np.float32)
211-
for i, lbls in enumerate(cast("Sequence[Sequence[int]]", labels)):
212-
for class_i, lbl in enumerate(lbls):
213-
y_mat[i, class_i] = lbl
214-
y = y_mat
215-
else:
216-
y = np.asarray(cast("Sequence[int]", labels), dtype=np.int64)
165+
y = np.asarray(labels, dtype=np.float32)
217166

218167
default_loss = (
219168
"MultiLogloss"
@@ -243,12 +192,8 @@ def clear_cache(self) -> None:
243192
del self._model
244193
if hasattr(self, "_embedder"):
245194
del self._embedder
246-
if hasattr(self, "_tokenizer"):
247-
del self._tokenizer
248-
if hasattr(self, "_dictionary"):
249-
del self._dictionary
250195

251-
def dump(self, path: str) -> None: # noqa: C901
196+
def dump(self, path: str) -> None:
252197
"""Save scorer and all artefacts needed for inference to path."""
253198
root = Path(path)
254199
if root.exists():
@@ -257,7 +202,7 @@ def dump(self, path: str) -> None: # noqa: C901
257202

258203
simple_attrs: dict[str, Any] = {}
259204
for k, v in vars(self).items():
260-
if k in {"_model", "_dictionary", "_tokenizer"}:
205+
if k in {"_model", "_dictionary", "_tokenizer", "_embedder"}:
261206
continue
262207
if isinstance(v, EmbedderConfig):
263208
simple_attrs[k] = v.model_dump()
@@ -270,25 +215,19 @@ def dump(self, path: str) -> None: # noqa: C901
270215
if hasattr(self, "_model"):
271216
self._model.save_model(str(root / "model.cbm"))
272217

273-
if hasattr(self, "_dictionary"):
274-
dict_dir = root / "dictionary"
275-
dict_dir.mkdir()
276-
self._dictionary.save(str(dict_dir / "dictionary.tsv"))
277-
278-
if hasattr(self, "_tokenizer"):
279-
tok_params = {
280-
"lowercasing": getattr(self._tokenizer, "lowercasing", True),
281-
"separator_type": getattr(self._tokenizer, "separator_type", "BySense"),
282-
"token_types": getattr(self._tokenizer, "token_types", ["Word", "Number"]),
283-
}
284-
(root / "tokenizer_params.json").write_text(json.dumps(tok_params), encoding="utf-8")
285-
286-
if self._use_embedder and hasattr(self, "_embedder"):
287-
obj = getattr(self._embedder, "__self__", self._embedder)
288-
if hasattr(obj, "save_pretrained"):
289-
obj.save_pretrained(str(root / "hf_model"))
290-
if hasattr(self, "_tokenizer") and hasattr(self._tokenizer, "save_pretrained"):
291-
self._tokenizer.save_pretrained(str(root / "hf_tokenizer"))
218+
if not self._use_embedder:
219+
if hasattr(self, "_dictionary"):
220+
dict_dir = root / "dictionary"
221+
dict_dir.mkdir()
222+
self._dictionary.save(str(dict_dir / "dictionary.tsv"))
223+
224+
if hasattr(self, "_tokenizer"):
225+
tok_params = {
226+
"lowercasing": getattr(self._tokenizer, "lowercasing", True),
227+
"separator_type": getattr(self._tokenizer, "separator_type", "BySense"),
228+
"token_types": getattr(self._tokenizer, "token_types", ["Word", "Number"]),
229+
}
230+
(root / "tokenizer_params.json").write_text(json.dumps(tok_params), encoding="utf-8")
292231

293232
@classmethod
294233
def load(
@@ -304,7 +243,7 @@ def load(
304243
cfg = EmbedderConfig.model_validate(cfg_dict)
305244

306245
scorer = cls(
307-
classification_model_config=cfg,
246+
embedder_config=cfg,
308247
iterations=simple_attrs["iterations"],
309248
learning_rate=simple_attrs["learning_rate"],
310249
loss_function=simple_attrs["loss_function"],
@@ -317,50 +256,21 @@ def load(
317256
scorer._n_classes = simple_attrs.get("_n_classes") # noqa: SLF001
318257
scorer._multilabel = simple_attrs.get("_multilabel") # noqa: SLF001
319258

259+
if not scorer._use_embedder: # noqa: SLF001
260+
scorer._init_catboost_text_tools() # noqa: SLF001
261+
dict_file = root / "dictionary" / "dictionary.tsv"
262+
if dict_file.exists():
263+
scorer._dictionary.load(str(dict_file)) # noqa: SLF001
264+
scorer._dictionary_fitted = simple_attrs.get("_dictionary_fitted", True) # noqa: SLF001
265+
266+
tok_params_file = root / "tokenizer_params.json"
267+
if tok_params_file.exists():
268+
tok_params = json.loads(tok_params_file.read_text(encoding="utf-8"))
269+
scorer._tokenizer = Tokenizer(**tok_params) # noqa: SLF001
270+
320271
model_file = root / "model.cbm"
321272
if model_file.exists():
322273
scorer._model = CatBoostClassifier() # noqa: SLF001
323274
scorer._model.load_model(str(model_file)) # noqa: SLF001
324275

325-
dict_file = root / "dictionary" / "dictionary.tsv"
326-
if dict_file.exists():
327-
scorer._dictionary = Dictionary() # noqa: SLF001
328-
scorer._dictionary.load(str(dict_file)) # noqa: SLF001
329-
scorer._dictionary_fitted = simple_attrs.get("_dictionary_fitted", True) # noqa: SLF001
330-
331-
tok_params_file = root / "tokenizer_params.json"
332-
if tok_params_file.exists():
333-
tok_params = json.loads(tok_params_file.read_text(encoding="utf-8"))
334-
scorer._tokenizer = Tokenizer(**tok_params) # noqa: SLF001
335-
336-
if scorer._use_embedder: # noqa: SLF001
337-
emb_dir = root / "hf_model"
338-
if emb_dir.exists():
339-
tok_dir = root / "hf_tokenizer"
340-
scorer._tokenizer = AutoTokenizer.from_pretrained(str(tok_dir if tok_dir.exists() else emb_dir)) # noqa: SLF001
341-
model = AutoModel.from_pretrained(str(emb_dir)).to(
342-
torch.device("cuda" if torch.cuda.is_available() else "cpu")
343-
)
344-
model.eval()
345-
346-
raw_max = getattr(scorer._tokenizer, "model_max_length", None) # noqa: SLF001
347-
max_len = (
348-
DEFAULT_TOKEN_LENGTH
349-
if not isinstance(raw_max, int) or raw_max <= 0 or raw_max > MAX_TOKEN_LENGTH
350-
else raw_max
351-
)
352-
353-
def encode(texts: list[str]) -> npt.NDArray[np.float32]:
354-
with torch.no_grad():
355-
batch = scorer._tokenizer( # noqa: SLF001
356-
texts,
357-
padding=True,
358-
truncation=True,
359-
max_length=max_len,
360-
return_tensors="pt",
361-
).to(model.device)
362-
return model(**batch).last_hidden_state[:, 0, :].cpu().numpy().astype(np.float32) # type: ignore[no-any-return]
363-
364-
scorer._embedder = encode # noqa: SLF001
365-
366276
return scorer

0 commit comments

Comments
 (0)