diff --git a/autointent/_dump_tools.py b/autointent/_dump_tools.py index 645e2e2c0..730302b90 100644 --- a/autointent/_dump_tools.py +++ b/autointent/_dump_tools.py @@ -20,6 +20,7 @@ from autointent import Embedder, Ranker, VectorIndex from autointent.configs import CrossEncoderConfig, EmbedderConfig +from autointent.context.optimization_info import Artifact from autointent.schemas import TagsList ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg] @@ -83,7 +84,7 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]] Dumper.make_subdirectories(path, exists_ok) for key, val in attrs.items(): - if exclude and isinstance(val, tuple(exclude)): + if isinstance(val, Artifact) or (exclude and isinstance(val, tuple(exclude))): continue if isinstance(val, TagsList): val.dump(path / Dumper.tags / key) diff --git a/autointent/modules/base/_base.py b/autointent/modules/base/_base.py index 9cda8b5a3..d9a69d373 100644 --- a/autointent/modules/base/_base.py +++ b/autointent/modules/base/_base.py @@ -8,7 +8,7 @@ import numpy as np import numpy.typing as npt -from typing_extensions import assert_never +from typing_extensions import Self, assert_never from autointent._dump_tools import Dumper from autointent.configs import CrossEncoderConfig, EmbedderConfig @@ -83,12 +83,13 @@ def dump(self, path: str) -> None: """ Dumper.dump(self, Path(path)) + @classmethod def load( - self, + cls, path: str, embedder_config: EmbedderConfig | None = None, cross_encoder_config: CrossEncoderConfig | None = None, - ) -> None: + ) -> Self: """Load data from file system. Args: @@ -96,7 +97,9 @@ def load( embedder_config: one can override presaved settings cross_encoder_config: one can override presaved settings """ - Dumper.load(self, Path(path), embedder_config=embedder_config, cross_encoder_config=cross_encoder_config) + instance = cls() + Dumper.load(instance, Path(path), embedder_config=embedder_config, cross_encoder_config=cross_encoder_config) + return instance @abstractmethod def predict( diff --git a/autointent/modules/embedding/_logreg.py b/autointent/modules/embedding/_logreg.py index 7b08ae709..5b6a323a3 100644 --- a/autointent/modules/embedding/_logreg.py +++ b/autointent/modules/embedding/_logreg.py @@ -50,7 +50,7 @@ class LogregAimedEmbedding(BaseEmbedding): def __init__( self, - embedder_config: EmbedderConfig | str | dict[str, Any], + embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, cv: PositiveInt = 3, ) -> None: self.embedder_config = EmbedderConfig.from_search_config(embedder_config) @@ -64,7 +64,7 @@ def __init__( def from_context( cls, context: Context, - embedder_config: EmbedderConfig | str, + embedder_config: EmbedderConfig | str | None = None, cv: PositiveInt = 3, ) -> "LogregAimedEmbedding": """Create a LogregAimedEmbedding instance using a Context object. diff --git a/autointent/modules/embedding/_retrieval.py b/autointent/modules/embedding/_retrieval.py index 04f4d04c8..f3b7cd808 100644 --- a/autointent/modules/embedding/_retrieval.py +++ b/autointent/modules/embedding/_retrieval.py @@ -46,7 +46,7 @@ class RetrievalAimedEmbedding(BaseEmbedding): def __init__( self, - embedder_config: EmbedderConfig | str | dict[str, Any], + embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, k: PositiveInt = 10, ) -> None: self.k = k @@ -61,7 +61,7 @@ def __init__( def from_context( cls, context: Context, - embedder_config: EmbedderConfig | str, + embedder_config: EmbedderConfig | str | None = None, k: PositiveInt = 10, ) -> "RetrievalAimedEmbedding": """Create an instance using a Context object. diff --git a/autointent/modules/regex/_simple.py b/autointent/modules/regex/_simple.py index d83470740..ddc2d2385 100644 --- a/autointent/modules/regex/_simple.py +++ b/autointent/modules/regex/_simple.py @@ -253,13 +253,17 @@ def dump(self, path: str) -> None: with (dump_dir / "regex_patterns.json").open("w") as file: json.dump(serialized, file, indent=4, ensure_ascii=False) + @classmethod def load( - self, + cls, path: str, embedder_config: EmbedderConfig | None = None, cross_encoder_config: CrossEncoderConfig | None = None, - ) -> None: + ) -> "SimpleRegex": + instance = cls() + with (Path(path) / "regex_patterns.json").open() as file: serialized: list[dict[str, Any]] = json.load(file) - self._compile_regex_patterns(serialized) + instance._compile_regex_patterns(serialized) # noqa: SLF001 + return instance diff --git a/autointent/modules/scoring/_dnnc/dnnc.py b/autointent/modules/scoring/_dnnc/dnnc.py index 48792dbe9..86a4b60ef 100644 --- a/autointent/modules/scoring/_dnnc/dnnc.py +++ b/autointent/modules/scoring/_dnnc/dnnc.py @@ -61,7 +61,7 @@ class DNNCScorer(BaseScorer): def __init__( self, - k: PositiveInt, + k: PositiveInt = 5, cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None, embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, ) -> None: @@ -77,7 +77,7 @@ def __init__( def from_context( cls, context: Context, - k: PositiveInt, + k: PositiveInt = 5, cross_encoder_config: CrossEncoderConfig | str | None = None, embedder_config: EmbedderConfig | str | None = None, ) -> "DNNCScorer": diff --git a/autointent/modules/scoring/_knn/knn.py b/autointent/modules/scoring/_knn/knn.py index bad7a6ef8..598ebbe4f 100644 --- a/autointent/modules/scoring/_knn/knn.py +++ b/autointent/modules/scoring/_knn/knn.py @@ -56,7 +56,7 @@ class KNNScorer(BaseScorer): def __init__( self, - k: PositiveInt, + k: PositiveInt = 5, embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, weights: WeightType = "distance", ) -> None: @@ -76,7 +76,7 @@ def __init__( def from_context( cls, context: Context, - k: PositiveInt, + k: PositiveInt = 5, weights: WeightType = "distance", embedder_config: EmbedderConfig | str | None = None, ) -> "KNNScorer": diff --git a/autointent/modules/scoring/_knn/rerank_scorer.py b/autointent/modules/scoring/_knn/rerank_scorer.py index 09ecab10b..70981fb0b 100644 --- a/autointent/modules/scoring/_knn/rerank_scorer.py +++ b/autointent/modules/scoring/_knn/rerank_scorer.py @@ -36,10 +36,10 @@ class RerankScorer(KNNScorer): def __init__( self, - k: int, + k: PositiveInt = 5, weights: WeightType = "distance", use_crosencoder_scores: bool = False, - m: int | None = None, + m: PositiveInt | None = None, cross_encoder_config: CrossEncoderConfig | str | dict[str, Any] | None = None, embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, ) -> None: @@ -62,7 +62,7 @@ def __init__( def from_context( cls, context: Context, - k: int, + k: PositiveInt = 5, weights: WeightType = "distance", m: PositiveInt | None = None, cross_encoder_config: CrossEncoderConfig | str | None = None, diff --git a/autointent/modules/scoring/_mlknn/mlknn.py b/autointent/modules/scoring/_mlknn/mlknn.py index 3021d9349..238ea518d 100644 --- a/autointent/modules/scoring/_mlknn/mlknn.py +++ b/autointent/modules/scoring/_mlknn/mlknn.py @@ -63,7 +63,7 @@ class MLKnnScorer(BaseScorer): def __init__( self, - k: PositiveInt, + k: PositiveInt = 5, embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, s: float = 1.0, ignore_first_neighbours: int = 0, @@ -84,7 +84,7 @@ def __init__( def from_context( cls, context: Context, - k: PositiveInt, + k: PositiveInt = 5, s: PositiveFloat = 1.0, ignore_first_neighbours: NonNegativeInt = 0, embedder_config: EmbedderConfig | str | None = None, diff --git a/autointent/modules/scoring/_sklearn/sklearn_scorer.py b/autointent/modules/scoring/_sklearn/sklearn_scorer.py index 62b340f94..69b4fa753 100644 --- a/autointent/modules/scoring/_sklearn/sklearn_scorer.py +++ b/autointent/modules/scoring/_sklearn/sklearn_scorer.py @@ -59,7 +59,7 @@ class SklearnScorer(BaseScorer): def __init__( self, - clf_name: str, + clf_name: str = "LogisticRegression", embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, **clf_args: Any, # noqa: ANN401 ) -> None: @@ -83,7 +83,7 @@ def __init__( def from_context( cls, context: Context, - clf_name: str, + clf_name: str = "LogisticRegression", embedder_config: EmbedderConfig | str | None = None, **clf_args: float | str | bool, ) -> Self: diff --git a/autointent/nodes/_inference_node.py b/autointent/nodes/_inference_node.py index 3fc7390fa..3cc2943e1 100644 --- a/autointent/nodes/_inference_node.py +++ b/autointent/nodes/_inference_node.py @@ -31,8 +31,8 @@ def from_config(cls, config: InferenceNodeConfig) -> "InferenceNode": config: Config to init from """ node_info = NODES_INFO[config.node_type] - module = node_info.modules_available[config.module_name](**config.module_config) - module.load( + module_cls = node_info.modules_available[config.module_name] + module = module_cls.load( config.load_path, embedder_config=getattr(config, "embedder_config", None), cross_encoder_config=getattr(config, "cross_encoder_config", None), diff --git a/tests/modules/decision/test_adaptive.py b/tests/modules/decision/test_adaptive.py index 33dcf22c9..08086f05e 100644 --- a/tests/modules/decision/test_adaptive.py +++ b/tests/modules/decision/test_adaptive.py @@ -37,9 +37,14 @@ def test_dump_load(multilabel_fit_data): path = setup_environment() / "adaptive_module" predictor.dump(path) + del predictor + + predictor = AdaptiveDecision.load(path) + + assert hasattr(predictor, "_r") + assert predictor._r is not None + assert isinstance(predictor._r, float) - predictor = AdaptiveDecision() - predictor.load(path) new_preds = predictor.predict(multilabel_fit_data[0]) assert all(p == n for p, n in zip(preds, new_preds, strict=True)) diff --git a/tests/modules/decision/test_argmax.py b/tests/modules/decision/test_argmax.py index c3dd87bdc..880074af7 100644 --- a/tests/modules/decision/test_argmax.py +++ b/tests/modules/decision/test_argmax.py @@ -3,6 +3,7 @@ from autointent.exceptions import MismatchNumClassesError, WrongClassificationError from autointent.modules.decision import ArgmaxDecision +from tests.conftest import setup_environment def test_multiclass(multiclass_fit_data, scores): @@ -24,3 +25,18 @@ def test_fails_on_wrong_clf_problem(multilabel_fit_data): predictor = ArgmaxDecision() with pytest.raises(WrongClassificationError): predictor.fit(*multilabel_fit_data) + + +def test_dump_load(multiclass_fit_data): + predictor = ArgmaxDecision() + predictor.fit(*multiclass_fit_data) + predictions = predictor.predict(multiclass_fit_data[0]) + + path = setup_environment() / "argmax_module" + predictor.dump(path) + del predictor + + predictor = ArgmaxDecision.load(path) + new_predictions = predictor.predict(multiclass_fit_data[0]) + + assert all(p == n for p, n in zip(predictions, new_predictions, strict=True)) diff --git a/tests/modules/decision/test_jinoos.py b/tests/modules/decision/test_jinoos.py index 73fda9d89..e37ecd4f3 100644 --- a/tests/modules/decision/test_jinoos.py +++ b/tests/modules/decision/test_jinoos.py @@ -6,6 +6,7 @@ from autointent.exceptions import MismatchNumClassesError, WrongClassificationError from autointent.modules import JinoosDecision +from tests.conftest import setup_environment def detect_oos(scores: npt.NDArray[Any], labels: npt.NDArray[Any], thresh: float): @@ -40,3 +41,23 @@ def test_fails_on_wrong_clf_problem(multilabel_fit_data): predictor = JinoosDecision() with pytest.raises(WrongClassificationError): predictor.fit(*multilabel_fit_data) + + +def test_dump_load(multiclass_fit_data): + predictor = JinoosDecision() + predictor.fit(*multiclass_fit_data) + predictions = predictor.predict(multiclass_fit_data[0]) + + path = setup_environment() / "jinoos_module" + predictor.dump(path) + del predictor + + predictor = JinoosDecision.load(path) + + assert hasattr(predictor, "_thresh") + assert predictor._thresh is not None + assert isinstance(predictor._thresh, float) + + new_predictions = predictor.predict(multiclass_fit_data[0]) + + assert all(p == n for p, n in zip(predictions, new_predictions, strict=True)) diff --git a/tests/modules/decision/test_threshold.py b/tests/modules/decision/test_threshold.py index 6b108fd6f..dcf3d348c 100644 --- a/tests/modules/decision/test_threshold.py +++ b/tests/modules/decision/test_threshold.py @@ -39,3 +39,25 @@ def test_fails_on_wrong_n_classes_fit(multiclass_fit_data): predictor = ThresholdDecision(thresh=[0.5]) with pytest.raises(MismatchNumClassesError): predictor.fit(*multiclass_fit_data) + + +@pytest.mark.parametrize("fit_fixture", ["multiclass_fit_data", "multilabel_fit_data"]) +def test_dump_load(fit_fixture, request, tmp_path): + fit_data = request.getfixturevalue(fit_fixture) + predictor = ThresholdDecision(thresh=0.3) + predictor.fit(*fit_data) + predictions = predictor.predict(fit_data[0]) + + predictor.dump(tmp_path) + del predictor + + predictor = ThresholdDecision.load(tmp_path) + + assert hasattr(predictor, "thresh") + assert predictor.thresh is not None + assert predictor.thresh == 0.3 + assert isinstance(predictor.thresh, float) + + new_predictions = predictor.predict(fit_data[0]) + + assert all(p == n for p, n in zip(predictions, new_predictions, strict=True)) diff --git a/tests/modules/decision/test_tunable.py b/tests/modules/decision/test_tunable.py index 18b4e1d45..69a2ebe7c 100644 --- a/tests/modules/decision/test_tunable.py +++ b/tests/modules/decision/test_tunable.py @@ -37,3 +37,23 @@ def test_fails_on_wrong_n_classes_predict(multiclass_fit_data): scores = np.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]]) with pytest.raises(MismatchNumClassesError): predictor.predict(scores) + + +@pytest.mark.parametrize("fit_fixture", ["multiclass_fit_data", "multilabel_fit_data"]) +def test_dump_load(fit_fixture, request, tmp_path): + fit_data = request.getfixturevalue(fit_fixture) + predictor = TunableDecision() + predictor.fit(*fit_data) + predictions = predictor.predict(fit_data[0]) + + predictor.dump(tmp_path) + del predictor + + predictor = TunableDecision.load(tmp_path) + assert hasattr(predictor, "thresh") + assert predictor.thresh is not None + assert isinstance(predictor.thresh, np.ndarray) + + new_predictions = predictor.predict(fit_data[0]) + + assert all(p == n for p, n in zip(predictions, new_predictions, strict=True)) diff --git a/tests/modules/embedding/test_logreg.py b/tests/modules/embedding/test_logreg.py index bd62a6632..613e8c445 100644 --- a/tests/modules/embedding/test_logreg.py +++ b/tests/modules/embedding/test_logreg.py @@ -1,4 +1,7 @@ +import numpy as np + from autointent.modules.embedding import LogregAimedEmbedding +from tests.conftest import setup_environment def test_get_assets_returns_correct_artifact_for_logreg(): @@ -31,3 +34,19 @@ def test_predict_evaluates_model(): assert len(probas) == 2 assert probas[0][0] > probas[0][1] assert probas[1][1] > probas[1][0] + + +def test_dump_load(): + module = LogregAimedEmbedding(embedder_config="sergeyzh/rubert-tiny-turbo") + utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"] + labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1] + module.fit(utterances, labels) + predictions = module.predict(["hello", "bye"]) + + dump_path = setup_environment() + + module.dump(dump_path) + del module + module = LogregAimedEmbedding.load(dump_path) + predictions_loaded = module.predict(["hello", "bye"]) + assert np.allclose(predictions, predictions_loaded) diff --git a/tests/modules/embedding/test_retrieval.py b/tests/modules/embedding/test_retrieval.py index 6e85d927a..3bc97425a 100644 --- a/tests/modules/embedding/test_retrieval.py +++ b/tests/modules/embedding/test_retrieval.py @@ -1,7 +1,6 @@ -import shutil +from pathlib import Path from autointent.modules.embedding import RetrievalAimedEmbedding -from tests.conftest import setup_environment def test_get_assets_returns_correct_artifact(): @@ -10,19 +9,17 @@ def test_get_assets_returns_correct_artifact(): assert artifact.config.model_name == "sergeyzh/rubert-tiny-turbo" -def test_dump_and_load_preserves_model_state(): - project_dir = setup_environment() +def test_dump_and_load_preserves_model_state(tmp_path: Path): module = RetrievalAimedEmbedding(k=5, embedder_config="sergeyzh/rubert-tiny-turbo") utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"] labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1] module.fit(utterances, labels) + predictions = module.predict(utterances) - module.dump(project_dir) + module.dump(tmp_path) + del module - loaded_module = RetrievalAimedEmbedding(k=5, embedder_config="sergeyzh/rubert-tiny-turbo") - loaded_module.load(project_dir) - - assert loaded_module.embedder_config == module.embedder_config - - shutil.rmtree(project_dir) + loaded_module = RetrievalAimedEmbedding.load(tmp_path) + predictions_loaded = loaded_module.predict(utterances) + assert predictions == predictions_loaded diff --git a/tests/modules/scoring/test_bert.py b/tests/modules/scoring/test_bert.py index a2b7a3c5d..86fbc685b 100644 --- a/tests/modules/scoring/test_bert.py +++ b/tests/modules/scoring/test_bert.py @@ -33,8 +33,7 @@ def test_bert_scorer_dump_load(dataset): scorer_original.dump(str(temp_dir_path)) # Create a new scorer and load saved model - scorer_loaded = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) - scorer_loaded.load(str(temp_dir_path)) + scorer_loaded = BertScorer.load(str(temp_dir_path)) # Verify model and tokenizer are loaded assert hasattr(scorer_loaded, "_model") diff --git a/tests/modules/scoring/test_description.py b/tests/modules/scoring/test_description.py index d7aa4355b..42962b6cd 100644 --- a/tests/modules/scoring/test_description.py +++ b/tests/modules/scoring/test_description.py @@ -46,7 +46,12 @@ def test_description_scorer(dataset, expected_prediction, multilabel): assert len(predictions) == len(test_utterances) assert metadata is None - scorer.clear_cache() + with tempfile.TemporaryDirectory() as temp_dir: + scorer.dump(temp_dir) + del scorer + new_scorer = DescriptionScorer.load(temp_dir) + new_predictions = new_scorer.predict(test_utterances) + np.testing.assert_almost_equal(predictions, new_predictions, decimal=5) @pytest.mark.parametrize( @@ -97,10 +102,7 @@ def test_description_scorer_cross_encoder(dataset, expected_prediction, multilab with tempfile.TemporaryDirectory() as temp_dir: scorer.dump(temp_dir) - new_scorer = DescriptionScorer( - cross_encoder_config="cross-encoder/ms-marco-MiniLM-L6-v2", encoder_type="cross", temperature=0.3 - ) - new_scorer.load(temp_dir) + new_scorer = DescriptionScorer.load(temp_dir) loaded_predictions = new_scorer.predict(test_utterances) diff --git a/tests/modules/scoring/test_dnnc.py b/tests/modules/scoring/test_dnnc.py index 62cb38617..93c734e92 100644 --- a/tests/modules/scoring/test_dnnc.py +++ b/tests/modules/scoring/test_dnnc.py @@ -1,3 +1,5 @@ +import tempfile + import numpy as np import pytest @@ -31,4 +33,9 @@ def test_base_dnnc(dataset, train_head, pred_score): assert "neighbors" in metadata[0] assert "scores" in metadata[0] - scorer.clear_cache() + with tempfile.TemporaryDirectory() as temp_dir: + scorer.dump(temp_dir) + del scorer + new_scorer = DNNCScorer.load(temp_dir) + new_predictions = new_scorer.predict(test_data) + np.testing.assert_almost_equal(predictions, new_predictions, decimal=5) diff --git a/tests/modules/scoring/test_knn.py b/tests/modules/scoring/test_knn.py index 5f48c3f5c..538c12b88 100644 --- a/tests/modules/scoring/test_knn.py +++ b/tests/modules/scoring/test_knn.py @@ -1,3 +1,5 @@ +import tempfile + import numpy as np from autointent.context.data_handler import DataHandler @@ -35,3 +37,10 @@ def test_base_knn(dataset): predictions, metadata = scorer.predict_with_metadata(test_data) assert len(predictions) == len(test_data) assert "neighbors" in metadata[0] + + with tempfile.TemporaryDirectory() as temp_dir: + scorer.dump(temp_dir) + del scorer + new_scorer = KNNScorer.load(temp_dir) + new_predictions = new_scorer.predict(test_data) + assert np.allclose(predictions, new_predictions) diff --git a/tests/modules/scoring/test_linear.py b/tests/modules/scoring/test_linear.py index 6179acefb..b02302283 100644 --- a/tests/modules/scoring/test_linear.py +++ b/tests/modules/scoring/test_linear.py @@ -1,3 +1,5 @@ +import tempfile + import numpy as np from autointent.context.data_handler import DataHandler @@ -35,3 +37,10 @@ def test_base_linear(dataset): predictions, metadata = scorer.predict_with_metadata(test_data) assert len(predictions) == len(test_data) assert metadata is None + + with tempfile.TemporaryDirectory() as temp_dir: + scorer.dump(temp_dir) + del scorer + new_scorer = LinearScorer.load(temp_dir) + new_predictions = new_scorer.predict(test_data) + np.testing.assert_almost_equal(predictions, new_predictions, decimal=5) diff --git a/tests/modules/scoring/test_lora.py b/tests/modules/scoring/test_lora.py index f2a9cdfd3..f9d2725fd 100644 --- a/tests/modules/scoring/test_lora.py +++ b/tests/modules/scoring/test_lora.py @@ -35,10 +35,7 @@ def test_lora_scorer_dump_load(dataset): scorer_original.dump(str(temp_dir_path)) # Create a new scorer and load saved model - scorer_loaded = BERTLoRAScorer( - classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8 - ) - scorer_loaded.load(str(temp_dir_path)) + scorer_loaded = BERTLoRAScorer.load(str(temp_dir_path)) # Verify model and tokenizer are loaded assert hasattr(scorer_loaded, "_model") diff --git a/tests/modules/scoring/test_mlknn.py b/tests/modules/scoring/test_mlknn.py index b4f6210f0..f4ef7c139 100644 --- a/tests/modules/scoring/test_mlknn.py +++ b/tests/modules/scoring/test_mlknn.py @@ -1,3 +1,5 @@ +import tempfile + import numpy as np from autointent.context.data_handler import DataHandler @@ -40,3 +42,10 @@ def test_base_mlknn(dataset): predictions, metadata = scorer.predict_with_metadata(test_data) assert len(predictions) == len(test_data) assert "neighbors" in metadata[0] + + with tempfile.TemporaryDirectory() as temp_dir: + scorer.dump(temp_dir) + del scorer + new_scorer = MLKnnScorer.load(temp_dir) + new_predictions = new_scorer.predict(test_data) + assert np.allclose(predictions, new_predictions) diff --git a/tests/modules/scoring/test_ptuning.py b/tests/modules/scoring/test_ptuning.py index 4d6ff2385..d74fe39bf 100644 --- a/tests/modules/scoring/test_ptuning.py +++ b/tests/modules/scoring/test_ptuning.py @@ -33,14 +33,7 @@ def test_ptuning_scorer_dump_load(dataset): try: scorer_original.dump(str(temp_dir_path)) - scorer_loaded = PTuningScorer( - classification_model_config="prajjwal1/bert-tiny", - num_train_epochs=1, - batch_size=8, - num_virtual_tokens=10, - seed=42, - ) - scorer_loaded.load(str(temp_dir_path)) + scorer_loaded = PTuningScorer.load(str(temp_dir_path)) assert hasattr(scorer_loaded, "_model") assert scorer_loaded._model is not None diff --git a/tests/modules/scoring/test_rerank_scorer.py b/tests/modules/scoring/test_rerank_scorer.py index 1b8d28fe1..024d54e66 100644 --- a/tests/modules/scoring/test_rerank_scorer.py +++ b/tests/modules/scoring/test_rerank_scorer.py @@ -1,3 +1,5 @@ +import tempfile + import numpy as np from autointent.context.data_handler import DataHandler @@ -41,3 +43,10 @@ def test_base_rerank_scorer(dataset): predictions, metadata = scorer.predict_with_metadata(test_data) assert len(predictions) == len(test_data) assert "neighbors" in metadata[0] + + with tempfile.TemporaryDirectory() as temp_dir: + scorer.dump(temp_dir) + del scorer + new_scorer = RerankScorer.load(temp_dir) + new_predictions = new_scorer.predict(test_data) + assert np.allclose(predictions, new_predictions) diff --git a/tests/modules/scoring/test_sklearn.py b/tests/modules/scoring/test_sklearn.py index b2c041f1f..62b409c6d 100644 --- a/tests/modules/scoring/test_sklearn.py +++ b/tests/modules/scoring/test_sklearn.py @@ -1,3 +1,5 @@ +import tempfile + import numpy as np from autointent.context.data_handler import DataHandler @@ -42,3 +44,10 @@ def test_base_sklearn(dataset): predictions, metadata = scorer.predict_with_metadata(test_data) assert len(predictions) == len(test_data) assert metadata is None + + with tempfile.TemporaryDirectory() as temp_dir: + scorer.dump(temp_dir) + del scorer + new_scorer = SklearnScorer.load(temp_dir) + new_predictions = new_scorer.predict(test_data) + np.testing.assert_almost_equal(predictions, new_predictions, decimal=5) diff --git a/tests/modules/test_dumper.py b/tests/modules/test_dumper.py new file mode 100644 index 000000000..da3d458ad --- /dev/null +++ b/tests/modules/test_dumper.py @@ -0,0 +1,186 @@ +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import torch +from sklearn.linear_model import LogisticRegression +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from autointent import Embedder, Ranker, VectorIndex +from autointent._dump_tools import Dumper +from autointent.configs import CrossEncoderConfig, EmbedderConfig, TokenizerConfig +from autointent.schemas import Tag, TagsList + + +class TestSimpleAttributes: + def init_attributes(self): + self.integer = 1 + self.float = 1.0 + self.string = "test" + self.boolean = True + self.array = np.array([1, 2, 3]) + + def check_attributes(self): + assert self.integer == 1 + assert self.float == 1.0 + assert self.string == "test" + assert self.boolean + np.testing.assert_array_equal(self.array, np.array([1, 2, 3])) + + +class TestTags: + def init_attributes(self): + self.tags = TagsList(tags=[Tag(name="hello", intent_ids=[0]), Tag(name="world", intent_ids=[1])]) + + def check_attributes(self): + assert self.tags == TagsList(tags=[Tag(name="hello", intent_ids=[0]), Tag(name="world", intent_ids=[1])]) + + +class TestTransformers: + def init_attributes(self): + self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + self._tokenizer_predictions = np.array(self.tokenizer(["hello", "world"]).input_ids) + self.transformer = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") + + with torch.no_grad(): + self._transformer_predictions = ( + self.transformer(input_ids=torch.tensor(self._tokenizer_predictions)).logits.cpu().numpy() + ) + + def check_attributes(self): + tokenizer_predictions = self.tokenizer(["hello", "world"]).input_ids + np.testing.assert_array_equal(self._tokenizer_predictions, tokenizer_predictions) + with torch.no_grad(): + np.testing.assert_array_equal( + self._transformer_predictions, + self.transformer(input_ids=torch.tensor(tokenizer_predictions)).logits.cpu().numpy(), + ) + + +class TestVectorIndex: + def init_attributes(self): + self.vector_index = VectorIndex( + embedder_config=EmbedderConfig.from_search_config("bert-base-uncased"), + ) + self.vector_index.add(texts=["hello", "world"], labels=[0, 1]) + + def check_attributes(self): + assert self.vector_index.texts == ["hello", "world"] + assert self.vector_index.labels == [0, 1] + + +class TestEmbedder: + def init_attributes(self): + self.embedder = Embedder( + embedder_config=EmbedderConfig.from_search_config("bert-base-uncased"), + ) + self._embedder_predictions = self.embedder.embed(["hello", "world"]) + + def check_attributes(self): + np.testing.assert_array_equal( + self._embedder_predictions, + self.embedder.embed(["hello", "world"]), + ) + + +class TestSklearnEstimator: + def init_attributes(self): + self.estimator = LogisticRegression() + self.estimator.fit([[1, 2, 3], [4, 5, 6]], [0, 1]) + self._estimator_predictions = self.estimator.predict([[1, 2, 3], [4, 5, 6]]) + + def check_attributes(self): + np.testing.assert_array_equal( + self._estimator_predictions, + self.estimator.predict([[1, 2, 3], [4, 5, 6]]), + ) + + +class TestRanker: + def init_attributes(self): + self.ranker = Ranker( + cross_encoder_config={"model_name": "cross-encoder/ms-marco-MiniLM-L6-v2", "train_head": True}, + ) + self.ranker.fit( + ["hello", "world", "bye", "earth", "hello", "world", "bye", "earth"], + [0, 1, 0, 1, 0, 1, 0, 1], + ) + self._ranker_predictions = self.ranker.predict( + [("hello", "world"), ("bye", "earth")], + ) + + def check_attributes(self): + np.testing.assert_array_equal( + self._ranker_predictions, + self.ranker.predict([("hello", "world"), ("bye", "earth")]), + ) + + +class TestEmbedderConfig: + def init_attributes(self): + self.pydantic_model = EmbedderConfig( + model_name="bert-base-uncased", + batch_size=16, + device="cpu", + trust_remote_code=True, + tokenizer_config=TokenizerConfig(max_length=512, padding="longest", truncation=False), + ) + + def check_attributes(self): + assert self.pydantic_model.model_name == "bert-base-uncased" + assert self.pydantic_model.batch_size == 16 + assert self.pydantic_model.device == "cpu" + assert self.pydantic_model.trust_remote_code + assert self.pydantic_model.tokenizer_config.max_length == 512 + assert self.pydantic_model.tokenizer_config.padding == "longest" + assert not self.pydantic_model.tokenizer_config.truncation + + +class TestCrossEncoderConfig: + def init_attributes(self): + self.pydantic_model = CrossEncoderConfig( + model_name="cross-encoder/ms-marco-MiniLM-L6-v2", + train_head=True, + device="cpu", + batch_size=16, + trust_remote_code=True, + tokenizer_config=TokenizerConfig(max_length=512, padding="longest", truncation=False), + ) + + def check_attributes(self): + assert self.pydantic_model.model_name == "cross-encoder/ms-marco-MiniLM-L6-v2" + assert self.pydantic_model.train_head + assert self.pydantic_model.device == "cpu" + assert self.pydantic_model.batch_size == 16 + assert self.pydantic_model.trust_remote_code + assert self.pydantic_model.tokenizer_config.max_length == 512 + assert self.pydantic_model.tokenizer_config.padding == "longest" + assert not self.pydantic_model.tokenizer_config.truncation + + +@pytest.mark.parametrize( + "test_class", + [ + TestSimpleAttributes, + TestTags, + TestTransformers, + TestVectorIndex, + TestEmbedder, + TestSklearnEstimator, + TestRanker, + TestEmbedderConfig, + TestCrossEncoderConfig, + ], +) +def test_dumper(test_class): + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir: + test_obj = test_class() + test_obj.init_attributes() + + Dumper.dump(test_obj, Path(temp_dir)) + del test_obj + + loaded_obj = test_class() + Dumper.load(loaded_obj, Path(temp_dir)) + loaded_obj.check_attributes() diff --git a/tests/modules/test_regex.py b/tests/modules/test_regex.py index a9fdb275b..038abd983 100644 --- a/tests/modules/test_regex.py +++ b/tests/modules/test_regex.py @@ -1,3 +1,5 @@ +import tempfile + import pytest from autointent.modules import SimpleRegex @@ -32,3 +34,10 @@ def test_base_regex(partial_match, expected_predictions): assert "partial_matches" in metadata[0] assert "full_matches" in metadata[0] + + with tempfile.TemporaryDirectory() as temp_dir: + matcher.dump(temp_dir) + del matcher + new_matcher = SimpleRegex.load(temp_dir) + new_predictions = new_matcher.predict(test_data) + assert predictions == new_predictions diff --git a/user_guides/basic_usage/02_modules.py b/user_guides/basic_usage/02_modules.py index 7c2f8732b..629d3f9b6 100644 --- a/user_guides/basic_usage/02_modules.py +++ b/user_guides/basic_usage/02_modules.py @@ -84,11 +84,7 @@ """ # %% -loaded_scorer = KNNScorer( - embedder_config="sergeyzh/rubert-tiny-turbo", - k=5, -) -loaded_scorer.load(pathdir) +loaded_scorer = KNNScorer.load(pathdir) loaded_scorer.predict(["hello world!"]) # %% [markdown]