18
18
19
19
20
20
class InTokens :
21
- required_input_keys = { "input_ids" , "labels" }
22
- required_output_keys = { "input_ids" , "labels" , "attention_mask" }
21
+ required_input_keys = [ "input_ids" , "labels" ]
22
+ required_output_keys = [ "input_ids" , "labels" , "attention_mask" ]
23
23
# Only supported the following keys for InTokens. Keys outside of the set will be ignored.
24
- supported_input_keys = { "input_ids" , "labels" , "attention_mask" , "position_ids" }
24
+ supported_input_keys = [ "input_ids" , "labels" , "attention_mask" , "position_ids" ]
25
25
26
26
@classmethod
27
27
def _pad_batch_records (cls , batch_records ):
28
28
# TODO: support pad_to_max_length for Pipeline parallel
29
29
# Only consider supported input keys
30
- input_keys = set (batch_records [0 ].keys ()).intersection (cls .supported_input_keys )
30
+ input_keys = [key for key in batch_records [0 ].keys () if key in cls .supported_input_keys ]
31
+
31
32
# Check required_keys
32
33
for key in cls .required_input_keys :
33
34
if key not in input_keys :
34
35
raise ValueError (f"feature `{ key } ` is required for InTokensDataset" )
36
+ # Output features must include all required output keys
37
+ for key in cls .required_output_keys :
38
+ if key not in input_keys :
39
+ input_keys .append (key )
35
40
36
- output_keys = input_keys .union (cls .required_output_keys )
37
- batched_features = {key : [] for key in output_keys }
41
+ batched_features = {key : [] for key in input_keys }
38
42
for record in batch_records :
39
43
batched_features ["input_ids" ].extend (record ["input_ids" ])
40
44
batched_features ["labels" ].extend (record ["labels" ])
41
45
seq_length = len (record ["input_ids" ])
42
46
# If attention_mask is not given, assume it's causal mask
43
- attention_mask = record .get ("attention_mask" , np .tril (np .ones ([seq_length , seq_length ], dtype = " bool" )))
47
+ attention_mask = record .get ("attention_mask" , np .tril (np .ones ([seq_length , seq_length ], dtype = bool )))
44
48
batched_features ["attention_mask" ].append (attention_mask )
45
49
# TODO: to adapt to chatglm position_2d
46
50
# NOTE: position_ids is optional and not required by every model
@@ -49,14 +53,18 @@ def _pad_batch_records(cls, batch_records):
49
53
block_attention_mask = block_diag (* batched_features ["attention_mask" ])
50
54
# convert to 3-D [batch_size(1), seq_length, seq_length]
51
55
batched_features ["attention_mask" ] = np .expand_dims (block_attention_mask , axis = 0 )
56
+ # batched_features["input_ids"] = np.array(batched_features["input_ids"], dtype=np.int64)
57
+ # batched_features["labels"] = np.array(batched_features["labels"], dtype=np.int64)
58
+ # if "position_ids" in record:
59
+ # batched_features["position_ids"] = np.array(batched_features["position_ids"], dtype=np.int64)
52
60
return batched_features
53
61
54
62
55
63
class InTokensMapDataset (InTokens , Dataset ):
56
64
def __init__ (self , data , tokenizer , max_length ):
57
65
self .tokenizer = tokenizer
58
66
self .max_length = max_length
59
- self .data = self ._create_intokens_data (data )
67
+ self .new_data = self ._create_intokens_data (data )
60
68
61
69
def _create_intokens_data (self , data ):
62
70
batch_records , max_len = [], 0
@@ -88,10 +96,10 @@ def _create_intokens_data(self, data):
88
96
return total_data
89
97
90
98
def __getitem__ (self , idx ):
91
- return self .data [idx ]
99
+ return self .new_data [idx ]
92
100
93
101
def __len__ (self ):
94
- return len (self .data )
102
+ return len (self .new_data )
95
103
96
104
97
105
class InTokensIterableDataset (InTokens , IterableDataset ):
0 commit comments