Skip to content

Commit b2e98d8

Browse files
authored
fix DatasetTuple identifier bug (#1941)
1 parent 44569cc commit b2e98d8

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

paddlenlp/datasets/dataset.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,26 @@ def load_from_ppnlp(path, *args, **kwargs):
6363

6464
class DatasetTuple:
6565
def __init__(self, splits):
66-
self.tuple_cls = namedtuple('datasets', splits)
66+
self.identifier_map, identifiers = self._gen_identifier_map(splits)
67+
self.tuple_cls = namedtuple('datasets', identifiers)
6768
self.tuple = self.tuple_cls(* [None for _ in splits])
6869

6970
def __getitem__(self, key):
7071
if isinstance(key, (int, slice)):
7172
return self.tuple[key]
7273
if isinstance(key, str):
73-
return getattr(self.tuple, key)
74-
75-
def __repr__(self):
76-
return self.tuple.__repr__()
74+
return getattr(self.tuple, self.identifier_map[key])
7775

7876
def __setitem__(self, key, value):
79-
self.tuple = self.tuple._replace(**{key: value})
77+
self.tuple = self.tuple._replace(**{self.identifier_map[key]: value})
78+
79+
def _gen_identifier_map(self, splits):
80+
identifier_map = {}
81+
identifiers = []
82+
for i in range(len(splits)):
83+
identifiers.append('splits_' + str(i))
84+
identifier_map[splits[i]] = 'splits_' + str(i)
85+
return identifier_map, identifiers
8086

8187
def __len__(self):
8288
return len(self.tuple)

0 commit comments

Comments
 (0)