66
77import json
88import os
9+ from pathlib import Path
910from typing import Any , Dict , List , Optional
1011
1112import cv2
@@ -53,6 +54,7 @@ def __init__(
5354 unlabeled_data_roots : Optional [str ] = None ,
5455 unlabeled_file_list : Optional [str ] = None ,
5556 cache_config : Optional [Dict [str , Any ]] = None ,
57+ ** kwargs ,
5658 ):
5759 super ().__init__ (
5860 task_type ,
@@ -65,6 +67,7 @@ def __init__(
6567 unlabeled_data_roots ,
6668 unlabeled_file_list ,
6769 cache_config ,
70+ ** kwargs ,
6871 )
6972 self .updated_label_id : Dict [int , int ] = {}
7073
@@ -166,7 +169,7 @@ def _import_dataset(
166169 test_ann_files : Optional [str ] = None ,
167170 unlabeled_data_roots : Optional [str ] = None ,
168171 unlabeled_file_list : Optional [str ] = None ,
169- pseudo_mask_dir : str = "detcon_mask" ,
172+ pseudo_mask_dir : Path = None ,
170173 ) -> Dict [Subset , DatumDataset ]:
171174 """Import custom Self-SL dataset for using DetCon.
172175
@@ -183,11 +186,13 @@ def _import_dataset(
183186 test_ann_files (Optional[str]): Path for test annotation file
184187 unlabeled_data_roots (Optional[str]): Path for unlabeled data.
185188 unlabeled_file_list (Optional[str]): Path of unlabeled file list
186- pseudo_mask_dir (str ): Directory to save pseudo masks. Defaults to "detcon_mask" .
189+ pseudo_mask_dir (Path ): Directory to save pseudo masks. Defaults to None .
187190
188191 Returns:
189192 DatumaroDataset: Datumaro Dataset
190193 """
194+ if pseudo_mask_dir is None :
195+ raise ValueError ("pseudo_mask_dir must be set." )
191196 if train_data_roots is None :
192197 raise ValueError ("train_data_root must be set." )
193198
@@ -199,23 +204,20 @@ def _import_dataset(
199204 self .is_train_phase = True
200205
201206 # Load pseudo masks
202- img_dir = None
203207 total_labels = []
208+ os .makedirs (pseudo_mask_dir , exist_ok = True )
204209 for item in dataset [Subset .TRAINING ]:
205210 img_path = item .media .path
206- if img_dir is None :
207- # Get image directory
208- img_dir = train_data_roots .split ("/" )[- 1 ]
209- pseudo_mask_path = img_path .replace (img_dir , pseudo_mask_dir )
210- if pseudo_mask_path .endswith (".jpg" ):
211- pseudo_mask_path = pseudo_mask_path .replace (".jpg" , ".png" )
211+ pseudo_mask_path = pseudo_mask_dir / os .path .basename (img_path )
212+ if pseudo_mask_path .suffix == ".jpg" :
213+ pseudo_mask_path = pseudo_mask_path .with_name (f"{ pseudo_mask_path .stem } .png" )
212214
213215 if not os .path .isfile (pseudo_mask_path ):
214216 # Create pseudo mask
215- pseudo_mask = self .create_pseudo_masks (item .media .data , pseudo_mask_path ) # type: ignore
217+ pseudo_mask = self .create_pseudo_masks (item .media .data , str ( pseudo_mask_path ) ) # type: ignore
216218 else :
217219 # Load created pseudo mask
218- pseudo_mask = cv2 .imread (pseudo_mask_path , cv2 .IMREAD_GRAYSCALE )
220+ pseudo_mask = cv2 .imread (str ( pseudo_mask_path ) , cv2 .IMREAD_GRAYSCALE )
219221
220222 # Set annotations into each item
221223 annotations = []
@@ -229,28 +231,27 @@ def _import_dataset(
229231 )
230232 item .annotations = annotations
231233
232- pseudo_mask_roots = train_data_roots .replace (img_dir , pseudo_mask_dir ) # type: ignore
233- if not os .path .isfile (os .path .join (pseudo_mask_roots , "dataset_meta.json" )):
234+ if not os .path .isfile (os .path .join (pseudo_mask_dir , "dataset_meta.json" )):
234235 # Save dataset_meta.json for newly created pseudo masks
235236 # FIXME: Because background class is ignored when generating polygons, meta is set with len(labels)-1.
236237 # It must be considered to set the whole labels later.
237238 # (-> {i: f"target{i+1}" for i in range(max(total_labels)+1)})
238239 meta = {"label_map" : {i + 1 : f"target{ i + 1 } " for i in range (max (total_labels ))}}
239- with open (os .path .join (pseudo_mask_roots , "dataset_meta.json" ), "w" , encoding = "UTF-8" ) as f :
240+ with open (os .path .join (pseudo_mask_dir , "dataset_meta.json" ), "w" , encoding = "UTF-8" ) as f :
240241 json .dump (meta , f , indent = 4 )
241242
242243 # Make categories for pseudo masks
243- label_map = parse_meta_file (os .path .join (pseudo_mask_roots , "dataset_meta.json" ))
244+ label_map = parse_meta_file (os .path .join (pseudo_mask_dir , "dataset_meta.json" ))
244245 dataset [Subset .TRAINING ].define_categories (make_categories (label_map ))
245246
246247 return dataset
247248
248- def create_pseudo_masks (self , img : np .array , pseudo_mask_path : str , mode : str = "FH" ) -> None :
249+ def create_pseudo_masks (self , img : np .ndarray , pseudo_mask_path : str , mode : str = "FH" ) -> None :
249250 """Create pseudo masks for self-sl for semantic segmentation using DetCon.
250251
251252 Args:
252- img (np.array ) : A sample to create a pseudo mask.
253- pseudo_mask_path (str ): The path to save a pseudo mask.
253+ img (np.ndarray ) : A sample to create a pseudo mask.
254+ pseudo_mask_path (Path ): The path to save a pseudo mask.
254255 mode (str): The mode to create a pseudo mask. Defaults to "FH".
255256
256257 Returns:
@@ -261,7 +262,6 @@ def create_pseudo_masks(self, img: np.array, pseudo_mask_path: str, mode: str =
261262 else :
262263 raise ValueError ((f'{ mode } is not supported to create pseudo masks for DetCon. Choose one of ["FH"].' ))
263264
264- os .makedirs (os .path .dirname (pseudo_mask_path ), exist_ok = True )
265265 cv2 .imwrite (pseudo_mask_path , pseudo_mask .astype (np .uint8 ))
266266
267267 return pseudo_mask
0 commit comments