Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion autointent/_dump_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from autointent import Embedder, Ranker, VectorIndex
from autointent.configs import CrossEncoderConfig, EmbedderConfig
from autointent.context.optimization_info import Artifact
from autointent.schemas import TagsList

ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
Expand Down Expand Up @@ -83,7 +84,7 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
Dumper.make_subdirectories(path, exists_ok)

for key, val in attrs.items():
if exclude and isinstance(val, tuple(exclude)):
if isinstance(val, Artifact) or (exclude and isinstance(val, tuple(exclude))):
continue
if isinstance(val, TagsList):
val.dump(path / Dumper.tags / key)
Expand Down
11 changes: 7 additions & 4 deletions autointent/modules/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np
import numpy.typing as npt
from typing_extensions import assert_never
from typing_extensions import Self, assert_never

from autointent._dump_tools import Dumper
from autointent.configs import CrossEncoderConfig, EmbedderConfig
Expand Down Expand Up @@ -83,20 +83,23 @@ def dump(self, path: str) -> None:
"""
Dumper.dump(self, Path(path))

@classmethod
def load(
self,
cls,
path: str,
embedder_config: EmbedderConfig | None = None,
cross_encoder_config: CrossEncoderConfig | None = None,
) -> None:
) -> Self:
"""Load data from file system.

Args:
path: Path to load
embedder_config: one can override presaved settings
cross_encoder_config: one can override presaved settings
"""
Dumper.load(self, Path(path), embedder_config=embedder_config, cross_encoder_config=cross_encoder_config)
instance = cls()
Dumper.load(instance, Path(path), embedder_config=embedder_config, cross_encoder_config=cross_encoder_config)
return instance

@abstractmethod
def predict(
Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/embedding/_logreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class LogregAimedEmbedding(BaseEmbedding):

def __init__(
self,
embedder_config: EmbedderConfig | str | dict[str, Any],
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
cv: PositiveInt = 3,
) -> None:
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
Expand All @@ -64,7 +64,7 @@ def __init__(
def from_context(
cls,
context: Context,
embedder_config: EmbedderConfig | str,
embedder_config: EmbedderConfig | str | None = None,
cv: PositiveInt = 3,
) -> "LogregAimedEmbedding":
"""Create a LogregAimedEmbedding instance using a Context object.
Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/embedding/_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class RetrievalAimedEmbedding(BaseEmbedding):

