Skip to content

Commit 2d0dd56

Browse files
committed
update tests
1 parent d9f840f commit 2d0dd56

File tree

1 file changed

+34
-8
lines changed

1 file changed

+34
-8
lines changed

tests/pipeline/test_inference.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,44 @@
11
import pytest
22

33
from autointent import Pipeline
4-
from autointent.configs import LoggingConfig
4+
from autointent.configs import EmbedderConfig, LoggingConfig
5+
from autointent.custom_types import NodeType
56
from tests.conftest import get_search_space, setup_environment
67

78

89
@pytest.mark.parametrize(
910
"task_type",
1011
["multiclass", "multilabel", "description"],
1112
)
12-
def test_inference_config(dataset, task_type):
13+
def test_inference_from_config(dataset, task_type):
1314
project_dir = setup_environment()
1415
search_space = get_search_space(task_type)
1516

1617
pipeline_optimizer = Pipeline.from_search_space(search_space)
1718

18-
pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True))
19+
logging_config = LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)
20+
pipeline_optimizer.set_config(logging_config)
1921

2022
if task_type == "multilabel":
2123
dataset = dataset.to_multilabel()
2224

2325
context = pipeline_optimizer.fit(dataset)
24-
inference_config = context.optimization_info.get_inference_nodes_config()
26+
context.dump()
2527

26-
inference_pipeline = Pipeline.from_config(inference_config)
28+
inference_pipeline = Pipeline.load(logging_config.dirpath)
2729
utterances = ["123", "hello world"]
2830
prediction = inference_pipeline.predict(utterances)
2931
assert len(prediction) == 2
3032

3133
rich_outputs = inference_pipeline.predict_with_metadata(utterances)
3234
assert len(rich_outputs.predictions) == len(utterances)
3335

34-
context.dump()
35-
3636

3737
@pytest.mark.parametrize(
3838
"task_type",
3939
["multiclass", "multilabel", "description"],
4040
)
41-
def test_inference_context(dataset, task_type):
41+
def test_inference_on_the_fly(dataset, task_type):
4242
project_dir = setup_environment()
4343
search_space = get_search_space(task_type)
4444

@@ -59,3 +59,29 @@ def test_inference_context(dataset, task_type):
5959
assert len(rich_outputs.predictions) == len(utterances)
6060

6161
context.dump()
62+
63+
64+
def test_load_with_overrided_params(dataset):
65+
project_dir = setup_environment()
66+
search_space = get_search_space("light")
67+
68+
pipeline_optimizer = Pipeline.from_search_space(search_space)
69+
70+
logging_config = LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)
71+
pipeline_optimizer.set_config(logging_config)
72+
73+
context = pipeline_optimizer.fit(dataset)
74+
context.dump()
75+
76+
inference_pipeline = Pipeline.load(logging_config.dirpath, embedder_config=EmbedderConfig(max_length=8))
77+
utterances = ["123", "hello world"]
78+
prediction = inference_pipeline.predict(utterances)
79+
assert len(prediction) == 2
80+
81+
rich_outputs = inference_pipeline.predict_with_metadata(utterances)
82+
assert len(rich_outputs.predictions) == len(utterances)
83+
84+
assert inference_pipeline.nodes[NodeType.scoring].module._embedder.max_length == 8
85+
86+
87+
# TODO Pipeline.dump()

0 commit comments

Comments
 (0)