Skip to content

Commit 262b23a

Browse files
authored
fix intokens (#6700)
1 parent ac6f9b7 commit 262b23a

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

paddlenlp/datasets/intokens_dataset.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,33 @@
1818

1919

2020
class InTokens:
21-
required_input_keys = {"input_ids", "labels"}
22-
required_output_keys = {"input_ids", "labels", "attention_mask"}
21+
required_input_keys = ["input_ids", "labels"]
22+
required_output_keys = ["input_ids", "labels", "attention_mask"]
2323
# Only supported the following keys for InTokens. Keys outside of the set will be ignored.
24-
supported_input_keys = {"input_ids", "labels", "attention_mask", "position_ids"}
24+
supported_input_keys = ["input_ids", "labels", "attention_mask", "position_ids"]
2525

2626
@classmethod
2727
def _pad_batch_records(cls, batch_records):
2828
# TODO: support pad_to_max_length for Pipeline parallel
2929
# Only consider supported input keys
30-
input_keys = set(batch_records[0].keys()).intersection(cls.supported_input_keys)
30+
input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys]
31+
3132
# Check required_keys
3233
for key in cls.required_input_keys:
3334
if key not in input_keys:
3435
raise ValueError(f"feature `{key}` is required for InTokensDataset")
36+
# Output features must include all required output keys
37+
for key in cls.required_output_keys:
38+
if key not in input_keys:
39+
input_keys.append(key)
3540

36-
output_keys = input_keys.union(cls.required_output_keys)
37-
batched_features = {key: [] for key in output_keys}
41+
batched_features = {key: [] for key in input_keys}
3842
for record in batch_records:
3943
batched_features["input_ids"].extend(record["input_ids"])
4044
batched_features["labels"].extend(record["labels"])
4145
seq_length = len(record["input_ids"])
4246
# If attention_mask is not given, assume it's causal mask
43-
attention_mask = record.get("attention_mask", np.tril(np.ones([seq_length, seq_length], dtype="bool")))
47+
attention_mask = record.get("attention_mask", np.tril(np.ones([seq_length, seq_length], dtype=bool)))
4448
batched_features["attention_mask"].append(attention_mask)
4549
# TODO: to adapt to chatglm position_2d
4650
# NOTE: position_ids is optional and not required by every model
@@ -49,14 +53,18 @@ def _pad_batch_records(cls, batch_records):
4953
block_attention_mask = block_diag(*batched_features["attention_mask"])
5054
# convert to 3-D [batch_size(1), seq_length, seq_length]
5155
batched_features["attention_mask"] = np.expand_dims(block_attention_mask, axis=0)
56+
# batched_features["input_ids"] = np.array(batched_features["input_ids"], dtype=np.int64)
57+
# batched_features["labels"] = np.array(batched_features["labels"], dtype=np.int64)
58+
# if "position_ids" in record:
59+
# batched_features["position_ids"] = np.array(batched_features["position_ids"], dtype=np.int64)
5260
return batched_features
5361

5462

5563
class InTokensMapDataset(InTokens, Dataset):
5664
def __init__(self, data, tokenizer, max_length):
5765
self.tokenizer = tokenizer
5866
self.max_length = max_length
59-
self.data = self._create_intokens_data(data)
67+
self.new_data = self._create_intokens_data(data)
6068

6169
def _create_intokens_data(self, data):
6270
batch_records, max_len = [], 0
@@ -88,10 +96,10 @@ def _create_intokens_data(self, data):
8896
return total_data
8997

9098
def __getitem__(self, idx):
91-
return self.data[idx]
99+
return self.new_data[idx]
92100

93101
def __len__(self):
94-
return len(self.data)
102+
return len(self.new_data)
95103

96104

97105
class InTokensIterableDataset(InTokens, IterableDataset):

0 commit comments

Comments
 (0)