def __init__(
self,
embedder_config: EmbedderConfig | str | dict[str, Any],
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
k: PositiveInt = 10,
) -> None:
self.k = k
Expand All @@ -61,7 +61,7 @@ def __init__(
def from_context(
cls,
context: Context,
embedder_config: EmbedderConfig | str,
embedder_config: EmbedderConfig | str | None = None,
k: PositiveInt = 10,
) -> "RetrievalAimedEmbedding":
"""Create an instance using a Context object.
Expand Down
10 changes: 7 additions & 3 deletions autointent/modules/regex/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,17 @@ def dump(self, path: str) -> None:
with (dump_dir / "regex_patterns.json").open("w") as file:
json.dump(serialized, file, indent=4, ensure_ascii=False)

@classmethod
def load(
self,
cls,
path: str,
embedder_config: EmbedderConfig | None = None,
cross_encoder_config: CrossEncoderConfig | None = None,
) -> None:
) -> "SimpleRegex":
instance = cls()

with (Path(path) / "regex_patterns.json").open() as file:
serialized: list[dict[str, Any]] = json.load(file)

self._compile_regex_patterns(serialized)
instance._compile_regex_patterns(serialized) # noqa: SLF001
return instance
4 changes: 2 additions & 2 deletions autointent/modules/scoring/_dnnc/dnnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class DNNCScorer(BaseScorer):

def __init__(
self,
k: PositiveInt,
k: PositiveInt = 5,
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None,
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
) -> None:
Expand All @@ -77,7 +77,7 @@ def __init__(
def from_context(
cls,
context: Context,
k: PositiveInt,
k: PositiveInt = 5,
cross_encoder_config: CrossEncoderConfig | str | None = None,
embedder_config: EmbedderConfig | str | None = None,
) -> "DNNCScorer":
Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/scoring/_knn/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class KNNScorer(BaseScorer):

def __init__(
self,
k: PositiveInt,
k: PositiveInt = 5,
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
weights: WeightType = "distance",
) -> None:
Expand All @@ -76,7 +76,7 @@ def __init__(
def from_context(
cls,
context: Context,
k: PositiveInt,
k: PositiveInt = 5,
weights: WeightType = "distance",
embedder_config: EmbedderConfig | str | None = None,
) -> "KNNScorer":
Expand Down
6 changes: 3 additions & 3 deletions autointent/modules/scoring/_knn/rerank_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ class RerankScorer(KNNScorer):

def __init__(
self,
k: int,
k: PositiveInt = 5,
weights: WeightType = "distance",
use_crosencoder_scores: bool = False,
m: int | None = None,
m: PositiveInt | None = None,
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None,
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
) -> None:
Expand All @@ -62,7 +62,7 @@ def __init__(
def from_context(
cls,
context: Context,
k: int,
k: PositiveInt = 5,
weights: WeightType = "distance",
m: PositiveInt | None = None,
cross_encoder_config: CrossEncoderConfig | str | None = None,
Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/scoring/_mlknn/mlknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class MLKnnScorer(BaseScorer):

def __init__(
self,
k: PositiveInt,
k: PositiveInt = 5,
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
s: float = 1.0,
ignore_first_neighbours: int = 0,
Expand All @@ -84,7 +84,7 @@ def __init__(
def from_context(
cls,
context: Context,
k: PositiveInt,
k: PositiveInt = 5,
s: PositiveFloat = 1.0,
ignore_first_neighbours: NonNegativeInt = 0,
embedder_config: EmbedderConfig | str | None = None,
Expand Down
4 changes: 2 additions & 2 deletions autointent/modules/scoring/_sklearn/sklearn_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class SklearnScorer(BaseScorer):

def __init__(
self,
clf_name: str,
clf_name: str = "LogisticRegression",
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
**clf_args: Any, # noqa: ANN401
) -> None:
Expand All @@ -83,7 +83,7 @@ def __init__(
def from_context(
cls,
context: Context,
clf_name: str,
clf_name: str = "LogisticRegression",
embedder_config: EmbedderConfig | str | None = None,
**clf_args: float | str | bool,
) -> Self:
Expand Down
4 changes: 2 additions & 2 deletions autointent/nodes/_inference_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def from_config(cls, config: InferenceNodeConfig) -> "InferenceNode":
config: Config to init from
"""
node_info = NODES_INFO[config.node_type]
module = node_info.modules_available[config.module_name](**config.module_config)
module.load(
module_cls = node_info.modules_available[config.module_name]
module = module_cls.load(
config.load_path,
embedder_config=getattr(config, "embedder_config", None),
cross_encoder_config=getattr(config, "cross_encoder_config", None),
Expand Down
9 changes: 7 additions & 2 deletions tests/modules/decision/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,14 @@ def test_dump_load(multilabel_fit_data):

path = setup_environment() / "adaptive_module"
predictor.dump(path)
del predictor

predictor = AdaptiveDecision.load(path)

assert hasattr(predictor, "_r")
assert predictor._r is not None
assert isinstance(predictor._r, float)

predictor = AdaptiveDecision()
predictor.load(path)
new_preds = predictor.predict(multilabel_fit_data[0])

assert all(p == n for p, n in zip(preds, new_preds, strict=True))
16 changes: 16 additions & 0 deletions tests/modules/decision/test_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from autointent.exceptions import MismatchNumClassesError, WrongClassificationError
from autointent.modules.decision import ArgmaxDecision
from tests.conftest import setup_environment


def test_multiclass(multiclass_fit_data, scores):
Expand All @@ -24,3 +25,18 @@ def test_fails_on_wrong_clf_problem(multilabel_fit_data):
predictor = ArgmaxDecision()
with pytest.raises(WrongClassificationError):
predictor.fit(*multilabel_fit_data)


def test_dump_load(multiclass_fit_data):
predictor = ArgmaxDecision()
predictor.fit(*multiclass_fit_data)
predictions = predictor.predict(multiclass_fit_data[0])

path = setup_environment() / "argmax_module"
predictor.dump(path)
del predictor

predictor = ArgmaxDecision.load(path)
new_predictions = predictor.predict(multiclass_fit_data[0])

assert all(p == n for p, n in zip(predictions, new_predictions, strict=True))
21 changes: 21 additions & 0 deletions tests/modules/decision/test_jinoos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from autointent.exceptions import MismatchNumClassesError, WrongClassificationError
from autointent.modules import JinoosDecision
from tests.conftest import setup_environment


def detect_oos(scores: npt.NDArray[Any], labels: npt.NDArray[Any], thresh: float):
Expand Down Expand Up @@ -40,3 +41,23 @@ def test_fails_on_wrong_clf_problem(multilabel_fit_data):
predictor = JinoosDecision()
with pytest.raises(WrongClassificationError):
predictor.fit(*multilabel_fit_data)


def test_dump_load(multiclass_fit_data):
predictor = JinoosDecision()
predictor.fit(*multiclass_fit_data)
predictions = predictor.predict(multiclass_fit_data[0])

path = setup_environment() / "jinoos_module"
predictor.dump(path)
del predictor

predictor = JinoosDecision.load(path)

assert hasattr(predictor, "_thresh")
assert predictor._thresh is not None
assert isinstance(predictor._thresh, float)

new_predictions = predictor.predict(multiclass_fit_data[0])

assert all(p == n for p, n in zip(predictions, new_predictions, strict=True))
22 changes: 22 additions & 0 deletions tests/modules/decision/test_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,25 @@ def test_fails_on_wrong_n_classes_fit(multiclass_fit_data):
predictor = ThresholdDecision(thresh=[0.5])
with pytest.raises(MismatchNumClassesError):
predictor.fit(*multiclass_fit_data)


@pytest.mark.parametrize("fit_fixture", ["multiclass_fit_data", "multilabel_fit_data"])
def test_dump_load(fit_fixture, request, tmp_path):
fit_data = request.getfixturevalue(fit_fixture)
predictor = ThresholdDecision(thresh=0.3)
predictor.fit(*fit_data)
predictions = predictor.predict(fit_data[0])

predictor.dump(tmp_path)
del predictor

predictor = ThresholdDecision.load(tmp_path)

assert hasattr(predictor, "thresh")
assert predictor.thresh is not None
assert predictor.thresh == 0.3
assert isinstance(predictor.thresh, float)

new_predictions = predictor.predict(fit_data[0])

assert all(p == n for p, n in zip(predictions, new_predictions, strict=True))
20 changes: 20 additions & 0 deletions tests/modules/decision/test_tunable.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,23 @@ def test_fails_on_wrong_n_classes_predict(multiclass_fit_data):
scores = np.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
with pytest.raises(MismatchNumClassesError):
predictor.predict(scores)


@pytest.mark.parametrize("fit_fixture", ["multiclass_fit_data", "multilabel_fit_data"])
def test_dump_load(fit_fixture, request, tmp_path):
fit_data = request.getfixturevalue(fit_fixture)
predictor = TunableDecision()
predictor.fit(*fit_data)
predictions = predictor.predict(fit_data[0])

predictor.dump(tmp_path)
del predictor

predictor = TunableDecision.load(tmp_path)
assert hasattr(predictor, "thresh")
assert predictor.thresh is not None
assert isinstance(predictor.thresh, np.ndarray)

new_predictions = predictor.predict(fit_data[0])

assert all(p == n for p, n in zip(predictions, new_predictions, strict=True))
19 changes: 19 additions & 0 deletions tests/modules/embedding/test_logreg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import numpy as np

from autointent.modules.embedding import LogregAimedEmbedding
from tests.conftest import setup_environment


def test_get_assets_returns_correct_artifact_for_logreg():
Expand Down Expand Up @@ -31,3 +34,19 @@ def test_predict_evaluates_model():
assert len(probas) == 2
assert probas[0][0] > probas[0][1]
assert probas[1][1] > probas[1][0]


def test_dump_load():
module = LogregAimedEmbedding(embedder_config="sergeyzh/rubert-tiny-turbo")
utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
module.fit(utterances, labels)
predictions = module.predict(["hello", "bye"])

dump_path = setup_environment()

module.dump(dump_path)
del module
module = LogregAimedEmbedding.load(dump_path)
predictions_loaded = module.predict(["hello", "bye"])
assert np.allclose(predictions, predictions_loaded)
Loading
Loading