Skip to content

Commit 2122a33

Browse files
authored
Feat/load to device (#155)
* implement the idea * fix typing * bug fix * update tests
1 parent 8959abd commit 2122a33

File tree

8 files changed

+104
-41
lines changed

8 files changed

+104
-41
lines changed

autointent/_dump_tools.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,12 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
9999
np.savez(path / Dumper.arrays, allow_pickle=False, **arrays)
100100

101101
@staticmethod
102-
def load(obj: Any, path: Path) -> None: # noqa: ANN401, PLR0912, C901, PLR0915
102+
def load( # noqa: PLR0912, C901, PLR0915
103+
obj: Any, # noqa: ANN401
104+
path: Path,
105+
embedder_config: EmbedderConfig | None = None,
106+
cross_encoder_config: CrossEncoderConfig | None = None,
107+
) -> None:
103108
"""Load attributes from file system."""
104109
tags: dict[str, Any] = {}
105110
simple_attrs: dict[str, Any] = {}
@@ -119,15 +124,18 @@ def load(obj: Any, path: Path) -> None: # noqa: ANN401, PLR0912, C901, PLR0915
119124
elif child.name == Dumper.arrays:
120125
arrays = dict(np.load(child))
121126
elif child.name == Dumper.embedders:
122-
# TODO propagate custom loading params (such as device, batch size etc) to this line
123-
embedders = {embedder_dump.name: Embedder.load(embedder_dump) for embedder_dump in child.iterdir()}
127+
embedders = {
128+
embedder_dump.name: Embedder.load(embedder_dump, override_config=embedder_config)
129+
for embedder_dump in child.iterdir()
130+
}
124131
elif child.name == Dumper.indexes:
125132
indexes = {index_dump.name: VectorIndex.load(index_dump) for index_dump in child.iterdir()}
126133
elif child.name == Dumper.estimators:
127134
estimators = {estimator_dump.name: joblib.load(estimator_dump) for estimator_dump in child.iterdir()}
128135
elif child.name == Dumper.cross_encoders:
129136
cross_encoders = {
130-
cross_encoder_dump.name: Ranker.load(cross_encoder_dump) for cross_encoder_dump in child.iterdir()
137+
cross_encoder_dump.name: Ranker.load(cross_encoder_dump, override_config=cross_encoder_config)
138+
for cross_encoder_dump in child.iterdir()
131139
}
132140
elif child.name == Dumper.pydantic_models:
133141
for model_file in child.iterdir():

autointent/_embedder.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_embeddings_path(filename: str) -> Path:
4040
class EmbedderDumpMetadata(TypedDict):
4141
"""Metadata for saving and loading an Embedder instance."""
4242

43-
model_name_or_path: str
43+
model_name: str
4444
"""Name of the hugging face model or a local path to sentence transformers dump."""
4545
device: str | None
4646
"""Torch notation for CPU or CUDA."""
@@ -114,7 +114,7 @@ def dump(self, path: Path) -> None:
114114
"""
115115
self.dump_dir = path
116116
metadata = EmbedderDumpMetadata(
117-
model_name_or_path=str(self.model_name),
117+
model_name=str(self.model_name),
118118
device=self.device,
119119
batch_size=self.batch_size,
120120
max_length=self.max_length,
@@ -125,24 +125,22 @@ def dump(self, path: Path) -> None:
125125
json.dump(metadata, file, indent=4)
126126

127127
@classmethod
128-
def load(cls, path: Path | str) -> "Embedder":
128+
def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -> "Embedder":
129129
"""Load the embedding model and metadata from disk.
130130
131131
Args:
132132
path: Path to the directory where the model is stored.
133+
override_config: one can override presaved settings
133134
"""
134135
with (Path(path) / cls.metadata_dict_name).open() as file:
135136
metadata: EmbedderDumpMetadata = json.load(file)
136137

137-
return cls(
138-
EmbedderConfig(
139-
model_name=metadata["model_name_or_path"],
140-
device=metadata["device"],
141-
batch_size=metadata["batch_size"],
142-
max_length=metadata["max_length"],
143-
use_cache=metadata["use_cache"],
144-
)
145-
)
138+
if override_config is not None:
139+
kwargs = {**metadata, **override_config.model_dump(exclude_unset=True)}
140+
else:
141+
kwargs = metadata # type: ignore[assignment]
142+
143+
return cls(EmbedderConfig(**kwargs))
146144

147145
def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]:
148146
"""Calculate embeddings for a list of utterances.

autointent/_pipeline/_pipeline.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,20 +266,33 @@ def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "Pipeline":
266266
return cls(nodes)
267267

268268
@classmethod
269-
def load(cls, path: str | Path) -> "Pipeline":
269+
def load(
270+
cls,
271+
path: str | Path,
272+
embedder_config: EmbedderConfig | None = None,
273+
cross_encoder_config: CrossEncoderConfig | None = None,
274+
) -> "Pipeline":
270275
"""Load pipeline in inference mode.
271276
272-
This method loads fitted modules and tuned hyperparameters.
273-
274277
Args:
275278
path: Path to load
279+
embedder_config: one can override presaved settings
280+
cross_encoder_config: one can override presaved settings
276281
277282
Returns:
278283
Inference pipeline
279284
"""
280285
with (Path(path) / "inference_config.yaml").open() as file:
281-
inference_dict_config = yaml.safe_load(file)
282-
return cls.from_dict_config(inference_dict_config["nodes_configs"])
286+
inference_dict_config: dict[str, Any] = yaml.safe_load(file)
287+
288+
inference_config = [
289+
InferenceNodeConfig(
290+
**node_config, embedder_config=embedder_config, cross_encoder_config=cross_encoder_config
291+
)
292+
for node_config in inference_dict_config["nodes_configs"]
293+
]
294+
295+
return cls.from_config(inference_config)
283296

284297
def predict(self, utterances: list[str]) -> ListOfGenericLabels:
285298
"""Predict the labels for the utterances.

autointent/_ranker.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,12 @@ def save(self, path: str) -> None:
266266
joblib.dump(self._clf, dump_dir / self.classifier_file_name)
267267

268268
@classmethod
269-
def load(cls, path: Path) -> "Ranker":
269+
def load(cls, path: Path, override_config: CrossEncoderConfig | None = None) -> "Ranker":
270270
"""Load the model and classifier from disk.
271271
272272
Args:
273273
path: Directory path containing the saved model and classifier
274+
override_config: one can override presaved settings
274275
275276
Returns:
276277
Initialized Ranker instance
@@ -280,14 +281,13 @@ def load(cls, path: Path) -> "Ranker":
280281
with (path / cls.metadata_file_name).open() as file:
281282
metadata: CrossEncoderMetadata = json.load(file)
282283

284+
if override_config is not None:
285+
kwargs = {**metadata, **override_config.model_dump(exclude_unset=True)}
286+
else:
287+
kwargs = metadata # type: ignore[assignment]
288+
283289
return cls(
284-
CrossEncoderConfig(
285-
model_name=metadata["model_name"],
286-
device=metadata["device"],
287-
max_length=metadata["max_length"],
288-
batch_size=metadata["batch_size"],
289-
train_head=metadata["train_classifier"],
290-
),
290+
CrossEncoderConfig(**kwargs),
291291
classifier_head=clf,
292292
)
293293

autointent/configs/_inference_node.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from autointent.custom_types import NodeType
77

8+
from ._transformers import CrossEncoderConfig, EmbedderConfig
9+
810

911
@dataclass
1012
class InferenceNodeConfig:
@@ -18,3 +20,7 @@ class InferenceNodeConfig:
1820
"""Configuration of the module"""
1921
load_path: str | None = None
2022
"""Path to the module dump. If None, the module will be trained from scratch"""
23+
embedder_config: EmbedderConfig | None = None
24+
"""One can override presaved embedder config while loading from file system."""
25+
cross_encoder_config: CrossEncoderConfig | None = None
26+
"""One can override presaved cross encoder config while loading from file system."""

autointent/modules/base/_base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing_extensions import assert_never
1212

1313
from autointent._dump_tools import Dumper
14+
from autointent.configs import CrossEncoderConfig, EmbedderConfig
1415
from autointent.context import Context
1516
from autointent.context.optimization_info import Artifact
1617
from autointent.custom_types import ListOfGenericLabels, ListOfLabels
@@ -88,13 +89,20 @@ def dump(self, path: str) -> None:
8889
"""
8990
Dumper.dump(self, Path(path))
9091

91-
def load(self, path: str) -> None:
92-
"""Load data from dump.
92+
def load(
93+
self,
94+
path: str,
95+
embedder_config: EmbedderConfig | None = None,
96+
cross_encoder_config: CrossEncoderConfig | None = None,
97+
) -> None:
98+
"""Load data from file system.
9399
94100
Args:
95101
path: Path to load
102+
embedder_config: one can override presaved settings
103+
cross_encoder_config: one can override presaved settings
96104
"""
97-
Dumper.load(self, Path(path))
105+
Dumper.load(self, Path(path), embedder_config=embedder_config, cross_encoder_config=cross_encoder_config)
98106

99107
@abstractmethod
100108
def predict(

autointent/nodes/_inference_node.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ def from_config(cls, config: InferenceNodeConfig) -> "InferenceNode":
3333
node_info = NODES_INFO[config.node_type]
3434
module = node_info.modules_available[config.module_name](**config.module_config)
3535
if config.load_path is not None:
36-
module.load(config.load_path)
36+
module.load(
37+
config.load_path,
38+
embedder_config=config.embedder_config,
39+
cross_encoder_config=config.cross_encoder_config,
40+
)
3741
return cls(module, config.node_type)
3842

3943
def clear_cache(self) -> None:

tests/pipeline/test_inference.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,44 @@
11
import pytest
22

33
from autointent import Pipeline
4-
from autointent.configs import LoggingConfig
4+
from autointent.configs import EmbedderConfig, LoggingConfig
5+
from autointent.custom_types import NodeType
56
from tests.conftest import get_search_space, setup_environment
67

78

89
@pytest.mark.parametrize(
910
"task_type",
1011
["multiclass", "multilabel", "description"],
1112
)
12-
def test_inference_config(dataset, task_type):
13+
def test_inference_from_config(dataset, task_type):
1314
project_dir = setup_environment()
1415
search_space = get_search_space(task_type)
1516

1617
pipeline_optimizer = Pipeline.from_search_space(search_space)
1718

18-
pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True))
19+
logging_config = LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)
20+
pipeline_optimizer.set_config(logging_config)
1921

