Skip to content

Commit 2702e41

Browse files
authored
Testing/dump load for each module (#192)
* add tests for decision modules * fix codestyle * adjust how transformers and tokenizers are saved * add test for generic dumper * add tests for embedding modules * add tests for scoring modules * add test for regex * make `load` a classmethod * try to fix file exists error * try to fix pydantic error * try to fix windows cleanup error
1 parent 1014e83 commit 2702e41

31 files changed

+405
-63
lines changed

autointent/_dump_tools.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from autointent import Embedder, Ranker, VectorIndex
2222
from autointent.configs import CrossEncoderConfig, EmbedderConfig
23+
from autointent.context.optimization_info import Artifact
2324
from autointent.schemas import TagsList
2425

2526
ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
@@ -83,7 +84,7 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
8384
Dumper.make_subdirectories(path, exists_ok)
8485

8586
for key, val in attrs.items():
86-
if exclude and isinstance(val, tuple(exclude)):
87+
if isinstance(val, Artifact) or (exclude and isinstance(val, tuple(exclude))):
8788
continue
8889
if isinstance(val, TagsList):
8990
val.dump(path / Dumper.tags / key)

autointent/modules/base/_base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010
import numpy.typing as npt
11-
from typing_extensions import assert_never
11+
from typing_extensions import Self, assert_never
1212

1313
from autointent._dump_tools import Dumper
1414
from autointent.configs import CrossEncoderConfig, EmbedderConfig
@@ -83,20 +83,23 @@ def dump(self, path: str) -> None:
8383
"""
8484
Dumper.dump(self, Path(path))
8585

86+
@classmethod
8687
def load(
87-
self,
88+
cls,
8889
path: str,
8990
embedder_config: EmbedderConfig | None = None,
9091
cross_encoder_config: CrossEncoderConfig | None = None,
91-
) -> None:
92+
) -> Self:
9293
"""Load data from file system.
9394
9495
Args:
9596
path: Path to load
9697
embedder_config: one can override presaved settings
9798
cross_encoder_config: one can override presaved settings
9899
"""
99-
Dumper.load(self, Path(path), embedder_config=embedder_config, cross_encoder_config=cross_encoder_config)
100+
instance = cls()
101+
Dumper.load(instance, Path(path), embedder_config=embedder_config, cross_encoder_config=cross_encoder_config)
102+
return instance
100103

101104
@abstractmethod
102105
def predict(

autointent/modules/embedding/_logreg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class LogregAimedEmbedding(BaseEmbedding):
5050

5151
def __init__(
5252
self,
53-
embedder_config: EmbedderConfig | str | dict[str, Any],
53+
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
5454
cv: PositiveInt = 3,
5555
) -> None:
5656
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
@@ -64,7 +64,7 @@ def __init__(
6464
def from_context(
6565
cls,
6666
context: Context,
67-
embedder_config: EmbedderConfig | str,
67+
embedder_config: EmbedderConfig | str | None = None,
6868
cv: PositiveInt = 3,
6969
) -> "LogregAimedEmbedding":
7070
"""Create a LogregAimedEmbedding instance using a Context object.

autointent/modules/embedding/_retrieval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class RetrievalAimedEmbedding(BaseEmbedding):
4646

4747
def __init__(
4848
self,
49-
embedder_config: EmbedderConfig | str | dict[str, Any],
49+
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
5050
k: PositiveInt = 10,
5151
) -> None:
5252
self.k = k
@@ -61,7 +61,7 @@ def __init__(
6161
def from_context(
6262
cls,
6363
context: Context,
64-
embedder_config: EmbedderConfig | str,
64+
embedder_config: EmbedderConfig | str | None = None,
6565
k: PositiveInt = 10,
6666
) -> "RetrievalAimedEmbedding":
6767
"""Create an instance using a Context object.

autointent/modules/regex/_simple.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,17 @@ def dump(self, path: str) -> None:
253253
with (dump_dir / "regex_patterns.json").open("w") as file:
254254
json.dump(serialized, file, indent=4, ensure_ascii=False)
255255

256+
@classmethod
256257
def load(
257-
self,
258+
cls,
258259
path: str,
259260
embedder_config: EmbedderConfig | None = None,
260261
cross_encoder_config: CrossEncoderConfig | None = None,
261-
) -> None:
262+
) -> "SimpleRegex":
263+
instance = cls()
264+
262265
with (Path(path) / "regex_patterns.json").open() as file:
263266
serialized: list[dict[str, Any]] = json.load(file)
264267

265-
self._compile_regex_patterns(serialized)
268+
instance._compile_regex_patterns(serialized) # noqa: SLF001
269+
return instance

autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class DNNCScorer(BaseScorer):
6161

6262
def __init__(
6363
self,
64-
k: PositiveInt,
64+
k: PositiveInt = 5,
6565
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None,
6666
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
6767
) -> None:
@@ -77,7 +77,7 @@ def __init__(
7777
def from_context(
7878
cls,
7979
context: Context,
80-
k: PositiveInt,
80+
k: PositiveInt = 5,
8181
cross_encoder_config: CrossEncoderConfig | str | None = None,
8282
embedder_config: EmbedderConfig | str | None = None,
8383
) -> "DNNCScorer":

autointent/modules/scoring/_knn/knn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class KNNScorer(BaseScorer):
5656

5757
def __init__(
5858
self,
59-
k: PositiveInt,
59+
k: PositiveInt = 5,
6060
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
6161
weights: WeightType = "distance",
6262
) -> None:
@@ -76,7 +76,7 @@ def __init__(
7676
def from_context(
7777
cls,
7878
context: Context,
79-
k: PositiveInt,
79+
k: PositiveInt = 5,
8080
weights: WeightType = "distance",
8181
embedder_config: EmbedderConfig | str | None = None,
8282
) -> "KNNScorer":

autointent/modules/scoring/_knn/rerank_scorer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ class RerankScorer(KNNScorer):
3636

3737
def __init__(
3838
self,
39-
k: int,
39+
k: PositiveInt = 5,
4040
weights: WeightType = "distance",
4141
use_crosencoder_scores: bool = False,
42-
m: int | None = None,
42+
m: PositiveInt | None = None,
4343
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None,
4444
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
4545
) -> None:
@@ -62,7 +62,7 @@ def __init__(
6262
def from_context(
6363
cls,
6464
context: Context,
65-
k: int,
65+
k: PositiveInt = 5,
6666
weights: WeightType = "distance",
6767
m: PositiveInt | None = None,
6868
cross_encoder_config: CrossEncoderConfig | str | None = None,

autointent/modules/scoring/_mlknn/mlknn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class MLKnnScorer(BaseScorer):
6363

6464
def __init__(
6565
self,
66-
k: PositiveInt,
66+
k: PositiveInt = 5,
6767
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
6868
s: float = 1.0,
6969
ignore_first_neighbours: int = 0,
@@ -84,7 +84,7 @@ def __init__(
8484
def from_context(
8585
cls,
8686
context: Context,
87-
k: PositiveInt,
87+
k: PositiveInt = 5,
8888
s: PositiveFloat = 1.0,
8989
ignore_first_neighbours: NonNegativeInt = 0,
9090
embedder_config: EmbedderConfig | str | None = None,

autointent/modules/scoring/_sklearn/sklearn_scorer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class SklearnScorer(BaseScorer):
5959

6060
def __init__(
6161
self,
62-
clf_name: str,
62+
clf_name: str = "LogisticRegression",
6363
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
6464
**clf_args: Any, # noqa: ANN401
6565
) -> None:
@@ -83,7 +83,7 @@ def __init__(
8383
def from_context(
8484
cls,
8585
context: Context,
86-
clf_name: str,
86+
clf_name: str = "LogisticRegression",
8787
embedder_config: EmbedderConfig | str | None = None,
8888
**clf_args: float | str | bool,
8989
) -> Self:

0 commit comments

Comments
 (0)