1717 otx_train_testing ,
1818)
1919
20- args_polygon = {
20+ args = {
2121 "--train-data-roots" : "tests/assets/car_tree_bug" ,
2222 "--val-data-roots" : "tests/assets/car_tree_bug" ,
2323 "--test-data-roots" : "tests/assets/car_tree_bug" ,
3333 ],
3434}
3535
36- args_mask = {
37- "--train-data-roots" : "tests/assets/car_tree_bug" ,
38- "--val-data-roots" : "tests/assets/car_tree_bug" ,
39- "--test-data-roots" : "tests/assets/car_tree_bug" ,
40- "--input" : "tests/assets/car_tree_bug/images/train" ,
41- "train_params" : [
42- "params" ,
43- "--learning_parameters.trainer.max_epochs" ,
44- "1" ,
45- "--learning_parameters.dataset.train_batch_size" ,
46- "2" ,
47- "--learning_parameters.dataset.use_mask" ,
48- "True" ,
49- ],
50- }
51-
5236# Training params for resume, num_iters*2
5337resume_params = [
5438 "params" ,
6145otx_dir = os .getcwd ()
6246
6347
64- templates = Registry ("src/otx/algorithms/visual_prompting" ).filter (task_type = "VISUAL_PROMPTING" ).templates
48+ templates = (
49+ Registry ("src/otx/algorithms/visual_prompting" , experimental = True ).filter (task_type = "VISUAL_PROMPTING" ).templates
50+ )
6551templates_ids = [template .model_template_id for template in templates ]
6652
6753
6854class TestVisualPromptingCLI :
6955 @e2e_pytest_component
7056 @pytest .mark .parametrize ("template" , templates , ids = templates_ids )
71- @pytest .mark .parametrize ("args" , [args_polygon , args_mask ])
72- def test_otx_train (self , args , template , tmp_dir_path ):
57+ def test_otx_train (self , template , tmp_dir_path ):
7358 tmp_dir_path = tmp_dir_path / "visual_prompting"
7459 otx_train_testing (template , tmp_dir_path , otx_dir , args , deterministic = False )
7560
7661 @e2e_pytest_component
7762 @pytest .mark .parametrize ("template" , templates , ids = templates_ids )
78- @pytest .mark .parametrize ("args" , [args_polygon , args_mask ])
79- def test_otx_resume (self , args , template , tmp_dir_path ):
63+ def test_otx_resume (self , template , tmp_dir_path ):
8064 tmp_dir_path = tmp_dir_path / "visual_prompting/test_resume"
8165 otx_resume_testing (template , tmp_dir_path , otx_dir , args )
8266 template_work_dir = get_template_dir (template , tmp_dir_path )
@@ -89,7 +73,6 @@ def test_otx_resume(self, args, template, tmp_dir_path):
8973
9074 @e2e_pytest_component
9175 @pytest .mark .parametrize ("template" , templates , ids = templates_ids )
92- @pytest .mark .parametrize ("args" , [args_polygon , args_mask ])
93- def test_otx_eval (self , args , template , tmp_dir_path ):
76+ def test_otx_eval (self , template , tmp_dir_path ):
9477 tmp_dir_path = tmp_dir_path / "visual_prompting"
9578 otx_eval_testing (template , tmp_dir_path , otx_dir , args )
0 commit comments