Skip to content

Commit ef55f31

Browse files
sungchul1sungmancyunchu
authored
Set path to save pseudo masks into workspace (#2185)
* Set path to save pseudo masks into workspace * pylint * unittest * make black happy * Fix unit tests * Update otx/cli/manager/config_manager.py Co-authored-by: Yunchu Lee <[email protected]> * Update otx/core/data/adapter/segmentation_dataset_adapter.py Co-authored-by: Kim, Sungchul <[email protected]> --------- Co-authored-by: sungmanc <[email protected]> Co-authored-by: Yunchu Lee <[email protected]>
1 parent 81c0ccd commit ef55f31

File tree

5 files changed

+60
-41
lines changed

5 files changed

+60
-41
lines changed

otx/cli/manager/config_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,11 @@ def get_dataset_config(self, subsets: List[str], hyper_parameters: Optional[Conf
406406
if learning_parameters:
407407
num_workers = getattr(learning_parameters, "num_workers", 0)
408408
dataset_config["cache_config"]["num_workers"] = num_workers
409+
if str(self.task_type).upper() == "SEGMENTATION" and str(self.train_type).upper() == "SELFSUPERVISED":
410+
# FIXME: manually set a path to save pseudo masks in workspace
411+
train_type_rel_path = TASK_TYPE_TO_SUB_DIR_NAME[self.train_type]
412+
train_type_dir = self.workspace_root / train_type_rel_path
413+
dataset_config["pseudo_mask_dir"] = train_type_dir / "detcon_mask"
409414
return dataset_config
410415

411416
def update_data_config(self, data_yaml: dict) -> None:

otx/core/data/adapter/base_dataset_adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282
unlabeled_data_roots: Optional[str] = None,
8383
unlabeled_file_list: Optional[str] = None,
8484
cache_config: Optional[Dict[str, Any]] = None,
85+
**kwargs,
8586
):
8687
self.task_type = task_type
8788
self.domain = task_type.domain
@@ -97,6 +98,7 @@ def __init__(
9798
test_ann_files=test_ann_files,
9899
unlabeled_data_roots=unlabeled_data_roots,
99100
unlabeled_file_list=unlabeled_file_list,
101+
**kwargs,
100102
)
101103

102104
cache_config = cache_config if cache_config is not None else {}

otx/core/data/adapter/segmentation_dataset_adapter.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import json
88
import os
9+
from pathlib import Path
910
from typing import Any, Dict, List, Optional
1011

1112
import cv2
@@ -53,6 +54,7 @@ def __init__(
5354
unlabeled_data_roots: Optional[str] = None,
5455
unlabeled_file_list: Optional[str] = None,
5556
cache_config: Optional[Dict[str, Any]] = None,
57+
**kwargs,
5658
):
5759
super().__init__(
5860
task_type,
@@ -65,6 +67,7 @@ def __init__(
6567
unlabeled_data_roots,
6668
unlabeled_file_list,
6769
cache_config,
70+
**kwargs,
6871
)
6972
self.updated_label_id: Dict[int, int] = {}
7073

@@ -166,7 +169,7 @@ def _import_dataset(
166169
test_ann_files: Optional[str] = None,
167170
unlabeled_data_roots: Optional[str] = None,
168171
unlabeled_file_list: Optional[str] = None,
169-
pseudo_mask_dir: str = "detcon_mask",
172+
pseudo_mask_dir: Path = None,
170173
) -> Dict[Subset, DatumDataset]:
171174
"""Import custom Self-SL dataset for using DetCon.
172175
@@ -183,11 +186,13 @@ def _import_dataset(
183186
test_ann_files (Optional[str]): Path for test annotation file
184187
unlabeled_data_roots (Optional[str]): Path for unlabeled data.
185188
unlabeled_file_list (Optional[str]): Path of unlabeled file list
186-
pseudo_mask_dir (str): Directory to save pseudo masks. Defaults to "detcon_mask".
189+
pseudo_mask_dir (Path): Directory to save pseudo masks. Defaults to None.
187190
188191
Returns:
189192
DatumaroDataset: Datumaro Dataset
190193
"""
194+
if pseudo_mask_dir is None:
195+
raise ValueError("pseudo_mask_dir must be set.")
191196
if train_data_roots is None:
192197
raise ValueError("train_data_root must be set.")
193198

@@ -199,23 +204,20 @@ def _import_dataset(
199204
self.is_train_phase = True
200205

201206
# Load pseudo masks
202-
img_dir = None
203207
total_labels = []
208+
os.makedirs(pseudo_mask_dir, exist_ok=True)
204209
for item in dataset[Subset.TRAINING]:
205210
img_path = item.media.path
206-
if img_dir is None:
207-
# Get image directory
208-
img_dir = train_data_roots.split("/")[-1]
209-
pseudo_mask_path = img_path.replace(img_dir, pseudo_mask_dir)
210-
if pseudo_mask_path.endswith(".jpg"):
211-
pseudo_mask_path = pseudo_mask_path.replace(".jpg", ".png")
211+
pseudo_mask_path = pseudo_mask_dir / os.path.basename(img_path)
212+
if pseudo_mask_path.suffix == ".jpg":
213+
pseudo_mask_path = pseudo_mask_path.with_name(f"{pseudo_mask_path.stem}.png")
212214

213215
if not os.path.isfile(pseudo_mask_path):
214216
# Create pseudo mask
215-
pseudo_mask = self.create_pseudo_masks(item.media.data, pseudo_mask_path) # type: ignore
217+
pseudo_mask = self.create_pseudo_masks(item.media.data, str(pseudo_mask_path)) # type: ignore
216218
else:
217219
# Load created pseudo mask
218-
pseudo_mask = cv2.imread(pseudo_mask_path, cv2.IMREAD_GRAYSCALE)
220+
pseudo_mask = cv2.imread(str(pseudo_mask_path), cv2.IMREAD_GRAYSCALE)
219221

220222
# Set annotations into each item
221223
annotations = []
@@ -229,28 +231,27 @@ def _import_dataset(
229231
)
230232
item.annotations = annotations
231233

232-
pseudo_mask_roots = train_data_roots.replace(img_dir, pseudo_mask_dir) # type: ignore
233-
if not os.path.isfile(os.path.join(pseudo_mask_roots, "dataset_meta.json")):
234+
if not os.path.isfile(os.path.join(pseudo_mask_dir, "dataset_meta.json")):
234235
# Save dataset_meta.json for newly created pseudo masks
235236
# FIXME: Because background class is ignored when generating polygons, meta is set with len(labels)-1.
236237
# It must be considered to set the whole labels later.
237238
# (-> {i: f"target{i+1}" for i in range(max(total_labels)+1)})
238239
meta = {"label_map": {i + 1: f"target{i+1}" for i in range(max(total_labels))}}
239-
with open(os.path.join(pseudo_mask_roots, "dataset_meta.json"), "w", encoding="UTF-8") as f:
240+
with open(os.path.join(pseudo_mask_dir, "dataset_meta.json"), "w", encoding="UTF-8") as f:
240241
json.dump(meta, f, indent=4)
241242

242243
# Make categories for pseudo masks
243-
label_map = parse_meta_file(os.path.join(pseudo_mask_roots, "dataset_meta.json"))
244+
label_map = parse_meta_file(os.path.join(pseudo_mask_dir, "dataset_meta.json"))
244245
dataset[Subset.TRAINING].define_categories(make_categories(label_map))
245246

246247
return dataset
247248

248-
def create_pseudo_masks(self, img: np.array, pseudo_mask_path: str, mode: str = "FH") -> None:
249+
def create_pseudo_masks(self, img: np.ndarray, pseudo_mask_path: str, mode: str = "FH") -> None:
249250
"""Create pseudo masks for self-sl for semantic segmentation using DetCon.
250251
251252
Args:
252-
img (np.array) : A sample to create a pseudo mask.
253-
pseudo_mask_path (str): The path to save a pseudo mask.
253+
img (np.ndarray) : A sample to create a pseudo mask.
254+
pseudo_mask_path (Path): The path to save a pseudo mask.
254255
mode (str): The mode to create a pseudo mask. Defaults to "FH".
255256
256257
Returns:
@@ -261,7 +262,6 @@ def create_pseudo_masks(self, img: np.array, pseudo_mask_path: str, mode: str =
261262
else:
262263
raise ValueError((f'{mode} is not supported to create pseudo masks for DetCon. Choose one of ["FH"].'))
263264

264-
os.makedirs(os.path.dirname(pseudo_mask_path), exist_ok=True)
265265
cv2.imwrite(pseudo_mask_path, pseudo_mask.astype(np.uint8))
266266

267267
return pseudo_mask

tests/unit/core/data/adapter/test_init.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
TASK_NAME_TO_TASK_TYPE,
1212
)
1313

14+
from pathlib import Path
15+
import shutil
16+
1417

1518
@e2e_pytest_unit
1619
@pytest.mark.parametrize("task_name", TASK_NAME_TO_TASK_TYPE.keys())
@@ -63,19 +66,28 @@ def test_get_dataset_adapter_selfsl_segmentation(task_name, train_type):
6366
task_type = TASK_NAME_TO_TASK_TYPE[task_name]
6467
data_root = TASK_NAME_TO_DATA_ROOT[task_name]
6568

66-
get_dataset_adapter(
67-
task_type=task_type,
68-
train_type=train_type,
69-
train_data_roots=os.path.join(root_path, data_root["train"]),
70-
)
69+
with pytest.raises(ValueError, match=r"pseudo_mask_dir must be set."):
70+
get_dataset_adapter(
71+
task_type=task_type,
72+
train_type=train_type,
73+
train_data_roots=os.path.join(root_path, data_root["train"]),
74+
)
7175

72-
with pytest.raises(ValueError):
7376
get_dataset_adapter(
7477
task_type=task_type,
7578
train_type=train_type,
7679
test_data_roots=os.path.join(root_path, data_root["test"]),
7780
)
7881

82+
tmp_supcon_mask_dir = Path("/tmp/selfsl_supcon_unit_test")
83+
get_dataset_adapter(
84+
task_type=task_type,
85+
train_type=train_type,
86+
train_data_roots=os.path.join(root_path, data_root["train"]),
87+
pseudo_mask_dir=tmp_supcon_mask_dir,
88+
)
89+
shutil.rmtree(str(tmp_supcon_mask_dir))
90+
7991

8092
# TODO: direct annotation function is only supported in COCO format for now.
8193
@e2e_pytest_unit

tests/unit/core/data/adapter/test_segmentation_adapter.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55
import os
66
import shutil
7+
from pathlib import Path
78
from typing import Optional
89

910
import numpy as np
@@ -66,28 +67,30 @@ def test_get_otx_dataset(self):
6667

6768

6869
class TestSelfSLSegmentationDatasetAdapter:
69-
def setup_method(self, method) -> None:
70+
def setup_class(self) -> None:
7071
self.root_path = os.getcwd()
7172
task = "segmentation"
7273

7374
self.task_type: TaskType = TASK_NAME_TO_TASK_TYPE[task]
7475
data_root_dict: dict = TASK_NAME_TO_DATA_ROOT[task]
7576
self.train_data_roots: str = os.path.join(self.root_path, data_root_dict["train"], "images")
7677

77-
self.pseudo_mask_roots = os.path.abspath(self.train_data_roots.replace("images", "detcon_mask"))
78+
self.pseudo_mask_dir = Path(os.path.abspath(self.train_data_roots.replace("images", "detcon_mask")))
79+
80+
def teardown_class(self) -> None:
81+
shutil.rmtree(self.pseudo_mask_dir, ignore_errors=True)
7882

7983
@e2e_pytest_unit
8084
def test_import_dataset_create_all_masks(self, mocker):
8185
"""Test _import_dataset when creating all masks.
8286
8387
This test is for when all masks are not created and it is required to create masks.
8488
"""
85-
shutil.rmtree(self.pseudo_mask_roots, ignore_errors=True)
89+
shutil.rmtree(self.pseudo_mask_dir, ignore_errors=True)
8690
spy_create_pseudo_masks = mocker.spy(SelfSLSegmentationDatasetAdapter, "create_pseudo_masks")
8791

8892
dataset_adapter = SelfSLSegmentationDatasetAdapter(
89-
task_type=self.task_type,
90-
train_data_roots=self.train_data_roots,
93+
task_type=self.task_type, train_data_roots=self.train_data_roots, pseudo_mask_dir=self.pseudo_mask_dir
9194
)
9295

9396
spy_create_pseudo_masks.assert_called()
@@ -102,20 +105,19 @@ def test_import_dataset_create_some_uncreated_masks(self, mocker, idx_remove: in
102105
and it is required to either create or just load masks.
103106
In this test, remove a mask created before and check if `create_pseudo_masks` is called once.
104107
"""
105-
shutil.rmtree(self.pseudo_mask_roots, ignore_errors=True)
108+
shutil.rmtree(self.pseudo_mask_dir, ignore_errors=True)
106109
dataset_adapter = SelfSLSegmentationDatasetAdapter(
107-
task_type=self.task_type,
108-
train_data_roots=self.train_data_roots,
110+
task_type=self.task_type, train_data_roots=self.train_data_roots, pseudo_mask_dir=self.pseudo_mask_dir
109111
)
110-
assert os.path.isdir(self.pseudo_mask_roots)
111-
assert len(os.listdir(self.pseudo_mask_roots)) == 4
112+
assert os.path.isdir(self.pseudo_mask_dir)
113+
assert len(os.listdir(self.pseudo_mask_dir)) == 4
112114

113115
# remove a mask
114-
os.remove(os.path.join(self.pseudo_mask_roots, f"000{idx_remove}.png"))
116+
os.remove(os.path.join(self.pseudo_mask_dir, f"000{idx_remove}.png"))
115117
spy_create_pseudo_masks = mocker.spy(SelfSLSegmentationDatasetAdapter, "create_pseudo_masks")
116118

117119
_ = dataset_adapter._import_dataset(
118-
train_data_roots=self.train_data_roots,
120+
train_data_roots=self.train_data_roots, pseudo_mask_dir=self.pseudo_mask_dir
119121
)
120122

121123
spy_create_pseudo_masks.assert_called()
@@ -127,8 +129,7 @@ def test_import_dataset_just_load_masks(self, mocker):
127129
spy_create_pseudo_masks = mocker.spy(SelfSLSegmentationDatasetAdapter, "create_pseudo_masks")
128130

129131
_ = SelfSLSegmentationDatasetAdapter(
130-
task_type=self.task_type,
131-
train_data_roots=self.train_data_roots,
132+
task_type=self.task_type, train_data_roots=self.train_data_roots, pseudo_mask_dir=self.pseudo_mask_dir
132133
)
133134

134135
spy_create_pseudo_masks.assert_not_called()
@@ -148,8 +149,7 @@ def test_create_pseudo_masks(self, mocker):
148149
mocker.patch("otx.core.data.adapter.segmentation_dataset_adapter.os.makedirs")
149150
mocker.patch("otx.core.data.adapter.segmentation_dataset_adapter.cv2.imwrite")
150151
dataset_adapter = SelfSLSegmentationDatasetAdapter(
151-
task_type=self.task_type,
152-
train_data_roots=self.train_data_roots,
152+
task_type=self.task_type, train_data_roots=self.train_data_roots, pseudo_mask_dir=self.pseudo_mask_dir
153153
)
154154

155155
pseudo_mask = dataset_adapter.create_pseudo_masks(img=np.ones((2, 2)), pseudo_mask_path="")

0 commit comments

Comments
 (0)