Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions autointent/_dump_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,12 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
np.savez(path / Dumper.arrays, allow_pickle=False, **arrays)

@staticmethod
def load(obj: Any, path: Path) -> None: # noqa: ANN401, PLR0912, C901, PLR0915
def load( # noqa: PLR0912, C901, PLR0915
obj: Any, # noqa: ANN401
path: Path,
embedder_config: EmbedderConfig | None = None,
cross_encoder_config: CrossEncoderConfig | None = None,
) -> None:
"""Load attributes from file system."""
tags: dict[str, Any] = {}
simple_attrs: dict[str, Any] = {}
Expand All @@ -119,15 +124,18 @@ def load(obj: Any, path: Path) -> None: # noqa: ANN401, PLR0912, C901, PLR0915
elif child.name == Dumper.arrays:
arrays = dict(np.load(child))
elif child.name == Dumper.embedders:
# TODO propagate custom loading params (such as device, batch size etc) to this line
embedders = {embedder_dump.name: Embedder.load(embedder_dump) for embedder_dump in child.iterdir()}
embedders = {
embedder_dump.name: Embedder.load(embedder_dump, override_config=embedder_config)
for embedder_dump in child.iterdir()
}
elif child.name == Dumper.indexes:
indexes = {index_dump.name: VectorIndex.load(index_dump) for index_dump in child.iterdir()}
elif child.name == Dumper.estimators:
estimators = {estimator_dump.name: joblib.load(estimator_dump) for estimator_dump in child.iterdir()}
elif child.name == Dumper.cross_encoders:
cross_encoders = {
cross_encoder_dump.name: Ranker.load(cross_encoder_dump) for cross_encoder_dump in child.iterdir()
cross_encoder_dump.name: Ranker.load(cross_encoder_dump, override_config=cross_encoder_config)
for cross_encoder_dump in child.iterdir()
}
elif child.name == Dumper.pydantic_models:
for model_file in child.iterdir():
Expand Down
22 changes: 10 additions & 12 deletions autointent/_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_embeddings_path(filename: str) -> Path:
class EmbedderDumpMetadata(TypedDict):
"""Metadata for saving and loading an Embedder instance."""

model_name_or_path: str
model_name: str
"""Name of the hugging face model or a local path to sentence transformers dump."""
device: str | None
"""Torch notation for CPU or CUDA."""
Expand Down Expand Up @@ -114,7 +114,7 @@ def dump(self, path: Path) -> None:
"""
self.dump_dir = path
metadata = EmbedderDumpMetadata(
model_name_or_path=str(self.model_name),
model_name=str(self.model_name),
device=self.device,
batch_size=self.batch_size,
max_length=self.max_length,
Expand All @@ -125,24 +125,22 @@ def dump(self, path: Path) -> None:
json.dump(metadata, file, indent=4)

@classmethod
def load(cls, path: Path | str) -> "Embedder":
def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -> "Embedder":
"""Load the embedding model and metadata from disk.

Args:
path: Path to the directory where the model is stored.
override_config: one can override presaved settings
"""
with (Path(path) / cls.metadata_dict_name).open() as file:
metadata: EmbedderDumpMetadata = json.load(file)

return cls(
EmbedderConfig(
model_name=metadata["model_name_or_path"],
device=metadata["device"],
batch_size=metadata["batch_size"],
max_length=metadata["max_length"],
use_cache=metadata["use_cache"],
)
)
if override_config is not None:
kwargs = {**metadata, **override_config.model_dump(exclude_unset=True)}
else:
kwargs = metadata # type: ignore[assignment]

return cls(EmbedderConfig(**kwargs))

def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]:
"""Calculate embeddings for a list of utterances.
Expand Down
23 changes: 18 additions & 5 deletions autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,20 +266,33 @@ def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "Pipeline":
return cls(nodes)

@classmethod
def load(cls, path: str | Path) -> "Pipeline":
def load(
cls,
path: str | Path,
embedder_config: EmbedderConfig | None = None,
cross_encoder_config: CrossEncoderConfig | None = None,
) -> "Pipeline":
"""Load pipeline in inference mode.

This method loads fitted modules and tuned hyperparameters.

Args:
path: Path to load
embedder_config: one can override presaved settings
cross_encoder_config: one can override presaved settings

Returns:
Inference pipeline
"""
with (Path(path) / "inference_config.yaml").open() as file:
inference_dict_config = yaml.safe_load(file)
return cls.from_dict_config(inference_dict_config["nodes_configs"])
inference_dict_config: dict[str, Any] = yaml.safe_load(file)

inference_config = [
InferenceNodeConfig(
**node_config, embedder_config=embedder_config, cross_encoder_config=cross_encoder_config
)
for node_config in inference_dict_config["nodes_configs"]
]

return cls.from_config(inference_config)

def predict(self, utterances: list[str]) -> ListOfGenericLabels:
"""Predict the labels for the utterances.
Expand Down
16 changes: 8 additions & 8 deletions autointent/_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,12 @@ def save(self, path: str) -> None:
joblib.dump(self._clf, dump_dir / self.classifier_file_name)

@classmethod
def load(cls, path: Path) -> "Ranker":
def load(cls, path: Path, override_config: CrossEncoderConfig | None = None) -> "Ranker":
"""Load the model and classifier from disk.

Args:
path: Directory path containing the saved model and classifier
override_config: one can override presaved settings

