Skip to content

Commit 4971965

Browse files
authored
[FIX] Classification Auto-split bug fix (#1824)
* doc update, fix autosplit * Fix pre-commit * Add ignore * Fix figure * Fix minor * Fix pre-commit * Fix integ test * Fix pre-commit
1 parent b293e94 commit 4971965

File tree

9 files changed

+14
-12
lines changed

9 files changed

+14
-12
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,7 @@ dmypy.json
4040
otx/**/*.c
4141
otx/**/*.html
4242
otx/**/*.so
43+
44+
45+
# Dataset made by unit-test
46+
tests/**/detcon_mask/*

docs/source/guide/tutorials/base/how_to_train/classification.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ with the following command:
5656
cd ..
5757
5858
|
59-
60-
.. image:: ../../../../../utils/images/flowers.jpg
59+
.. image:: ../../../../../utils/images/flowers_example.jpg
6160
:width: 600
6261

6362
|
@@ -121,7 +120,7 @@ Let's prepare an OpenVINO™ Training Extensions classification workspace runnin
121120
122121
(otx) ...$ cd ./otx-workspace-CLASSIFICATION
123122
124-
It will create **otx-workspace-CLASSIFICATION** with all necessary configs for MobileNet-V3-large-1x, prepared ``data.yaml`` to simplify CLI commands launch and splitted dataset.
123+
It will create **otx-workspace-CLASSIFICATION** with all necessery configs for MobileNet-V3-large-1x, prepared ``data.yaml`` to simplify CLI commands launch and splitted dataset named ``splitted_dataset``.
125124

126125
3. To start training we need to call ``otx train``
127126
command in our workspace:
@@ -252,4 +251,4 @@ Please note, that POT will take some time (generally less than NNCF optimization
252251
efficient model representation ready-to-use classification model.
253252

254253
The following tutorials provide further steps on how to :doc:`deploy <../deploy>` and use your model in the :doc:`demonstration mode <../demo>` and visualize results.
255-
The examples are provided with an object detection model, but it is easy to apply them for classification by substituting the object detection model with classification one.
254+
The examples are provided with an object detection model, but it is easy to apply them for classification by substituting the object detection model with classification one.
148 KB
Loading

otx/cli/manager/config_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def configure_data_config(self, update_data_yaml: bool = True) -> None:
169169
use_auto_split = data_yaml["data"]["train"]["data-roots"] and not data_yaml["data"]["val"]["data-roots"]
170170
# FIXME: Hardcoded for Self-Supervised Learning
171171
if use_auto_split and str(self.train_type).upper() != "SELFSUPERVISED":
172-
splitted_dataset = self.auto_split_data(data_yaml["data"]["train"]["data-roots"], self.task_type)
172+
splitted_dataset = self.auto_split_data(data_yaml["data"]["train"]["data-roots"], str(self.task_type))
173173
default_data_folder_name = "splitted_dataset"
174174
data_yaml = self._get_arg_data_yaml()
175175
self._save_data(splitted_dataset, default_data_folder_name, data_yaml)
@@ -202,7 +202,6 @@ def auto_task_detection(self, data_roots: str) -> str:
202202
if not data_roots:
203203
raise ValueError("Workspace must already exist or one of {task or model or train-data-roots} must exist.")
204204
self.data_format = self.dataset_manager.get_data_format(data_roots)
205-
print(f"[*] Detected dataset format: {self.data_format}")
206205
return self._get_task_type_from_data_format(self.data_format)
207206

208207
def _get_task_type_from_data_format(self, data_format: str) -> str:

otx/core/data/manager/dataset_manager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pylint: disable=invalid-name
88
import os
9-
from typing import List, Tuple
9+
from typing import List, Tuple, Union
1010

1111
import datumaro
1212
from datumaro.components.dataset import Dataset, DatasetSubset
@@ -31,12 +31,12 @@ def get_train_dataset(dataset: Dataset) -> DatasetSubset:
3131
raise ValueError("Can't find training data.")
3232

3333
@staticmethod
34-
def get_val_dataset(dataset: Dataset) -> DatasetSubset:
34+
def get_val_dataset(dataset: Dataset) -> Union[DatasetSubset, None]:
3535
"""Returns validation dataset."""
3636
for k, v in dataset.subsets().items():
37-
if "val" in k or "default" in k:
37+
if "val" in k:
3838
return v
39-
raise ValueError("Can't find validation data.")
39+
return None
4040

4141
@staticmethod
4242
def get_data_format(data_root: str) -> str:
@@ -57,6 +57,7 @@ def get_data_format(data_root: str) -> str:
5757
data_formats = datumaro.Environment().detect_dataset(data_root)
5858
# TODO: how to avoid hard-coded part
5959
data_format = data_formats[0] if "imagenet" not in data_formats else "imagenet"
60+
print(f"[*] Detected dataset format: {data_format}")
6061
return data_format
6162

6263
@staticmethod
631 Bytes
Loading
631 Bytes
Loading
631 Bytes
Loading

tests/unit/core/data/manager/test_dataset_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ def test_get_train_dataset(self, task: List[str], subset: List[str]):
5757
@pytest.mark.parametrize("subset", AVAILABLE_SUBSETS)
5858
def test_get_val_dataset(self, task: List[str], subset: List[str]):
5959
if subset == "train":
60-
with pytest.raises(ValueError, match="Can't find validation data."):
61-
DatasetManager.get_val_dataset(self.dataset[subset][task])
60+
assert DatasetManager.get_val_dataset(self.dataset[subset][task]) is None
6261
else:
6362
val_dataset = DatasetManager.get_val_dataset(self.dataset[subset][task])
6463
assert isinstance(val_dataset, dm.DatasetSubset)

0 commit comments

Comments
 (0)