diff --git a/autointent/_dump_tools.py b/autointent/_dump_tools.py index dc932c850..13a7ba953 100644 --- a/autointent/_dump_tools.py +++ b/autointent/_dump_tools.py @@ -99,7 +99,12 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901 np.savez(path / Dumper.arrays, allow_pickle=False, **arrays) @staticmethod - def load(obj: Any, path: Path) -> None: # noqa: ANN401, PLR0912, C901, PLR0915 + def load( # noqa: PLR0912, C901, PLR0915 + obj: Any, # noqa: ANN401 + path: Path, + embedder_config: EmbedderConfig | None = None, + cross_encoder_config: CrossEncoderConfig | None = None, + ) -> None: """Load attributes from file system.""" tags: dict[str, Any] = {} simple_attrs: dict[str, Any] = {} @@ -119,15 +124,18 @@ def load(obj: Any, path: Path) -> None: # noqa: ANN401, PLR0912, C901, PLR0915 elif child.name == Dumper.arrays: arrays = dict(np.load(child)) elif child.name == Dumper.embedders: - # TODO propagate custom loading params (such as device, batch size etc) to this line - embedders = {embedder_dump.name: Embedder.load(embedder_dump) for embedder_dump in child.iterdir()} + embedders = { + embedder_dump.name: Embedder.load(embedder_dump, override_config=embedder_config) + for embedder_dump in child.iterdir() + } elif child.name == Dumper.indexes: indexes = {index_dump.name: VectorIndex.load(index_dump) for index_dump in child.iterdir()} elif child.name == Dumper.estimators: estimators = {estimator_dump.name: joblib.load(estimator_dump) for estimator_dump in child.iterdir()} elif child.name == Dumper.cross_encoders: cross_encoders = { - cross_encoder_dump.name: Ranker.load(cross_encoder_dump) for cross_encoder_dump in child.iterdir() + cross_encoder_dump.name: Ranker.load(cross_encoder_dump, override_config=cross_encoder_config) + for cross_encoder_dump in child.iterdir() } elif child.name == Dumper.pydantic_models: for model_file in child.iterdir(): diff --git a/autointent/_embedder.py b/autointent/_embedder.py index fddd65149..f1ceeb36e 100644 --- a/autointent/_embedder.py +++ b/autointent/_embedder.py @@ -40,7 +40,7 @@ def get_embeddings_path(filename: str) -> Path: class EmbedderDumpMetadata(TypedDict): """Metadata for saving and loading an Embedder instance.""" - model_name_or_path: str + model_name: str """Name of the hugging face model or a local path to sentence transformers dump.""" device: str | None """Torch notation for CPU or CUDA.""" @@ -114,7 +114,7 @@ def dump(self, path: Path) -> None: """ self.dump_dir = path metadata = EmbedderDumpMetadata( - model_name_or_path=str(self.model_name), + model_name=str(self.model_name), device=self.device, batch_size=self.batch_size, max_length=self.max_length, @@ -125,24 +125,22 @@ def dump(self, path: Path) -> None: json.dump(metadata, file, indent=4) @classmethod - def load(cls, path: Path | str) -> "Embedder": + def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -> "Embedder": """Load the embedding model and metadata from disk. Args: path: Path to the directory where the model is stored. + override_config: one can override presaved settings """ with (Path(path) / cls.metadata_dict_name).open() as file: metadata: EmbedderDumpMetadata = json.load(file) - return cls( - EmbedderConfig( - model_name=metadata["model_name_or_path"], - device=metadata["device"], - batch_size=metadata["batch_size"], - max_length=metadata["max_length"], - use_cache=metadata["use_cache"], - ) - ) + if override_config is not None: + kwargs = {**metadata, **override_config.model_dump(exclude_unset=True)} + else: + kwargs = metadata # type: ignore[assignment] + + return cls(EmbedderConfig(**kwargs)) def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]: """Calculate embeddings for a list of utterances. diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 096d29224..fc00981ac 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -266,20 +266,33 @@ def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "Pipeline": return cls(nodes) @classmethod - def load(cls, path: str | Path) -> "Pipeline": + def load( + cls, + path: str | Path, + embedder_config: EmbedderConfig | None = None, + cross_encoder_config: CrossEncoderConfig | None = None, + ) -> "Pipeline": """Load pipeline in inference mode. - This method loads fitted modules and tuned hyperparameters. - Args: path: Path to load + embedder_config: one can override presaved settings + cross_encoder_config: one can override presaved settings Returns: Inference pipeline """ with (Path(path) / "inference_config.yaml").open() as file: - inference_dict_config = yaml.safe_load(file) - return cls.from_dict_config(inference_dict_config["nodes_configs"]) + inference_dict_config: dict[str, Any] = yaml.safe_load(file) + + inference_config = [ + InferenceNodeConfig( + **node_config, embedder_config=embedder_config, cross_encoder_config=cross_encoder_config + ) + for node_config in inference_dict_config["nodes_configs"] + ] + + return cls.from_config(inference_config) def predict(self, utterances: list[str]) -> ListOfGenericLabels: """Predict the labels for the utterances. diff --git a/autointent/_ranker.py b/autointent/_ranker.py index fb72ea50a..9bb96f627 100644 --- a/autointent/_ranker.py +++ b/autointent/_ranker.py @@ -266,11 +266,12 @@ def save(self, path: str) -> None: joblib.dump(self._clf, dump_dir / self.classifier_file_name) @classmethod - def load(cls, path: Path) -> "Ranker": + def load(cls, path: Path, override_config: CrossEncoderConfig | None = None) -> "Ranker": """Load the model and classifier from disk. Args: path: Directory path containing the saved model and classifier + override_config: one can override presaved settings Returns: Initialized Ranker instance @@ -280,14 +281,13 @@ def load(cls, path: Path) -> "Ranker": with (path / cls.metadata_file_name).open() as file: metadata: CrossEncoderMetadata = json.load(file) + if override_config is not None: + kwargs = {**metadata, **override_config.model_dump(exclude_unset=True)} + else: + kwargs = metadata # type: ignore[assignment] + return cls( - CrossEncoderConfig( - model_name=metadata["model_name"], - device=metadata["device"], - max_length=metadata["max_length"], - batch_size=metadata["batch_size"], - train_head=metadata["train_classifier"], - ), + CrossEncoderConfig(**kwargs), classifier_head=clf, ) diff --git a/autointent/configs/_inference_node.py b/autointent/configs/_inference_node.py index b99c684db..09fe1ed47 100644 --- a/autointent/configs/_inference_node.py +++ b/autointent/configs/_inference_node.py @@ -5,6 +5,8 @@ from autointent.custom_types import NodeType +from ._transformers import CrossEncoderConfig, EmbedderConfig + @dataclass class InferenceNodeConfig: @@ -18,3 +20,7 @@ class InferenceNodeConfig: """Configuration of the module""" load_path: str | None = None """Path to the module dump. If None, the module will be trained from scratch""" + embedder_config: EmbedderConfig | None = None + """One can override presaved embedder config while loading from file system.""" + cross_encoder_config: CrossEncoderConfig | None = None + """One can override presaved cross encoder config while loading from file system.""" diff --git a/autointent/modules/base/_base.py b/autointent/modules/base/_base.py index 2614426f5..2e250295c 100644 --- a/autointent/modules/base/_base.py +++ b/autointent/modules/base/_base.py @@ -11,6 +11,7 @@ from typing_extensions import assert_never from autointent._dump_tools import Dumper +from autointent.configs import CrossEncoderConfig, EmbedderConfig from autointent.context import Context from autointent.context.optimization_info import Artifact from autointent.custom_types import ListOfGenericLabels, ListOfLabels @@ -88,13 +89,20 @@ def dump(self, path: str) -> None: """ Dumper.dump(self, Path(path)) - def load(self, path: str) -> None: - """Load data from dump. + def load( + self, + path: str, + embedder_config: EmbedderConfig | None = None, + cross_encoder_config: CrossEncoderConfig | None = None, + ) -> None: + """Load data from file system. Args: path: Path to load + embedder_config: one can override presaved settings + cross_encoder_config: one can override presaved settings """ - Dumper.load(self, Path(path)) + Dumper.load(self, Path(path), embedder_config=embedder_config, cross_encoder_config=cross_encoder_config) @abstractmethod def predict( diff --git a/autointent/nodes/_inference_node.py b/autointent/nodes/_inference_node.py index 21ecded95..9df15dda8 100644 --- a/autointent/nodes/_inference_node.py +++ b/autointent/nodes/_inference_node.py @@ -33,7 +33,11 @@ def from_config(cls, config: InferenceNodeConfig) -> "InferenceNode": node_info = NODES_INFO[config.node_type] module = node_info.modules_available[config.module_name](**config.module_config) if config.load_path is not None: - module.load(config.load_path) + module.load( + config.load_path, + embedder_config=config.embedder_config, + cross_encoder_config=config.cross_encoder_config, + ) return cls(module, config.node_type) def clear_cache(self) -> None: diff --git a/tests/pipeline/test_inference.py b/tests/pipeline/test_inference.py index 856ba81ba..3e864fd03 100644 --- a/tests/pipeline/test_inference.py +++ b/tests/pipeline/test_inference.py @@ -1,7 +1,8 @@ import pytest from autointent import Pipeline -from autointent.configs import LoggingConfig +from autointent.configs import EmbedderConfig, LoggingConfig +from autointent.custom_types import NodeType from tests.conftest import get_search_space, setup_environment @@ -9,21 +10,22 @@ "task_type", ["multiclass", "multilabel", "description"], ) -def test_inference_config(dataset, task_type): +def test_inference_from_config(dataset, task_type): project_dir = setup_environment() search_space = get_search_space(task_type) pipeline_optimizer = Pipeline.from_search_space(search_space) - pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)) + logging_config = LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True) + pipeline_optimizer.set_config(logging_config) if task_type == "multilabel": dataset = dataset.to_multilabel() context = pipeline_optimizer.fit(dataset) - inference_config = context.optimization_info.get_inference_nodes_config() + context.dump() - inference_pipeline = Pipeline.from_config(inference_config) + inference_pipeline = Pipeline.load(logging_config.dirpath) utterances = ["123", "hello world"] prediction = inference_pipeline.predict(utterances) assert len(prediction) == 2 @@ -31,14 +33,12 @@ def test_inference_config(dataset, task_type): rich_outputs = inference_pipeline.predict_with_metadata(utterances) assert len(rich_outputs.predictions) == len(utterances) - context.dump() - @pytest.mark.parametrize( "task_type", ["multiclass", "multilabel", "description"], ) -def test_inference_context(dataset, task_type): +def test_inference_on_the_fly(dataset, task_type): project_dir = setup_environment() search_space = get_search_space(task_type) @@ -59,3 +59,29 @@ def test_inference_context(dataset, task_type): assert len(rich_outputs.predictions) == len(utterances) context.dump() + + +def test_load_with_overrided_params(dataset): + project_dir = setup_environment() + search_space = get_search_space("light") + + pipeline_optimizer = Pipeline.from_search_space(search_space) + + logging_config = LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True) + pipeline_optimizer.set_config(logging_config) + + context = pipeline_optimizer.fit(dataset) + context.dump() + + inference_pipeline = Pipeline.load(logging_config.dirpath, embedder_config=EmbedderConfig(max_length=8)) + utterances = ["123", "hello world"] + prediction = inference_pipeline.predict(utterances) + assert len(prediction) == 2 + + rich_outputs = inference_pipeline.predict_with_metadata(utterances) + assert len(rich_outputs.predictions) == len(utterances) + + assert inference_pipeline.nodes[NodeType.scoring].module._embedder.max_length == 8 + + +# TODO Pipeline.dump()