@@ -85,6 +85,19 @@ class CatalogDataset:
8585 filter_empty = False ,
8686 ),
8787 ),
88+ CatalogDataset (
89+ name = "coco_2017_cls" ,
90+ task = Task .CLASSIFICATION ,
91+ train_split = CatalogSplit (
92+ image_root = "coco/train2017" ,
93+ json_file = "coco/annotations/instances_train2017.json" ,
94+ ),
95+ val_split = CatalogSplit (
96+ image_root = "coco/val2017" ,
97+ json_file = "coco/annotations/instances_val2017.json" ,
98+ filter_empty = False ,
99+ ),
100+ ),
88101 CatalogDataset (
89102 name = "coco_2017_instance" ,
90103 task = Task .INSTANCE_SEGMENTATION ,
@@ -131,6 +144,7 @@ def _load_dataset_split(
131144 split_name : str ,
132145 split : CatalogSplit ,
133146 task : Task ,
147+ split_type : DatasetSplitType ,
134148 root = DATASETS_DIR ,
135149) -> DictDataset :
136150 """
@@ -152,7 +166,7 @@ def get_path(root, path):
152166 task = task ,
153167 )
154168
155- if task in [Task .DETECTION , Task .INSTANCE_SEGMENTATION , Task .KEYPOINT ]:
169+ if task in [Task .DETECTION , Task .INSTANCE_SEGMENTATION , Task .KEYPOINT , Task . CLASSIFICATION ]:
156170 dataset_dict = load_coco_json (json_file_path , image_root_path , metadata , task = task )
157171 if split .filter_empty :
158172 dataset_dict = filter_images_with_only_crowd_annotations (dataset_dicts = dataset_dict )
@@ -171,7 +185,7 @@ def get_path(root, path):
171185 raise ValueError (f"Unknown task { task } " )
172186
173187 metadata .count = len (dataset_dict )
174- return DictDataset (dataset_dict , task = task , metadata = metadata )
188+ return DictDataset (dataset_dict , task = task , metadata = metadata , split_type = split_type )
175189
176190
177191def get_dataset_split (name : str , split : DatasetSplitType , datasets_root = DATASETS_DIR ) -> DictDataset :
@@ -192,4 +206,4 @@ def get_dataset_split(name: str, split: DatasetSplitType, datasets_root=DATASETS
192206 else :
193207 raise ValueError (f"Unknown split { split } " )
194208
195- return _load_dataset_split (split_name , entry , ds .task , datasets_root )
209+ return _load_dataset_split (split_name = split_name , split = entry , task = ds .task , split_type = split , root = datasets_root )
0 commit comments