1313except ImportError as e :
1414 print ("Please install Hugging Face datasets package `pip install datasets`." )
1515 exit (1 )
16+ from .class_map import load_class_map
1617from .reader import Reader
1718
1819
19- def get_class_labels (info ):
20+ def get_class_labels (info , label_key = 'label' ):
2021 if 'label' not in info .features :
2122 return {}
22- class_label = info .features ['label' ]
23+ class_label = info .features [label_key ]
2324 class_to_idx = {n : class_label .str2int (n ) for n in class_label .names }
2425 return class_to_idx
2526
@@ -32,6 +33,7 @@ def __init__(
3233 name ,
3334 split = 'train' ,
3435 class_map = None ,
36+ label_key = 'label' ,
3537 download = False ,
3638 ):
3739 """
@@ -43,12 +45,17 @@ def __init__(
4345 name , # 'name' maps to path arg in hf datasets
4446 split = split ,
4547 cache_dir = self .root , # timm doesn't expect hidden cache dir for datasets, specify a path
46- #use_auth_token=True,
4748 )
4849 # leave decode for caller, plus we want easy access to original path names...
4950 self .dataset = self .dataset .cast_column ('image' , datasets .Image (decode = False ))
5051
51- self .class_to_idx = get_class_labels (self .dataset .info )
52+ self .label_key = label_key
53+ self .remap_class = False
54+ if class_map :
55+ self .class_to_idx = load_class_map (class_map )
56+ self .remap_class = True
57+ else :
58+ self .class_to_idx = get_class_labels (self .dataset .info , self .label_key )
5259 self .split_info = self .dataset .info .splits [split ]
5360 self .num_samples = self .split_info .num_examples
5461
@@ -60,7 +67,10 @@ def __getitem__(self, index):
6067 else :
6168 assert 'path' in image and image ['path' ]
6269 image = open (image ['path' ], 'rb' )
63- return image , item ['label' ]
70+ label = item [self .label_key ]
71+ if self .remap_class :
72+ label = self .class_to_idx [label ]
73+ return image , label
6474
6575 def __len__ (self ):
6676 return len (self .dataset )
0 commit comments