@@ -118,6 +118,7 @@ def from_context(
118118 ** catboost_kwargs ,
119119 )
120120
121+
121122 def _init_catboost_text_tools (self ) -> None :
122123 if not hasattr (self , "_tokenizer" ):
123124 self ._tokenizer = Tokenizer (lowercasing = True , separator_type = "BySense" , token_types = ["Word" , "Number" ])
@@ -127,7 +128,9 @@ def _init_catboost_text_tools(self) -> None:
127128 self ._dictionary_fitted = False
128129
129130 def get_embedder_config (self ) -> dict [str , Any ]:
130- return self .embedder_config .model_dump ()
131+ if self ._use_embedder :
132+ return self .embedder_config .model_dump ()
133+ return {}
131134
132135 def get_implicit_initialization_params (self ) -> dict [str , Any ]:
133136 return {
@@ -192,9 +195,14 @@ def clear_cache(self) -> None:
192195 del self ._model
193196 if hasattr (self , "_embedder" ):
194197 del self ._embedder
198+ if hasattr (self , "_tokenizer" ):
199+ del self ._tokenizer
200+ if hasattr (self , "_dictionary" ):
201+ del self ._dictionary
202+ if hasattr (self , "_dictionary_fitted" ):
203+ del self ._dictionary_fitted
195204
196205 def dump (self , path : str ) -> None :
197- """Save scorer and all artefacts needed for inference to path."""
198206 root = Path (path )
199207 if root .exists ():
200208 shutil .rmtree (root )
@@ -235,15 +243,11 @@ def load(
235243 path : str ,
236244 embedder_config : EmbedderConfig | None = None ,
237245 ) -> "CatBoostScorer" :
238- """Load scorer dumped with :pymeth:`dump`."""
239246 root = Path (path )
240247 simple_attrs = json .loads ((root / "simple_attrs.json" ).read_text (encoding = "utf-8" ))
241248
242- cfg_dict = embedder_config .model_dump () if embedder_config else simple_attrs ["classification_model_config" ]
243- cfg = EmbedderConfig .model_validate (cfg_dict )
244-
245249 scorer = cls (
246- embedder_config = cfg ,
250+ embedder_config = embedder_config ,
247251 iterations = simple_attrs ["iterations" ],
248252 learning_rate = simple_attrs ["learning_rate" ],
249253 loss_function = simple_attrs ["loss_function" ],
@@ -256,8 +260,10 @@ def load(
256260 scorer ._n_classes = simple_attrs .get ("_n_classes" ) # noqa: SLF001
257261 scorer ._multilabel = simple_attrs .get ("_multilabel" ) # noqa: SLF001
258262
259- if not scorer ._use_embedder : # noqa: SLF001
260- scorer ._init_catboost_text_tools () # noqa: SLF001
263+ if scorer ._use_embedder :
264+ scorer ._embedder = Embedder (scorer .embedder_config )
265+ else :
266+ scorer ._init_catboost_text_tools ()
261267 dict_file = root / "dictionary" / "dictionary.tsv"
262268 if dict_file .exists ():
263269 scorer ._dictionary .load (str (dict_file )) # noqa: SLF001
0 commit comments