Skip to content

Commit c226919

Browse files
authored
[FIX][Release1.0] Fix action task sample codes and unit tests (#1834)
Fix action task sample codes and unit tests
1 parent f550746 commit c226919

File tree

4 files changed

+13
-4
lines changed

4 files changed

+13
-4
lines changed

otx/algorithms/action/tools/sample_classification.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@ def parse_args():
4949

5050
def load_test_dataset(model_template):
5151
"""Load Sample dataset for detection."""
52+
53+
algo_backend = model_template.hyper_parameters.parameter_overrides["algo_backend"]
54+
train_type = algo_backend["train_type"]["default_value"]
5255
dataset_adapter = get_dataset_adapter(
5356
model_template.task_type,
54-
model_template.train_type,
57+
train_type,
5558
train_data_roots=TRAIN_DATA_ROOTS,
5659
val_data_roots=VAL_DATA_ROOTS,
5760
)

otx/algorithms/action/tools/sample_detection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ def parse_args():
4646

4747
def load_test_dataset(model_template):
4848
"""Load Sample dataset for detection."""
49+
algo_backend = model_template.hyper_parameters.parameter_overrides["algo_backend"]
50+
train_type = algo_backend["train_type"]["default_value"]
4951
dataset_adapter = get_dataset_adapter(
5052
model_template.task_type,
51-
model_template.train_type,
53+
train_type,
5254
train_data_roots=TRAIN_DATA_ROOTS,
5355
val_data_roots=VAL_DATA_ROOTS,
5456
)

tests/unit/algorithms/action/tools/test_action_sample_classification.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def test_load_test_dataset() -> None:
4949

5050
class MockTemplate:
5151
task_type = TaskType.ACTION_CLASSIFICATION
52-
train_type = TrainType.INCREMENTAL.value
52+
hyper_parameters = Config(
53+
{"parameter_overrides": {"algo_backend": {"train_type": {"default_value": TrainType.INCREMENTAL.value}}}}
54+
)
5355

5456
dataset, label_schema = load_test_dataset(MockTemplate())
5557
isinstance(dataset, DatasetEntity)

tests/unit/algorithms/action/tools/test_action_sample_detection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def test_load_test_dataset() -> None:
5050

5151
class MockTemplate:
5252
task_type = TaskType.ACTION_DETECTION
53-
train_type = TrainType.INCREMENTAL.value
53+
hyper_parameters = Config(
54+
{"parameter_overrides": {"algo_backend": {"train_type": {"default_value": TrainType.INCREMENTAL.value}}}}
55+
)
5456

5557
dataset, label_schema = load_test_dataset(MockTemplate())
5658
isinstance(dataset, DatasetEntity)

0 commit comments

Comments
 (0)