diff --git a/autointent/_embedder.py b/autointent/_embedder.py index be45ced91..81a6b94c4 100644 --- a/autointent/_embedder.py +++ b/autointent/_embedder.py @@ -87,7 +87,7 @@ def __hash__(self) -> int: hasher = Hasher() for parameter in self.embedding_model.parameters(): hasher.update(parameter.detach().cpu().numpy()) - hasher.update(self.config.max_length) + hasher.update(self.config.tokenizer_config.max_length) return hasher.intdigest() def clear_ram(self) -> None: @@ -114,7 +114,7 @@ def dump(self, path: Path) -> None: model_name=str(self.config.model_name), device=self.config.device, batch_size=self.config.batch_size, - max_length=self.config.max_length, + max_length=self.config.tokenizer_config.max_length, use_cache=self.config.use_cache, ) path.mkdir(parents=True, exist_ok=True) @@ -137,6 +137,10 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) - else: kwargs = metadata # type: ignore[assignment] + max_length = kwargs.pop("max_length", None) + if max_length is not None: + kwargs["tokenizer_config"] = {"max_length": max_length} + return cls(EmbedderConfig(**kwargs)) def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]: @@ -162,12 +166,12 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> "Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s", self.config.model_name, self.config.batch_size, - str(self.config.max_length), + str(self.config.tokenizer_config.max_length), self.config.device, ) - if self.config.max_length is not None: - self.embedding_model.max_seq_length = self.config.max_length + if self.config.tokenizer_config.max_length is not None: + self.embedding_model.max_seq_length = self.config.tokenizer_config.max_length embeddings = self.embedding_model.encode( utterances, diff --git a/autointent/_ranker.py b/autointent/_ranker.py index 798cd1a9a..43774d5dd 100644 --- a/autointent/_ranker.py +++ b/autointent/_ranker.py @@ -113,7 +113,7 @@ def __init__( self.config.model_name, trust_remote_code=True, device=self.config.device, - max_length=self.config.max_length, # type: ignore[arg-type] + max_length=self.config.tokenizer_config.max_length, # type: ignore[arg-type] ) self._train_head = False self._clf = classifier_head @@ -252,7 +252,7 @@ def save(self, path: str) -> None: model_name=self.config.model_name, train_head=self._train_head, device=self.config.device, - max_length=self.config.max_length, + max_length=self.config.tokenizer_config.max_length, batch_size=self.config.batch_size, ) @@ -282,6 +282,10 @@ def load(cls, path: Path, override_config: CrossEncoderConfig | None = None) -> else: kwargs = metadata # type: ignore[assignment] + max_length = kwargs.pop("max_length", None) + if max_length is not None: + kwargs["tokenizer_config"] = {"max_length": max_length} + return cls( CrossEncoderConfig(**kwargs), classifier_head=clf, diff --git a/autointent/_vector_index.py b/autointent/_vector_index.py index 1f16b102d..92a313015 100644 --- a/autointent/_vector_index.py +++ b/autointent/_vector_index.py @@ -15,7 +15,7 @@ import numpy.typing as npt from autointent import Embedder -from autointent.configs import EmbedderConfig, TaskTypeEnum +from autointent.configs import EmbedderConfig, TaskTypeEnum, TokenizerConfig from autointent.custom_types import ListOfLabels @@ -195,7 +195,7 @@ def dump(self, dir_path: Path) -> None: json.dump(data, file, indent=4, ensure_ascii=False) metadata = VectorIndexMetadata( - embedder_max_length=self.embedder.config.max_length, + embedder_max_length=self.embedder.config.tokenizer_config.max_length, embedder_model_name=str(self.embedder.config.model_name), embedder_device=self.embedder.config.device, embedder_batch_size=self.embedder.config.batch_size, @@ -229,7 +229,7 @@ def load( model_name=metadata["embedder_model_name"], device=embedder_device or metadata["embedder_device"], batch_size=embedder_batch_size or metadata["embedder_batch_size"], - max_length=metadata["embedder_max_length"], + tokenizer_config=TokenizerConfig(max_length=metadata["embedder_max_length"]), use_cache=embedder_use_cache or metadata["embedder_use_cache"], ) ) diff --git a/autointent/configs/__init__.py b/autointent/configs/__init__.py index b939a5395..410d3e15f 100644 --- a/autointent/configs/__init__.py +++ b/autointent/configs/__init__.py @@ -2,14 +2,16 @@ from ._inference_node import InferenceNodeConfig from ._optimization import DataConfig, LoggingConfig -from ._transformers import CrossEncoderConfig, EmbedderConfig, TaskTypeEnum +from ._transformers import CrossEncoderConfig, EmbedderConfig, HFModelConfig, TaskTypeEnum, TokenizerConfig __all__ = [ "CrossEncoderConfig", "DataConfig", "EmbedderConfig", + "HFModelConfig", "InferenceNodeConfig", "InferenceNodeConfig", "LoggingConfig", "TaskTypeEnum", + "TokenizerConfig", ] diff --git a/autointent/configs/_transformers.py b/autointent/configs/_transformers.py index 2cbe98efb..bfc88839d 100644 --- a/autointent/configs/_transformers.py +++ b/autointent/configs/_transformers.py @@ -1,19 +1,24 @@ from enum import Enum -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, PositiveInt from typing_extensions import Self, assert_never -class ModelConfig(BaseModel): - model_config = ConfigDict(extra="forbid") - batch_size: PositiveInt = Field(32, description="Batch size for model inference.") +class TokenizerConfig(BaseModel): + padding: bool | Literal["longest", "max_length", "do_not_pad"] = True + truncation: bool = True max_length: PositiveInt | None = Field(None, description="Maximum length of input sequences.") -class STModelConfig(ModelConfig): - model_name: str +class HFModelConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + model_name: str = Field( + "prajjwal1/bert-tiny", description="Name of the hugging face repository with transformer model." + ) + batch_size: PositiveInt = Field(32, description="Batch size for model inference.") device: str | None = Field(None, description="Torch notation for CPU or CUDA.") + tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig) @classmethod def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self: @@ -26,7 +31,7 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Model configuration. """ if values is None: - return cls() # type: ignore[call-arg] + return cls() if isinstance(values, BaseModel): return values # type: ignore[return-value] if isinstance(values, str): @@ -45,7 +50,7 @@ class TaskTypeEnum(Enum): sts = "sts" -class EmbedderConfig(STModelConfig): +class EmbedderConfig(HFModelConfig): model_name: str = Field("sentence-transformers/all-MiniLM-L6-v2", description="Name of the hugging face model.") default_prompt: str | None = Field( None, description="Default prompt for the model. This is used when no task specific prompt is not provided." @@ -105,7 +110,7 @@ def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # no use_cache: bool = Field(False, description="Whether to use embeddings caching.") -class CrossEncoderConfig(STModelConfig): +class CrossEncoderConfig(HFModelConfig): model_name: str = Field("cross-encoder/ms-marco-MiniLM-L-6-v2", description="Name of the hugging face model.") train_head: bool = Field( False, description="Whether to train the head of the model. If False, LogReg will be trained." diff --git a/autointent/modules/__init__.py b/autointent/modules/__init__.py index 5eefed1e7..212d886b1 100644 --- a/autointent/modules/__init__.py +++ b/autointent/modules/__init__.py @@ -12,7 +12,16 @@ ) from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding from .regex import SimpleRegex -from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, SklearnScorer +from .scoring import ( + BertScorer, + DescriptionScorer, + DNNCScorer, + KNNScorer, + LinearScorer, + MLKnnScorer, + RerankScorer, + SklearnScorer, +) T = TypeVar("T", bound=BaseModule) @@ -36,6 +45,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]: RerankScorer, SklearnScorer, MLKnnScorer, + BertScorer, ] ) diff --git a/autointent/modules/scoring/__init__.py b/autointent/modules/scoring/__init__.py index fa06f40ce..c8d7fde95 100644 --- a/autointent/modules/scoring/__init__.py +++ b/autointent/modules/scoring/__init__.py @@ -1,3 +1,4 @@ +from ._bert import BertScorer from ._description import DescriptionScorer from ._dnnc import DNNCScorer from ._knn import KNNScorer, RerankScorer @@ -6,6 +7,7 @@ from ._sklearn import SklearnScorer __all__ = [ + "BertScorer", "DNNCScorer", "DescriptionScorer", "KNNScorer", diff --git a/autointent/modules/scoring/_bert.py b/autointent/modules/scoring/_bert.py new file mode 100644 index 000000000..a2e3eb946 --- /dev/null +++ b/autointent/modules/scoring/_bert.py @@ -0,0 +1,148 @@ +"""BertScorer class for transformer-based classification.""" + +import tempfile +from typing import Any + +import numpy as np +import numpy.typing as npt +import torch +from datasets import Dataset +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + Trainer, + TrainingArguments, +) + +from autointent import Context +from autointent._callbacks import REPORTERS_NAMES +from autointent.configs import HFModelConfig +from autointent.custom_types import ListOfLabels +from autointent.modules.base import BaseScorer + + +class BertScorer(BaseScorer): + name = "transformer" + supports_multiclass = True + supports_multilabel = True + _model: Any + _tokenizer: Any + + def __init__( + self, + model_config: HFModelConfig | str | dict[str, Any] | None = None, + num_train_epochs: int = 3, + batch_size: int = 8, + learning_rate: float = 5e-5, + seed: int = 0, + report_to: REPORTERS_NAMES | None = None, # type: ignore # noqa: PGH003 + ) -> None: + self.model_config = HFModelConfig.from_search_config(model_config) + self.num_train_epochs = num_train_epochs + self.batch_size = batch_size + self.learning_rate = learning_rate + self.seed = seed + self.report_to = report_to + + @classmethod + def from_context( + cls, + context: Context, + model_config: HFModelConfig | str | dict[str, Any] | None = None, + num_train_epochs: int = 3, + batch_size: int = 8, + learning_rate: float = 5e-5, + seed: int = 0, + ) -> "BertScorer": + if model_config is None: + model_config = context.resolve_embedder() + + report_to = context.logging_config.report_to + + return cls( + model_config=model_config, + num_train_epochs=num_train_epochs, + batch_size=batch_size, + learning_rate=learning_rate, + seed=seed, + report_to=report_to, + ) + + def get_embedder_config(self) -> dict[str, Any]: + return self.model_config.model_dump() + + def fit( + self, + utterances: list[str], + labels: ListOfLabels, + ) -> None: + if hasattr(self, "_model"): + self.clear_cache() + + self._validate_task(labels) + + model_name = self.model_config.model_name + self._tokenizer = AutoTokenizer.from_pretrained(model_name) + self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=self._n_classes) + + use_cpu = self.model_config.device == "cpu" + + def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]: + return self._tokenizer( # type: ignore[no-any-return] + examples["text"], return_tensors="pt", **self.model_config.tokenizer_config.model_dump() + ) + + dataset = Dataset.from_dict({"text": utterances, "labels": labels}) + tokenized_dataset = dataset.map(tokenize_function, batched=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + num_train_epochs=self.num_train_epochs, + per_device_train_batch_size=self.batch_size, + learning_rate=self.learning_rate, + seed=self.seed, + save_strategy="no", + logging_strategy="steps", + logging_steps=10, + report_to=self.report_to, + use_cpu=use_cpu, + ) + + trainer = Trainer( + model=self._model, + args=training_args, + train_dataset=tokenized_dataset, + tokenizer=self._tokenizer, + data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer), + ) + + trainer.train() + + self._model.eval() + + def predict(self, utterances: list[str]) -> npt.NDArray[Any]: + if not hasattr(self, "_model") or not hasattr(self, "_tokenizer"): + msg = "Model is not trained. Call fit() first." + raise RuntimeError(msg) + + all_predictions = [] + for i in range(0, len(utterances), self.batch_size): + batch = utterances[i : i + self.batch_size] + inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump()) + with torch.no_grad(): + outputs = self._model(**inputs) + logits = outputs.logits + if self._multilabel: + batch_predictions = torch.sigmoid(logits).numpy() + else: + batch_predictions = torch.softmax(logits, dim=1).numpy() + all_predictions.append(batch_predictions) + return np.vstack(all_predictions) if all_predictions else np.array([]) + + def clear_cache(self) -> None: + if hasattr(self, "_model"): + del self._model + if hasattr(self, "_tokenizer"): + del self._tokenizer diff --git a/autointent/modules/scoring/_mlknn/mlknn.py b/autointent/modules/scoring/_mlknn/mlknn.py index d6e8f9057..306453010 100644 --- a/autointent/modules/scoring/_mlknn/mlknn.py +++ b/autointent/modules/scoring/_mlknn/mlknn.py @@ -140,7 +140,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None: model_name=self.embedder_config.model_name, device=self.embedder_config.device, batch_size=self.embedder_config.batch_size, - max_length=self.embedder_config.max_length, + tokenizer_config=self.embedder_config.tokenizer_config, use_cache=self.embedder_config.use_cache, ), ) diff --git a/autointent/modules/scoring/_sklearn/sklearn_scorer.py b/autointent/modules/scoring/_sklearn/sklearn_scorer.py index 19e1a635f..8e5ed51c2 100644 --- a/autointent/modules/scoring/_sklearn/sklearn_scorer.py +++ b/autointent/modules/scoring/_sklearn/sklearn_scorer.py @@ -128,7 +128,7 @@ def fit( model_name=self.embedder_config.model_name, device=self.embedder_config.device, batch_size=self.embedder_config.batch_size, - max_length=self.embedder_config.max_length, + tokenizer_config=self.embedder_config.tokenizer_config, use_cache=self.embedder_config.use_cache, ) ) diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index aabd685ac..d6a3b595a 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -3,6 +3,12 @@ "CrossEncoderConfig": { "additionalProperties": false, "properties": { + "model_name": { + "default": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "description": "Name of the hugging face model.", + "title": "Model Name", + "type": "string" + }, "batch_size": { "default": 32, "description": "Batch size for model inference.", @@ -10,26 +16,6 @@ "title": "Batch Size", "type": "integer" }, - "max_length": { - "anyOf": [ - { - "exclusiveMinimum": 0, - "type": "integer" - }, - { - "type": "null" - } - ], - "default": null, - "description": "Maximum length of input sequences.", - "title": "Max Length" - }, - "model_name": { - "default": "cross-encoder/ms-marco-MiniLM-L-6-v2", - "description": "Name of the hugging face model.", - "title": "Model Name", - "type": "string" - }, "device": { "anyOf": [ { @@ -43,6 +29,9 @@ "description": "Torch notation for CPU or CUDA.", "title": "Device" }, + "tokenizer_config": { + "$ref": "#/$defs/TokenizerConfig" + }, "train_head": { "default": false, "description": "Whether to train the head of the model. If False, LogReg will be trained.", @@ -104,6 +93,12 @@ "EmbedderConfig": { "additionalProperties": false, "properties": { + "model_name": { + "default": "sentence-transformers/all-MiniLM-L6-v2", + "description": "Name of the hugging face model.", + "title": "Model Name", + "type": "string" + }, "batch_size": { "default": 32, "description": "Batch size for model inference.", @@ -111,26 +106,6 @@ "title": "Batch Size", "type": "integer" }, - "max_length": { - "anyOf": [ - { - "exclusiveMinimum": 0, - "type": "integer" - }, - { - "type": "null" - } - ], - "default": null, - "description": "Maximum length of input sequences.", - "title": "Max Length" - }, - "model_name": { - "default": "sentence-transformers/all-MiniLM-L6-v2", - "description": "Name of the hugging face model.", - "title": "Model Name", - "type": "string" - }, "device": { "anyOf": [ { @@ -144,6 +119,9 @@ "description": "Torch notation for CPU or CUDA.", "title": "Device" }, + "tokenizer_config": { + "$ref": "#/$defs/TokenizerConfig" + }, "default_prompt": { "anyOf": [ { @@ -301,6 +279,48 @@ }, "title": "LoggingConfig", "type": "object" + }, + "TokenizerConfig": { + "properties": { + "padding": { + "anyOf": [ + { + "type": "boolean" + }, + { + "enum": [ + "longest", + "max_length", + "do_not_pad" + ], + "type": "string" + } + ], + "default": true, + "title": "Padding" + }, + "truncation": { + "default": true, + "title": "Truncation", + "type": "boolean" + }, + "max_length": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Maximum length of input sequences.", + "title": "Max Length" + } + }, + "title": "TokenizerConfig", + "type": "object" } }, "description": "Configuration for the optimization process.\n\nOne can use it to customize optimization beyond choosing different preset.\nInstantiate it and pass to :py:meth:`autointent.Pipeline.from_optimization_config`.", @@ -334,10 +354,14 @@ "embedder_config": { "$ref": "#/$defs/EmbedderConfig", "default": { - "batch_size": 32, - "max_length": null, "model_name": "sentence-transformers/all-MiniLM-L6-v2", + "batch_size": 32, "device": null, + "tokenizer_config": { + "max_length": null, + "padding": true, + "truncation": true + }, "default_prompt": null, "classifier_prompt": null, "cluster_prompt": null, @@ -350,10 +374,14 @@ "cross_encoder_config": { "$ref": "#/$defs/CrossEncoderConfig", "default": { - "batch_size": 32, - "max_length": null, "model_name": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "batch_size": 32, "device": null, + "tokenizer_config": { + "max_length": null, + "padding": true, + "truncation": true + }, "train_head": false } }, diff --git a/pyproject.toml b/pyproject.toml index 9a2fad329..316bd2878 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "datasets (>=3.2.0,<4.0.0)", "xxhash (>=3.5.0,<4.0.0)", "python-dotenv (>=1.0.1,<2.0.0)", + "transformers[torch] (>=4.49.0,<5.0.0)", ] [project.urls] diff --git a/tests/callback/test_callback.py b/tests/callback/test_callback.py index 986b5d348..bbb094f4a 100644 --- a/tests/callback/test_callback.py +++ b/tests/callback/test_callback.py @@ -125,7 +125,7 @@ def test_pipeline_callbacks(dataset): "cluster_prompt": None, "default_prompt": None, "device": None, - "max_length": None, + "tokenizer_config": {"max_length": None, "truncation": True, "padding": True}, "model_name": "sergeyzh/rubert-tiny-turbo", "passage_prompt": None, "query_prompt": None, @@ -151,7 +151,7 @@ def test_pipeline_callbacks(dataset): "cluster_prompt": None, "default_prompt": None, "device": None, - "max_length": None, + "tokenizer_config": {"max_length": None, "truncation": True, "padding": True}, "model_name": "sergeyzh/rubert-tiny-turbo", "passage_prompt": None, "query_prompt": None, @@ -177,7 +177,7 @@ def test_pipeline_callbacks(dataset): "cluster_prompt": None, "default_prompt": None, "device": None, - "max_length": None, + "tokenizer_config": {"max_length": None, "truncation": True, "padding": True}, "model_name": "sergeyzh/rubert-tiny-turbo", "passage_prompt": None, "query_prompt": None, diff --git a/tests/modules/scoring/test_bert.py b/tests/modules/scoring/test_bert.py new file mode 100644 index 000000000..3ef319703 --- /dev/null +++ b/tests/modules/scoring/test_bert.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest + +from autointent.context.data_handler import DataHandler +from autointent.modules import BertScorer + + +def test_bert_prediction(dataset): + """Test that the transformer model can fit and make predictions.""" + data_handler = DataHandler(dataset) + + scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + + scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) + + test_data = [ + "why is there a hold on my american saving bank account", + "i am nost sure why my account is blocked", + "why is there a hold on my capital one checking account", + "i think my account is blocked but i do not know the reason", + "can you tell me why is my bank account frozen", + ] + + predictions = scorer.predict(test_data) + + # Verify prediction shape + assert predictions.shape[0] == len(test_data) + assert predictions.shape[1] == len(set(data_handler.train_labels(0))) + + # Verify predictions are probabilities + assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0 + + # Verify probabilities sum to 1 for multiclass + if not scorer._multilabel: + for pred_row in predictions: + np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5) + + # Test metadata function if available + if hasattr(scorer, "predict_with_metadata"): + predictions, metadata = scorer.predict_with_metadata(test_data) + assert len(predictions) == len(test_data) + assert metadata is None + + +def test_bert_cache_clearing(dataset): + """Test that the transformer model properly handles cache clearing.""" + data_handler = DataHandler(dataset) + + scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + + scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) + + test_data = ["test text"] + + # Should work before clearing cache + scorer.predict(test_data) + + # Clear the cache + scorer.clear_cache() + + # Verify model and tokenizer are removed + assert not hasattr(scorer, "_model") or scorer._model is None + assert not hasattr(scorer, "_tokenizer") or scorer._tokenizer is None + + # Should raise exception after clearing cache + with pytest.raises(RuntimeError): + scorer.predict(test_data) diff --git a/tests/pipeline/test_inference.py b/tests/pipeline/test_inference.py index a5e33c180..683112b49 100644 --- a/tests/pipeline/test_inference.py +++ b/tests/pipeline/test_inference.py @@ -1,7 +1,7 @@ import pytest from autointent import Pipeline -from autointent.configs import EmbedderConfig, LoggingConfig +from autointent.configs import EmbedderConfig, LoggingConfig, TokenizerConfig from autointent.custom_types import NodeType from tests.conftest import get_search_space, setup_environment @@ -99,7 +99,9 @@ def test_load_with_overrided_params(dataset): context.dump() # case 1: simple inference from file system - inference_pipeline = Pipeline.load(logging_config.dirpath, embedder_config=EmbedderConfig(max_length=8)) + inference_pipeline = Pipeline.load( + logging_config.dirpath, embedder_config=EmbedderConfig(tokenizer_config=TokenizerConfig(max_length=8)) + ) utterances = ["123", "hello world"] prediction = inference_pipeline.predict(utterances) assert len(prediction) == 2 @@ -107,17 +109,19 @@ def test_load_with_overrided_params(dataset): # case 2: rich inference from file system rich_outputs = inference_pipeline.predict_with_metadata(utterances) assert len(rich_outputs.predictions) == len(utterances) - assert inference_pipeline.nodes[NodeType.scoring].module._embedder.config.max_length == 8 + assert inference_pipeline.nodes[NodeType.scoring].module._embedder.config.tokenizer_config.max_length == 8 del inference_pipeline # case 3: dump and then load pipeline pipeline_optimizer.dump() del pipeline_optimizer - loaded_pipe = Pipeline.load(logging_config.dirpath, embedder_config=EmbedderConfig(max_length=8)) + loaded_pipe = Pipeline.load( + logging_config.dirpath, embedder_config=EmbedderConfig(tokenizer_config=TokenizerConfig(max_length=8)) + ) prediction_v2 = loaded_pipe.predict(utterances) assert prediction == prediction_v2 - assert loaded_pipe.nodes[NodeType.scoring].module._embedder.config.max_length == 8 + assert loaded_pipe.nodes[NodeType.scoring].module._embedder.config.tokenizer_config.max_length == 8 def test_no_saving(dataset): diff --git a/user_guides/basic_usage/03_automl.py b/user_guides/basic_usage/03_automl.py index 23caa8a6e..96b3ebde3 100644 --- a/user_guides/basic_usage/03_automl.py +++ b/user_guides/basic_usage/03_automl.py @@ -82,10 +82,12 @@ """ # %% -from autointent.configs import EmbedderConfig, CrossEncoderConfig +from autointent.configs import EmbedderConfig, CrossEncoderConfig, TokenizerConfig custom_pipeline.set_config(EmbedderConfig(model_name="prajjwal1/bert-tiny", device="cpu")) -custom_pipeline.set_config(CrossEncoderConfig(model_name="cross-encoder/ms-marco-MiniLM-L2-v2", max_length=8)) +custom_pipeline.set_config( + CrossEncoderConfig(model_name="cross-encoder/ms-marco-MiniLM-L2-v2", tokenizer_config=TokenizerConfig(max_length=8)) +) # %% [markdown] """