41
41
class DatasetTuple :
42
42
def __init__ (self , splits ):
43
43
self .tuple_cls = namedtuple ('datasets' , splits )
44
- self .tuple = self .tuple_cls (* [None for _ in splits ])
44
+ self .tuple = self .tuple_cls (* [None for _ in splits ])
45
45
46
46
def __getitem__ (self , key ):
47
47
if isinstance (key , (int , slice )):
@@ -645,6 +645,29 @@ def read(self, filename, split='train'):
645
645
label_list = self .get_labels ()
646
646
vocab_info = self .get_vocab ()
647
647
648
+ def _create_dict (labels ):
649
+ # For multiple labels in the form of list.
650
+ if isinstance (labels [0 ], list ) or isinstance (labels [0 ], tuple ):
651
+ label_dict = []
652
+ for sub_labels in labels :
653
+ sub_dict = {}
654
+ for i , label in enumerate (sub_labels ):
655
+ sub_dict [label ] = i
656
+ label_dict .append (sub_dict )
657
+ else :
658
+ label_dict = {}
659
+ for i , label in enumerate (labels ):
660
+ label_dict [label ] = i
661
+ return label_dict
662
+
663
+ def _convert_label_to_id (labels , label_dict ):
664
+ if isinstance (labels , list ) or isinstance (labels , tuple ):
665
+ for label_idx in range (len (labels )):
666
+ labels [label_idx ] = label_dict [labels [label_idx ]]
667
+ else :
668
+ labels = label_dict [labels ]
669
+ return labels
670
+
648
671
if self .lazy :
649
672
650
673
def generate_examples ():
@@ -664,16 +687,15 @@ def generate_examples():
664
687
665
688
# Convert class label to label ids.
666
689
if label_list is not None and example .get (label_col , None ):
667
- label_dict = {}
668
- for i , label in enumerate (label_list ):
669
- label_dict [label ] = i
670
- if isinstance (example [label_col ], list ) or isinstance (
671
- example [label_col ], tuple ):
672
- for label_idx in range (len (example [label_col ])):
673
- example [label_col ][label_idx ] = label_dict [
674
- example [label_col ][label_idx ]]
690
+ label_dict = _create_dict (label_list )
691
+ # For multiple labels in the form of list.
692
+ if isinstance (label_dict , list ):
693
+ for idx , sub_dict in enumerate (label_dict ):
694
+ example [label_col ][idx ] = _convert_label_to_id (
695
+ example [label_col ][idx ], sub_dict )
675
696
else :
676
- example [label_col ] = label_dict [example [label_col ]]
697
+ example [label_col ] = _convert_label_to_id (
698
+ example [label_col ], label_dict )
677
699
678
700
yield example
679
701
else :
@@ -709,18 +731,16 @@ def generate_examples():
709
731
710
732
# Convert class label to label ids.
711
733
if label_list is not None and examples [0 ].get (label_col , None ):
712
- label_dict = {}
713
- for i , label in enumerate (label_list ):
714
- label_dict [label ] = i
734
+ label_dict = _create_dict (label_list )
715
735
for idx in range (len (examples )):
716
- if isinstance ( examples [ idx ][ label_col ], list ) or isinstance (
717
- examples [ idx ][ label_col ], tuple ):
718
- for label_idx in range ( len ( examples [ idx ][ label_col ]) ):
719
- examples [idx ][label_col ][label_idx ] = label_dict [
720
- examples [idx ][label_col ][label_idx ]]
736
+ # For multiple labels in the form of list.
737
+ if isinstance ( label_dict , list ):
738
+ for i , sub_dict in enumerate ( label_dict ):
739
+ examples [idx ][label_col ][i ] = _convert_label_to_id (
740
+ examples [idx ][label_col ][i ], sub_dict )
721
741
else :
722
- examples [idx ][label_col ] = label_dict [ examples [ idx ][
723
- label_col ]]
742
+ examples [idx ][label_col ] = _convert_label_to_id (
743
+ examples [ idx ][ label_col ], label_dict )
724
744
725
745
return MapDataset (
726
746
examples , label_list = label_list , vocab_info = vocab_info )
0 commit comments