33import json
44import shutil
55from pathlib import Path
6- from typing import TYPE_CHECKING , Any , cast
7-
8- if TYPE_CHECKING :
9- from collections .abc import Sequence
6+ from typing import Any , cast
107
118import numpy as np
129import numpy .typing as npt
13- import torch
1410from catboost import CatBoostClassifier , Pool # type: ignore[import-untyped]
1511from catboost .text_processing import Dictionary , Tokenizer # type: ignore[import-untyped]
16- from transformers import AutoModel , AutoTokenizer # type: ignore[attr-defined]
1712
18- from autointent import Context
19- from autointent .configs import EmbedderConfig
13+ from autointent import Context , Embedder
14+ from autointent .configs import EmbedderConfig , TaskTypeEnum
2015from autointent .custom_types import ListOfLabels
2116from autointent .modules .base import BaseScorer
2217
@@ -29,7 +24,7 @@ class CatBoostScorer(BaseScorer):
2924 """CatBoost scorer using either external embeddings or CatBoost's own BoW encoding.
3025
3126 Args:
32- classification_model_config : Config of the base transformer model (HFModelConfig, str, or dict)
27+ embedder_config : Config of the base transformer model (HFModelConfig, str, or dict)
3328 If None (default) the scorer relies on CatBoost's own Bag-of-Words encoding,
3429 otherwise the provided embedder is used.
3530 iterations: Number of boosting iterations.
@@ -77,43 +72,44 @@ class CatBoostScorer(BaseScorer):
7772
7873 def __init__ (
7974 self ,
80- classification_model_config : EmbedderConfig | str | dict [str , Any ] | None = None ,
75+ embedder_config : EmbedderConfig | str | dict [str , Any ] | None = None ,
8176 iterations : int = 100 ,
8277 learning_rate : float = 0.1 ,
8378 loss_function : str | None = None ,
8479 random_seed : int = 0 ,
8580 verbose : bool = False ,
8681 ** catboost_kwargs : Any , # noqa: ANN401
8782 ) -> None :
88- self .classification_model_config = EmbedderConfig .from_search_config (classification_model_config )
89- self ._use_embedder = classification_model_config is not None
83+ self ._use_embedder = embedder_config is not None
84+ if self ._use_embedder :
85+ self .embedder_config = EmbedderConfig .from_search_config (embedder_config )
86+ self ._embedder = Embedder (self .embedder_config )
87+ else :
88+ self ._init_catboost_text_tools ()
9089 self .iterations = iterations
9190 self .learning_rate = learning_rate
9291 self .loss_function = loss_function
9392 self .random_seed = random_seed
9493 self .verbose = verbose
9594 self .catboost_kwargs = catboost_kwargs
9695 self ._model : CatBoostClassifier
97- self ._embedder : Any
98- self ._tokenizer : Tokenizer
99- self ._dictionary : Dictionary
10096
10197 @classmethod
10298 def from_context (
10399 cls ,
104100 context : Context ,
105- classification_model_config : EmbedderConfig | str | dict [str , Any ] | None = None ,
101+ embedder_config : EmbedderConfig | str | dict [str , Any ] | None = None ,
106102 iterations : int = 100 ,
107103 learning_rate : float = 0.1 ,
108104 loss_function : str | None = None ,
109105 random_seed : int = 0 ,
110106 verbose : bool = False ,
111107 ** catboost_kwargs : Any , # noqa: ANN401
112108 ) -> "CatBoostScorer" :
113- if classification_model_config is None :
114- classification_model_config = context .resolve_embedder ()
109+ if embedder_config is None :
110+ embedder_config = context .resolve_embedder ()
115111 return cls (
116- classification_model_config = classification_model_config ,
112+ embedder_config = embedder_config ,
117113 iterations = iterations ,
118114 learning_rate = learning_rate ,
119115 loss_function = loss_function ,
@@ -122,68 +118,29 @@ def from_context(
122118 ** catboost_kwargs ,
123119 )
124120
125- def get_classification_model_config (self ) -> dict [str , Any ]:
126- return self .classification_model_config .model_dump ()
127-
128- def get_implicit_initialization_params (self ) -> dict [str , Any ]:
129- return {
130- "classification_model_config" : self .classification_model_config .model_dump (),
131- }
132-
133- def _load_embedder (self ) -> Any : # noqa: ANN401
134- if getattr (self , "_embedder" , None ) is not None :
135- return self ._embedder
136- cfg = self .classification_model_config
137- if hasattr (cfg , "encode" ):
138- self ._embedder = cfg
139- return self ._embedder
140-
141- model_name = getattr (cfg , "model_name" , None )
142- if model_name is None and hasattr (cfg , "model_dump" ):
143- model_name = cfg .model_dump ().get ("model_name" )
144- tokenizer = AutoTokenizer .from_pretrained (model_name )
145- model = AutoModel .from_pretrained (model_name )
146- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
147- model .to (device ).eval ()
148-
149- raw_max = getattr (tokenizer , "model_max_length" , None )
150- max_len = (
151- DEFAULT_TOKEN_LENGTH
152- if not isinstance (raw_max , int ) or raw_max <= 0 or raw_max > MAX_TOKEN_LENGTH
153- else raw_max
154- )
155-
156- def encode (texts : list [str ]) -> npt .NDArray [np .float32 ]:
157- with torch .no_grad ():
158- batch = tokenizer (
159- texts ,
160- padding = True ,
161- truncation = True ,
162- max_length = max_len ,
163- return_tensors = "pt" ,
164- )
165- batch = {k : v .to (device ) for k , v in batch .items ()}
166- outputs = model (** batch )
167- embeddings = outputs .last_hidden_state [:, 0 , :].cpu ().numpy ()
168- return np .array (embeddings , dtype = np .float32 )
169-
170- self ._embedder = encode
171- return self ._embedder
172-
173- def _init_text_tools (self ) -> None :
121+ def _init_catboost_text_tools (self ) -> None :
174122 if not hasattr (self , "_tokenizer" ):
175123 self ._tokenizer = Tokenizer (lowercasing = True , separator_type = "BySense" , token_types = ["Word" , "Number" ])
176124 if not hasattr (self , "_dictionary" ):
177125 self ._dictionary = Dictionary (occurence_lower_bound = 1 , gram_order = 1 )
126+ if not hasattr (self , "_dictionary_fitted" ):
127+ self ._dictionary_fitted = False
128+
129+ def get_embedder_config (self ) -> dict [str , Any ]:
130+ return self .embedder_config .model_dump ()
131+
132+ def get_implicit_initialization_params (self ) -> dict [str , Any ]:
133+ return {
134+ "embedder_config" : self .embedder_config .model_dump (),
135+ }
178136
179137 def _encode_utterances (self , utterances : list [str ]) -> npt .NDArray [np .float32 ]:
180138 if self ._use_embedder :
181- embedder = self ._load_embedder ()
182- vecs = embedder .encode (utterances ) if hasattr (embedder , "encode" ) else embedder (utterances )
139+ vecs = self ._embedder .embed (utterances , task_type = TaskTypeEnum .classification )
183140 return np .asarray (vecs , dtype = np .float32 )
184- self . _init_text_tools ()
141+
185142 tokenized = [self ._tokenizer .tokenize (u ) for u in utterances ]
186- if not hasattr ( self , " _dictionary_fitted" ) :
143+ if not self . _dictionary_fitted :
187144 self ._dictionary .fit (tokenized )
188145 self ._dictionary_fitted = True
189146
@@ -205,15 +162,7 @@ def fit(
205162 self ._validate_task (labels )
206163
207164 x = self ._encode_utterances (utterances )
208- y : npt .NDArray [np .float32 ] | npt .NDArray [np .int64 ]
209- if self ._multilabel :
210- y_mat = np .zeros ((len (labels ), self ._n_classes ), dtype = np .float32 )
211- for i , lbls in enumerate (cast ("Sequence[Sequence[int]]" , labels )):
212- for class_i , lbl in enumerate (lbls ):
213- y_mat [i , class_i ] = lbl
214- y = y_mat
215- else :
216- y = np .asarray (cast ("Sequence[int]" , labels ), dtype = np .int64 )
165+ y = np .asarray (labels , dtype = np .float32 )
217166
218167 default_loss = (
219168 "MultiLogloss"
@@ -243,12 +192,8 @@ def clear_cache(self) -> None:
243192 del self ._model
244193 if hasattr (self , "_embedder" ):
245194 del self ._embedder
246- if hasattr (self , "_tokenizer" ):
247- del self ._tokenizer
248- if hasattr (self , "_dictionary" ):
249- del self ._dictionary
250195
251- def dump (self , path : str ) -> None : # noqa: C901
196+ def dump (self , path : str ) -> None :
252197 """Save scorer and all artefacts needed for inference to path."""
253198 root = Path (path )
254199 if root .exists ():
@@ -257,7 +202,7 @@ def dump(self, path: str) -> None: # noqa: C901
257202
258203 simple_attrs : dict [str , Any ] = {}
259204 for k , v in vars (self ).items ():
260- if k in {"_model" , "_dictionary" , "_tokenizer" }:
205+ if k in {"_model" , "_dictionary" , "_tokenizer" , "_embedder" }:
261206 continue
262207 if isinstance (v , EmbedderConfig ):
263208 simple_attrs [k ] = v .model_dump ()
@@ -270,25 +215,19 @@ def dump(self, path: str) -> None: # noqa: C901
270215 if hasattr (self , "_model" ):
271216 self ._model .save_model (str (root / "model.cbm" ))
272217
273- if hasattr (self , "_dictionary" ):
274- dict_dir = root / "dictionary"
275- dict_dir .mkdir ()
276- self ._dictionary .save (str (dict_dir / "dictionary.tsv" ))
277-
278- if hasattr (self , "_tokenizer" ):
279- tok_params = {
280- "lowercasing" : getattr (self ._tokenizer , "lowercasing" , True ),
281- "separator_type" : getattr (self ._tokenizer , "separator_type" , "BySense" ),
282- "token_types" : getattr (self ._tokenizer , "token_types" , ["Word" , "Number" ]),
283- }
284- (root / "tokenizer_params.json" ).write_text (json .dumps (tok_params ), encoding = "utf-8" )
285-
286- if self ._use_embedder and hasattr (self , "_embedder" ):
287- obj = getattr (self ._embedder , "__self__" , self ._embedder )
288- if hasattr (obj , "save_pretrained" ):
289- obj .save_pretrained (str (root / "hf_model" ))
290- if hasattr (self , "_tokenizer" ) and hasattr (self ._tokenizer , "save_pretrained" ):
291- self ._tokenizer .save_pretrained (str (root / "hf_tokenizer" ))
218+ if not self ._use_embedder :
219+ if hasattr (self , "_dictionary" ):
220+ dict_dir = root / "dictionary"
221+ dict_dir .mkdir ()
222+ self ._dictionary .save (str (dict_dir / "dictionary.tsv" ))
223+
224+ if hasattr (self , "_tokenizer" ):
225+ tok_params = {
226+ "lowercasing" : getattr (self ._tokenizer , "lowercasing" , True ),
227+ "separator_type" : getattr (self ._tokenizer , "separator_type" , "BySense" ),
228+ "token_types" : getattr (self ._tokenizer , "token_types" , ["Word" , "Number" ]),
229+ }
230+ (root / "tokenizer_params.json" ).write_text (json .dumps (tok_params ), encoding = "utf-8" )
292231
293232 @classmethod
294233 def load (
@@ -304,7 +243,7 @@ def load(
304243 cfg = EmbedderConfig .model_validate (cfg_dict )
305244
306245 scorer = cls (
307- classification_model_config = cfg ,
246+ embedder_config = cfg ,
308247 iterations = simple_attrs ["iterations" ],
309248 learning_rate = simple_attrs ["learning_rate" ],
310249 loss_function = simple_attrs ["loss_function" ],
@@ -317,50 +256,21 @@ def load(
317256 scorer ._n_classes = simple_attrs .get ("_n_classes" ) # noqa: SLF001
318257 scorer ._multilabel = simple_attrs .get ("_multilabel" ) # noqa: SLF001
319258
259+ if not scorer ._use_embedder : # noqa: SLF001
260+ scorer ._init_catboost_text_tools () # noqa: SLF001
261+ dict_file = root / "dictionary" / "dictionary.tsv"
262+ if dict_file .exists ():
263+ scorer ._dictionary .load (str (dict_file )) # noqa: SLF001
264+ scorer ._dictionary_fitted = simple_attrs .get ("_dictionary_fitted" , True ) # noqa: SLF001
265+
266+ tok_params_file = root / "tokenizer_params.json"
267+ if tok_params_file .exists ():
268+ tok_params = json .loads (tok_params_file .read_text (encoding = "utf-8" ))
269+ scorer ._tokenizer = Tokenizer (** tok_params ) # noqa: SLF001
270+
320271 model_file = root / "model.cbm"
321272 if model_file .exists ():
322273 scorer ._model = CatBoostClassifier () # noqa: SLF001
323274 scorer ._model .load_model (str (model_file )) # noqa: SLF001
324275
325- dict_file = root / "dictionary" / "dictionary.tsv"
326- if dict_file .exists ():
327- scorer ._dictionary = Dictionary () # noqa: SLF001
328- scorer ._dictionary .load (str (dict_file )) # noqa: SLF001
329- scorer ._dictionary_fitted = simple_attrs .get ("_dictionary_fitted" , True ) # noqa: SLF001
330-
331- tok_params_file = root / "tokenizer_params.json"
332- if tok_params_file .exists ():
333- tok_params = json .loads (tok_params_file .read_text (encoding = "utf-8" ))
334- scorer ._tokenizer = Tokenizer (** tok_params ) # noqa: SLF001
335-
336- if scorer ._use_embedder : # noqa: SLF001
337- emb_dir = root / "hf_model"
338- if emb_dir .exists ():
339- tok_dir = root / "hf_tokenizer"
340- scorer ._tokenizer = AutoTokenizer .from_pretrained (str (tok_dir if tok_dir .exists () else emb_dir )) # noqa: SLF001
341- model = AutoModel .from_pretrained (str (emb_dir )).to (
342- torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
343- )
344- model .eval ()
345-
346- raw_max = getattr (scorer ._tokenizer , "model_max_length" , None ) # noqa: SLF001
347- max_len = (
348- DEFAULT_TOKEN_LENGTH
349- if not isinstance (raw_max , int ) or raw_max <= 0 or raw_max > MAX_TOKEN_LENGTH
350- else raw_max
351- )
352-
353- def encode (texts : list [str ]) -> npt .NDArray [np .float32 ]:
354- with torch .no_grad ():
355- batch = scorer ._tokenizer ( # noqa: SLF001
356- texts ,
357- padding = True ,
358- truncation = True ,
359- max_length = max_len ,
360- return_tensors = "pt" ,
361- ).to (model .device )
362- return model (** batch ).last_hidden_state [:, 0 , :].cpu ().numpy ().astype (np .float32 ) # type: ignore[no-any-return]
363-
364- scorer ._embedder = encode # noqa: SLF001
365-
366276 return scorer
0 commit comments