Skip to content

Commit 624eec7

Browse files
committed
minor fix
1 parent c381bbc commit 624eec7

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

autointent/modules/scoring/_catboost/catboost_scorer.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def from_context(
118118
**catboost_kwargs,
119119
)
120120

121+
121122
def _init_catboost_text_tools(self) -> None:
122123
if not hasattr(self, "_tokenizer"):
123124
self._tokenizer = Tokenizer(lowercasing=True, separator_type="BySense", token_types=["Word", "Number"])
@@ -127,7 +128,9 @@ def _init_catboost_text_tools(self) -> None:
127128
self._dictionary_fitted = False
128129

129130
def get_embedder_config(self) -> dict[str, Any]:
130-
return self.embedder_config.model_dump()
131+
if self._use_embedder:
132+
return self.embedder_config.model_dump()
133+
return {}
131134

132135
def get_implicit_initialization_params(self) -> dict[str, Any]:
133136
return {
@@ -192,9 +195,14 @@ def clear_cache(self) -> None:
192195
del self._model
193196
if hasattr(self, "_embedder"):
194197
del self._embedder
198+
if hasattr(self, "_tokenizer"):
199+
del self._tokenizer
200+
if hasattr(self, "_dictionary"):
201+
del self._dictionary
202+
if hasattr(self, "_dictionary_fitted"):
203+
del self._dictionary_fitted
195204

196205
def dump(self, path: str) -> None:
197-
"""Save scorer and all artefacts needed for inference to path."""
198206
root = Path(path)
199207
if root.exists():
200208
shutil.rmtree(root)
@@ -235,15 +243,11 @@ def load(
235243
path: str,
236244
embedder_config: EmbedderConfig | None = None,
237245
) -> "CatBoostScorer":
238-
"""Load scorer dumped with :pymeth:`dump`."""
239246
root = Path(path)
240247
simple_attrs = json.loads((root / "simple_attrs.json").read_text(encoding="utf-8"))
241248

242-
cfg_dict = embedder_config.model_dump() if embedder_config else simple_attrs["classification_model_config"]
243-
cfg = EmbedderConfig.model_validate(cfg_dict)
244-
245249
scorer = cls(
246-
embedder_config=cfg,
250+
embedder_config=embedder_config,
247251
iterations=simple_attrs["iterations"],
248252
learning_rate=simple_attrs["learning_rate"],
249253
loss_function=simple_attrs["loss_function"],
@@ -256,8 +260,10 @@ def load(
256260
scorer._n_classes = simple_attrs.get("_n_classes") # noqa: SLF001
257261
scorer._multilabel = simple_attrs.get("_multilabel") # noqa: SLF001
258262

259-
if not scorer._use_embedder: # noqa: SLF001
260-
scorer._init_catboost_text_tools() # noqa: SLF001
263+
if scorer._use_embedder:
264+
scorer._embedder = Embedder(scorer.embedder_config)
265+
else:
266+
scorer._init_catboost_text_tools()
261267
dict_file = root / "dictionary" / "dictionary.tsv"
262268
if dict_file.exists():
263269
scorer._dictionary.load(str(dict_file)) # noqa: SLF001

0 commit comments

Comments
 (0)