11import pytest
22
33from autointent import Pipeline
4- from autointent .configs import LoggingConfig
4+ from autointent .configs import EmbedderConfig , LoggingConfig
5+ from autointent .custom_types import NodeType
56from tests .conftest import get_search_space , setup_environment
67
78
89@pytest .mark .parametrize (
910 "task_type" ,
1011 ["multiclass" , "multilabel" , "description" ],
1112)
12- def test_inference_config (dataset , task_type ):
13+ def test_inference_from_config (dataset , task_type ):
1314 project_dir = setup_environment ()
1415 search_space = get_search_space (task_type )
1516
1617 pipeline_optimizer = Pipeline .from_search_space (search_space )
1718
18- pipeline_optimizer .set_config (LoggingConfig (project_dir = project_dir , dump_modules = True , clear_ram = True ))
19+ logging_config = LoggingConfig (project_dir = project_dir , dump_modules = True , clear_ram = True )
20+ pipeline_optimizer .set_config (logging_config )
1921
2022 if task_type == "multilabel" :
2123 dataset = dataset .to_multilabel ()
2224
2325 context = pipeline_optimizer .fit (dataset )
24- inference_config = context .optimization_info . get_inference_nodes_config ()
26+ context .dump ()
2527
26- inference_pipeline = Pipeline .from_config ( inference_config )
28+ inference_pipeline = Pipeline .load ( logging_config . dirpath )
2729 utterances = ["123" , "hello world" ]
2830 prediction = inference_pipeline .predict (utterances )
2931 assert len (prediction ) == 2
3032
3133 rich_outputs = inference_pipeline .predict_with_metadata (utterances )
3234 assert len (rich_outputs .predictions ) == len (utterances )
3335
34- context .dump ()
35-
3636
3737@pytest .mark .parametrize (
3838 "task_type" ,
3939 ["multiclass" , "multilabel" , "description" ],
4040)
41- def test_inference_context (dataset , task_type ):
41+ def test_inference_on_the_fly (dataset , task_type ):
4242 project_dir = setup_environment ()
4343 search_space = get_search_space (task_type )
4444
@@ -59,3 +59,29 @@ def test_inference_context(dataset, task_type):
5959 assert len (rich_outputs .predictions ) == len (utterances )
6060
6161 context .dump ()
62+
63+
64+ def test_load_with_overrided_params (dataset ):
65+ project_dir = setup_environment ()
66+ search_space = get_search_space ("light" )
67+
68+ pipeline_optimizer = Pipeline .from_search_space (search_space )
69+
70+ logging_config = LoggingConfig (project_dir = project_dir , dump_modules = True , clear_ram = True )
71+ pipeline_optimizer .set_config (logging_config )
72+
73+ context = pipeline_optimizer .fit (dataset )
74+ context .dump ()
75+
76+ inference_pipeline = Pipeline .load (logging_config .dirpath , embedder_config = EmbedderConfig (max_length = 8 ))
77+ utterances = ["123" , "hello world" ]
78+ prediction = inference_pipeline .predict (utterances )
79+ assert len (prediction ) == 2
80+
81+ rich_outputs = inference_pipeline .predict_with_metadata (utterances )
82+ assert len (rich_outputs .predictions ) == len (utterances )
83+
84+ assert inference_pipeline .nodes [NodeType .scoring ].module ._embedder .max_length == 8
85+
86+
87+ # TODO Pipeline.dump()
0 commit comments