Skip to content

Commit 572f883

Browse files
authored
Fix visual prompting template file name (#2303)
* Fix template name * Fix template path in test helper * Fix * Fix import name * Fix * Edit integration test cases * Fix
1 parent e5b1000 commit 572f883

File tree

4 files changed

+13
-25
lines changed

4 files changed

+13
-25
lines changed

src/otx/algorithms/visual_prompting/tasks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@
22

33
# Copyright (C) 2023 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
5+
6+
from .inference import InferenceTask # noqa: F401
7+
from .train import TrainingTask # noqa: F401

tests/integration/cli/visual_prompting/test_visual_prompting.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
otx_train_testing,
1818
)
1919

20-
args_polygon = {
20+
args = {
2121
"--train-data-roots": "tests/assets/car_tree_bug",
2222
"--val-data-roots": "tests/assets/car_tree_bug",
2323
"--test-data-roots": "tests/assets/car_tree_bug",
@@ -33,22 +33,6 @@
3333
],
3434
}
3535

36-
args_mask = {
37-
"--train-data-roots": "tests/assets/car_tree_bug",
38-
"--val-data-roots": "tests/assets/car_tree_bug",
39-
"--test-data-roots": "tests/assets/car_tree_bug",
40-
"--input": "tests/assets/car_tree_bug/images/train",
41-
"train_params": [
42-
"params",
43-
"--learning_parameters.trainer.max_epochs",
44-
"1",
45-
"--learning_parameters.dataset.train_batch_size",
46-
"2",
47-
"--learning_parameters.dataset.use_mask",
48-
"True",
49-
],
50-
}
51-
5236
# Training params for resume, num_iters*2
5337
resume_params = [
5438
"params",
@@ -61,22 +45,22 @@
6145
otx_dir = os.getcwd()
6246

6347

64-
templates = Registry("src/otx/algorithms/visual_prompting").filter(task_type="VISUAL_PROMPTING").templates
48+
templates = (
49+
Registry("src/otx/algorithms/visual_prompting", experimental=True).filter(task_type="VISUAL_PROMPTING").templates
50+
)
6551
templates_ids = [template.model_template_id for template in templates]
6652

6753

6854
class TestVisualPromptingCLI:
6955
@e2e_pytest_component
7056
@pytest.mark.parametrize("template", templates, ids=templates_ids)
71-
@pytest.mark.parametrize("args", [args_polygon, args_mask])
72-
def test_otx_train(self, args, template, tmp_dir_path):
57+
def test_otx_train(self, template, tmp_dir_path):
7358
tmp_dir_path = tmp_dir_path / "visual_prompting"
7459
otx_train_testing(template, tmp_dir_path, otx_dir, args, deterministic=False)
7560

7661
@e2e_pytest_component
7762
@pytest.mark.parametrize("template", templates, ids=templates_ids)
78-
@pytest.mark.parametrize("args", [args_polygon, args_mask])
79-
def test_otx_resume(self, args, template, tmp_dir_path):
63+
def test_otx_resume(self, template, tmp_dir_path):
8064
tmp_dir_path = tmp_dir_path / "visual_prompting/test_resume"
8165
otx_resume_testing(template, tmp_dir_path, otx_dir, args)
8266
template_work_dir = get_template_dir(template, tmp_dir_path)
@@ -89,7 +73,6 @@ def test_otx_resume(self, args, template, tmp_dir_path):
8973

9074
@e2e_pytest_component
9175
@pytest.mark.parametrize("template", templates, ids=templates_ids)
92-
@pytest.mark.parametrize("args", [args_polygon, args_mask])
93-
def test_otx_eval(self, args, template, tmp_dir_path):
76+
def test_otx_eval(self, template, tmp_dir_path):
9477
tmp_dir_path = tmp_dir_path / "visual_prompting"
9578
otx_eval_testing(template, tmp_dir_path, otx_dir, args)

tests/unit/algorithms/visual_prompting/test_helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def generate_visual_prompting_dataset(use_mask: bool = False) -> DatasetEntity:
104104

105105

106106
def init_environment(model: Optional[ModelEntity] = None):
107-
model_template = parse_model_template(os.path.join(DEFAULT_VISUAL_PROMPTING_TEMPLATE_DIR, "template.yaml"))
107+
model_template = parse_model_template(
108+
os.path.join(DEFAULT_VISUAL_PROMPTING_TEMPLATE_DIR, "template_experimental.yaml")
109+
)
108110
hyper_parameters = create(model_template.hyper_parameters.data)
109111
labels_schema = generate_otx_label_schema()
110112
environment = TaskEnvironment(

0 commit comments

Comments
 (0)