Skip to content

Commit d1bfa9a

Browse files
committed
Support HF datasets and TFSD w/ a sub-path by fixing split, fix #1598 ... add class mapping support to HF datasets in case class label isn't in info.
1 parent 35fb00c commit d1bfa9a

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

timm/data/dataset_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def create_dataset(
151151
elif name.startswith('hfds/'):
152152
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
153153
# There will be a IterableDataset variant too, TBD
154-
ds = ImageDataset(root, reader=name, split=split, **kwargs)
154+
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
155155
elif name.startswith('tfds/'):
156156
ds = IterableImageDataset(
157157
root,

timm/data/readers/reader_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def create_reader(name, root, split='train', **kwargs):
88
name = name.lower()
9-
name = name.split('/', 2)
9+
name = name.split('/', 1)
1010
prefix = ''
1111
if len(name) > 1:
1212
prefix = name[0]

timm/data/readers/reader_hfds.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
except ImportError as e:
1414
print("Please install Hugging Face datasets package `pip install datasets`.")
1515
exit(1)
16+
from .class_map import load_class_map
1617
from .reader import Reader
1718

1819

19-
def get_class_labels(info):
20+
def get_class_labels(info, label_key='label'):
2021
if 'label' not in info.features:
2122
return {}
22-
class_label = info.features['label']
23+
class_label = info.features[label_key]
2324
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
2425
return class_to_idx
2526

@@ -32,6 +33,7 @@ def __init__(
3233
name,
3334
split='train',
3435
class_map=None,
36+
label_key='label',
3537
download=False,
3638
):
3739
"""
@@ -43,12 +45,17 @@ def __init__(
4345
name, # 'name' maps to path arg in hf datasets
4446
split=split,
4547
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
46-
#use_auth_token=True,
4748
)
4849
# leave decode for caller, plus we want easy access to original path names...
4950
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
5051

51-
self.class_to_idx = get_class_labels(self.dataset.info)
52+
self.label_key = label_key
53+
self.remap_class = False
54+
if class_map:
55+
self.class_to_idx = load_class_map(class_map)
56+
self.remap_class = True
57+
else:
58+
self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
5259
self.split_info = self.dataset.info.splits[split]
5360
self.num_samples = self.split_info.num_examples
5461

@@ -60,7 +67,10 @@ def __getitem__(self, index):
6067
else:
6168
assert 'path' in image and image['path']
6269
image = open(image['path'], 'rb')
63-
return image, item['label']
70+
label = item[self.label_key]
71+
if self.remap_class:
72+
label = self.class_to_idx[label]
73+
return image, label
6474

6575
def __len__(self):
6676
return len(self.dataset)

0 commit comments

Comments
 (0)