Returns:
Initialized Ranker instance
Expand All @@ -280,14 +281,13 @@ def load(cls, path: Path) -> "Ranker":
with (path / cls.metadata_file_name).open() as file:
metadata: CrossEncoderMetadata = json.load(file)

if override_config is not None:
kwargs = {**metadata, **override_config.model_dump(exclude_unset=True)}
else:
kwargs = metadata # type: ignore[assignment]

return cls(
CrossEncoderConfig(
model_name=metadata["model_name"],
device=metadata["device"],
max_length=metadata["max_length"],
batch_size=metadata["batch_size"],
train_head=metadata["train_classifier"],
),
CrossEncoderConfig(**kwargs),
classifier_head=clf,
)

Expand Down
6 changes: 6 additions & 0 deletions autointent/configs/_inference_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from autointent.custom_types import NodeType

from ._transformers import CrossEncoderConfig, EmbedderConfig


@dataclass
class InferenceNodeConfig:
Expand All @@ -18,3 +20,7 @@ class InferenceNodeConfig:
"""Configuration of the module"""
load_path: str | None = None
"""Path to the module dump. If None, the module will be trained from scratch"""
embedder_config: EmbedderConfig | None = None
"""One can override presaved embedder config while loading from file system."""
cross_encoder_config: CrossEncoderConfig | None = None
"""One can override presaved cross encoder config while loading from file system."""
14 changes: 11 additions & 3 deletions autointent/modules/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import assert_never

from autointent._dump_tools import Dumper
from autointent.configs import CrossEncoderConfig, EmbedderConfig
from autointent.context import Context
from autointent.context.optimization_info import Artifact
from autointent.custom_types import ListOfGenericLabels, ListOfLabels
Expand Down Expand Up @@ -88,13 +89,20 @@ def dump(self, path: str) -> None:
"""
Dumper.dump(self, Path(path))

def load(self, path: str) -> None:
"""Load data from dump.
def load(
self,
path: str,
embedder_config: EmbedderConfig | None = None,
cross_encoder_config: CrossEncoderConfig | None = None,
) -> None:
"""Load data from file system.

Args:
path: Path to load
embedder_config: one can override presaved settings
cross_encoder_config: one can override presaved settings
"""
Dumper.load(self, Path(path))
Dumper.load(self, Path(path), embedder_config=embedder_config, cross_encoder_config=cross_encoder_config)

@abstractmethod
def predict(
Expand Down
6 changes: 5 additions & 1 deletion autointent/nodes/_inference_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def from_config(cls, config: InferenceNodeConfig) -> "InferenceNode":
node_info = NODES_INFO[config.node_type]
module = node_info.modules_available[config.module_name](**config.module_config)
if config.load_path is not None:
module.load(config.load_path)
module.load(
config.load_path,
embedder_config=config.embedder_config,
cross_encoder_config=config.cross_encoder_config,
)
return cls(module, config.node_type)

def clear_cache(self) -> None:
Expand Down
42 changes: 34 additions & 8 deletions tests/pipeline/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,44 @@
import pytest

from autointent import Pipeline
from autointent.configs import LoggingConfig
from autointent.configs import EmbedderConfig, LoggingConfig
from autointent.custom_types import NodeType
from tests.conftest import get_search_space, setup_environment


@pytest.mark.parametrize(
"task_type",
["multiclass", "multilabel", "description"],
)
def test_inference_config(dataset, task_type):
def test_inference_from_config(dataset, task_type):
project_dir = setup_environment()
search_space = get_search_space(task_type)

pipeline_optimizer = Pipeline.from_search_space(search_space)

pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True))
logging_config = LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)
pipeline_optimizer.set_config(logging_config)

if task_type == "multilabel":
dataset = dataset.to_multilabel()

context = pipeline_optimizer.fit(dataset)
inference_config = context.optimization_info.get_inference_nodes_config()
context.dump()

inference_pipeline = Pipeline.from_config(inference_config)
inference_pipeline = Pipeline.load(logging_config.dirpath)
utterances = ["123", "hello world"]
prediction = inference_pipeline.predict(utterances)
assert len(prediction) == 2

rich_outputs = inference_pipeline.predict_with_metadata(utterances)
assert len(rich_outputs.predictions) == len(utterances)

context.dump()


@pytest.mark.parametrize(
"task_type",
["multiclass", "multilabel", "description"],
)
def test_inference_context(dataset, task_type):
def test_inference_on_the_fly(dataset, task_type):
project_dir = setup_environment()
search_space = get_search_space(task_type)

Expand All @@ -59,3 +59,29 @@ def test_inference_context(dataset, task_type):
assert len(rich_outputs.predictions) == len(utterances)

context.dump()


def test_load_with_overrided_params(dataset):
project_dir = setup_environment()
search_space = get_search_space("light")

pipeline_optimizer = Pipeline.from_search_space(search_space)

logging_config = LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)
pipeline_optimizer.set_config(logging_config)

context = pipeline_optimizer.fit(dataset)
context.dump()

inference_pipeline = Pipeline.load(logging_config.dirpath, embedder_config=EmbedderConfig(max_length=8))
utterances = ["123", "hello world"]
prediction = inference_pipeline.predict(utterances)
assert len(prediction) == 2

rich_outputs = inference_pipeline.predict_with_metadata(utterances)
assert len(rich_outputs.predictions) == len(utterances)

assert inference_pipeline.nodes[NodeType.scoring].module._embedder.max_length == 8


# TODO Pipeline.dump()