2022
if task_type == "multilabel":
2123
dataset = dataset.to_multilabel()
2224

2325
context = pipeline_optimizer.fit(dataset)
24-
inference_config = context.optimization_info.get_inference_nodes_config()
26+
context.dump()
2527

26-
inference_pipeline = Pipeline.from_config(inference_config)
28+
inference_pipeline = Pipeline.load(logging_config.dirpath)
2729
utterances = ["123", "hello world"]
2830
prediction = inference_pipeline.predict(utterances)
2931
assert len(prediction) == 2
3032

3133
rich_outputs = inference_pipeline.predict_with_metadata(utterances)
3234
assert len(rich_outputs.predictions) == len(utterances)
3335

34-
context.dump()
35-
3636

3737
@pytest.mark.parametrize(
3838
"task_type",
3939
["multiclass", "multilabel", "description"],
4040
)
41-
def test_inference_context(dataset, task_type):
41+
def test_inference_on_the_fly(dataset, task_type):
4242
project_dir = setup_environment()
4343
search_space = get_search_space(task_type)
4444

@@ -59,3 +59,29 @@ def test_inference_context(dataset, task_type):
5959
assert len(rich_outputs.predictions) == len(utterances)
6060

6161
context.dump()
62+
63+
64+
def test_load_with_overrided_params(dataset):
65+
project_dir = setup_environment()
66+
search_space = get_search_space("light")
67+
68+
pipeline_optimizer = Pipeline.from_search_space(search_space)
69+
70+
logging_config = LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)
71+
pipeline_optimizer.set_config(logging_config)
72+
73+
context = pipeline_optimizer.fit(dataset)
74+
context.dump()
75+
76+
inference_pipeline = Pipeline.load(logging_config.dirpath, embedder_config=EmbedderConfig(max_length=8))
77+
utterances = ["123", "hello world"]
78+
prediction = inference_pipeline.predict(utterances)
79+
assert len(prediction) == 2
80+
81+
rich_outputs = inference_pipeline.predict_with_metadata(utterances)
82+
assert len(rich_outputs.predictions) == len(utterances)
83+
84+
assert inference_pipeline.nodes[NodeType.scoring].module._embedder.max_length == 8
85+
86+
87+
# TODO Pipeline.dump()

0 commit comments

Comments
 (0)