Skip to content

Commit ba3ea1c

Browse files
committed
[cblue] support converting labels of multi-tasks
1 parent 205e500 commit ba3ea1c

File tree

1 file changed

+40
-20
lines changed

1 file changed

+40
-20
lines changed

paddlenlp/datasets/dataset.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
class DatasetTuple:
4242
def __init__(self, splits):
4343
self.tuple_cls = namedtuple('datasets', splits)
44-
self.tuple = self.tuple_cls(* [None for _ in splits])
44+
self.tuple = self.tuple_cls(*[None for _ in splits])
4545

4646
def __getitem__(self, key):
4747
if isinstance(key, (int, slice)):
@@ -645,6 +645,29 @@ def read(self, filename, split='train'):
645645
label_list = self.get_labels()
646646
vocab_info = self.get_vocab()
647647

648+
def _create_dict(labels):
649+
# For multiple labels in the form of list.
650+
if isinstance(labels[0], list) or isinstance(labels[0], tuple):
651+
label_dict = []
652+
for sub_labels in labels:
653+
sub_dict = {}
654+
for i, label in enumerate(sub_labels):
655+
sub_dict[label] = i
656+
label_dict.append(sub_dict)
657+
else:
658+
label_dict = {}
659+
for i, label in enumerate(labels):
660+
label_dict[label] = i
661+
return label_dict
662+
663+
def _convert_label_to_id(labels, label_dict):
664+
if isinstance(labels, list) or isinstance(labels, tuple):
665+
for label_idx in range(len(labels)):
666+
labels[label_idx] = label_dict[labels[label_idx]]
667+
else:
668+
labels = label_dict[labels]
669+
return labels
670+
648671
if self.lazy:
649672

650673
def generate_examples():
@@ -664,16 +687,15 @@ def generate_examples():
664687

665688
# Convert class label to label ids.
666689
if label_list is not None and example.get(label_col, None):
667-
label_dict = {}
668-
for i, label in enumerate(label_list):
669-
label_dict[label] = i
670-
if isinstance(example[label_col], list) or isinstance(
671-
example[label_col], tuple):
672-
for label_idx in range(len(example[label_col])):
673-
example[label_col][label_idx] = label_dict[
674-
example[label_col][label_idx]]
690+
label_dict = _create_dict(label_list)
691+
# For multiple labels in the form of list.
692+
if isinstance(label_dict, list):
693+
for idx, sub_dict in enumerate(label_dict):
694+
example[label_col][idx] = _convert_label_to_id(
695+
example[label_col][idx], sub_dict)
675696
else:
676-
example[label_col] = label_dict[example[label_col]]
697+
example[label_col] = _convert_label_to_id(
698+
example[label_col], label_dict)
677699

678700
yield example
679701
else:
@@ -709,18 +731,16 @@ def generate_examples():
709731

710732
# Convert class label to label ids.
711733
if label_list is not None and examples[0].get(label_col, None):
712-
label_dict = {}
713-
for i, label in enumerate(label_list):
714-
label_dict[label] = i
734+
label_dict = _create_dict(label_list)
715735
for idx in range(len(examples)):
716-
if isinstance(examples[idx][label_col], list) or isinstance(
717-
examples[idx][label_col], tuple):
718-
for label_idx in range(len(examples[idx][label_col])):
719-
examples[idx][label_col][label_idx] = label_dict[
720-
examples[idx][label_col][label_idx]]
736+
# For multiple labels in the form of list.
737+
if isinstance(label_dict, list):
738+
for i, sub_dict in enumerate(label_dict):
739+
examples[idx][label_col][i] = _convert_label_to_id(
740+
examples[idx][label_col][i], sub_dict)
721741
else:
722-
examples[idx][label_col] = label_dict[examples[idx][
723-
label_col]]
742+
examples[idx][label_col] = _convert_label_to_id(
743+
examples[idx][label_col], label_dict)
724744

725745
return MapDataset(
726746
examples, label_list=label_list, vocab_info=vocab_info)

0 commit comments

Comments
 (0)