@@ -27,12 +27,12 @@ def __init__(
27
27
28
28
self .name = name
29
29
self .input_dataset = input_dataset
30
- self .num_samples = len (self .input_dataset [' input_ids' ])
30
+ self .num_samples = len (self .input_dataset [" input_ids" ])
31
31
self .seq_length = seq_length
32
32
33
33
self .weighted_loss_mode = weighted_loss_mode
34
34
self .ds_weight = ds_weight
35
- self .task_name = data_prefix .split ('/' )[- 1 ]
35
+ self .task_name = data_prefix .split ("/" )[- 1 ]
36
36
self .task_id = TASK2ID [self .task_name ]
37
37
38
38
# Checks
@@ -47,8 +47,7 @@ def __getitem__(self, idx):
47
47
try :
48
48
# Get the shuffled index.
49
49
idx = idx % self .num_samples
50
- idx_data = {key : self .input_dataset [key ][idx ]
51
- for key in self .input_dataset }
50
+ idx_data = {key : self .input_dataset [key ][idx ] for key in self .input_dataset }
52
51
53
52
if self .weighted_loss_mode :
54
53
idx_data ["weight" ] = np .array ([self .ds_weight ], dtype = np .float32 )
@@ -115,9 +114,7 @@ def __init__(self, datasets, weights, global_num_samples, local_num_samples):
115
114
116
115
print (
117
116
"> RANK {} elapsed time for building blendable dataset indices: "
118
- "{:.2f} (sec)" .format (
119
- torch .distributed .get_rank (), time .time () - start_time
120
- )
117
+ "{:.2f} (sec)" .format (torch .distributed .get_rank (), time .time () - start_time )
121
118
)
122
119
123
120
def calc_weights (self ):
@@ -166,7 +163,7 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
166
163
encoder = UniformEncoder (args , args .tokenize_mode )
167
164
encoder .initializer ()
168
165
169
- data_prefixes = list (args .data_paths [1 :- 1 ].split (',' ))
166
+ data_prefixes = list (args .data_paths [1 :- 1 ].split ("," ))
170
167
171
168
splits = []
172
169
splits_string = args .data_split
@@ -179,7 +176,7 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
179
176
while len (splits ) < 3 :
180
177
splits .append (0.0 )
181
178
splits = splits [:3 ]
182
- print (f' data splits: { splits } ' )
179
+ print (f" data splits: { splits } " )
183
180
184
181
all_train_datasets = []
185
182
all_valid_datasets = []
@@ -200,40 +197,40 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
200
197
cur_dataset_loss_mask = []
201
198
# support multiple jsonl files under task dir
202
199
for file in files :
203
- file_name = data_prefixes [dataset_index ] + '/' + file
200
+ file_name = data_prefixes [dataset_index ] + "/" + file
204
201
if os .path .isdir (file_name ):
205
202
continue
206
- fin = open (file_name , 'r' )
207
- print (f' [Global Rank { global_rank } ] open file { file_name } ' )
203
+ fin = open (file_name , "r" )
204
+ print (f" [Global Rank { global_rank } ] open file { file_name } " )
208
205
209
- if args .padding_mode == ' padding' or args .padding_mode == ' pack' :
206
+ if args .padding_mode == " padding" or args .padding_mode == " pack" :
210
207
for i , line in enumerate (fin ):
211
208
# pre-sharding
212
209
if shard_data and i % world_size != global_rank :
213
210
continue
214
- data = json .loads (line .rstrip (' \n \r ' ))
211
+ data = json .loads (line .rstrip (" \n \r " ))
215
212
features , length = encoder .encode (data , verbose = (i < 1 ))
216
213
# features, length = encoder.encode(data)
217
214
# may have more samples
218
- for idx in range (len (features [' input_ids' ])):
219
- cur_dataset_input_ids .append (features [' input_ids' ][idx ])
220
- cur_dataset_loss_mask .append (features [' loss_mask' ][idx ])
215
+ for idx in range (len (features [" input_ids" ])):
216
+ cur_dataset_input_ids .append (features [" input_ids" ][idx ])
217
+ cur_dataset_loss_mask .append (features [" loss_mask" ][idx ])
221
218
222
219
fin .close ()
223
220
else :
224
221
i = 0
225
222
for line in fin :
226
- data = json .loads (line .rstrip (' \n \r ' ))
223
+ data = json .loads (line .rstrip (" \n \r " ))
227
224
features , length = encoder .encode (data )
228
225
# 一个document可能编码不出sample,可能编码出多个sample
229
- for idx in range (len (features [' input_ids' ])):
226
+ for idx in range (len (features [" input_ids" ])):
230
227
# post-sharding
231
228
if shard_data and i % world_size != global_rank :
232
229
i += 1
233
230
continue
234
231
i += 1
235
- cur_dataset_input_ids .append (features [' input_ids' ][idx ])
236
- cur_dataset_loss_mask .append (features [' loss_mask' ][idx ])
232
+ cur_dataset_input_ids .append (features [" input_ids" ][idx ])
233
+ cur_dataset_loss_mask .append (features [" loss_mask" ][idx ])
237
234
fin .close ()
238
235
239
236
cur_dataset_input_ids = np .array (cur_dataset_input_ids , dtype = np .float32 )
@@ -249,54 +246,48 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
249
246
train_ratio = splits [0 ] / 100.0
250
247
train_num = int (math .ceil (train_ratio * cur_dataset_sample_num ))
251
248
# split train/valid
252
- cur_train_input_ids , cur_valid_input_ids = cur_dataset_input_ids [: train_num ], cur_dataset_input_ids [train_num :]
253
- cur_train_loss_mask , cur_valid_loss_mask = cur_dataset_loss_mask [: train_num ], cur_dataset_loss_mask [train_num :]
249
+ cur_train_input_ids , cur_valid_input_ids = cur_dataset_input_ids [:train_num ], cur_dataset_input_ids [train_num :]
250
+ cur_train_loss_mask , cur_valid_loss_mask = cur_dataset_loss_mask [:train_num ], cur_dataset_loss_mask [train_num :]
254
251
local_train_num += train_num
255
- local_valid_num += (cur_dataset_sample_num - train_num )
256
-
257
- cur_train_dataset = {
258
- 'input_ids' : cur_train_input_ids ,
259
- 'loss_mask' : cur_train_loss_mask
260
- }
261
- cur_valid_dataset = {
262
- 'input_ids' : cur_valid_input_ids ,
263
- 'loss_mask' : cur_valid_loss_mask
264
- }
252
+ local_valid_num += cur_dataset_sample_num - train_num
253
+
254
+ cur_train_dataset = {"input_ids" : cur_train_input_ids , "loss_mask" : cur_train_loss_mask }
255
+ cur_valid_dataset = {"input_ids" : cur_valid_input_ids , "loss_mask" : cur_valid_loss_mask }
265
256
print (f"[Global Rank { global_rank } ]shape of cur train dataset: { cur_train_dataset ['input_ids' ].shape } " )
266
257
print (f"[Global Rank { global_rank } ]shape of cur valid dataset: { cur_valid_dataset ['input_ids' ].shape } " )
267
258
268
259
cur_train_ds = GPT2FromRawDataset (
269
- ' train' ,
260
+ " train" ,
270
261
data_prefixes [dataset_index ],
271
262
cur_train_dataset ,
272
263
args .seq_length ,
273
264
weighted_loss_mode = args .weighted_loss_mode ,
274
- ds_weight = splits [0 ]
265
+ ds_weight = splits [0 ],
275
266
)
276
267
cur_valid_ds = GPT2FromRawDataset (
277
- ' valid' ,
268
+ " valid" ,
278
269
data_prefixes [dataset_index ],
279
270
cur_valid_dataset ,
280
271
args .seq_length ,
281
272
weighted_loss_mode = args .weighted_loss_mode ,
282
- ds_weight = splits [1 ]
273
+ ds_weight = splits [1 ],
283
274
)
284
-
275
+
285
276
all_train_datasets .append (cur_train_ds )
286
277
all_valid_datasets .append (cur_valid_ds )
287
278
all_train_datasets_length .append (len (cur_train_ds ))
288
279
all_valid_datasets_length .append (len (cur_valid_ds ))
289
-
290
- print (f' [Global Rank { global_rank } ]num tokens: { num_tokens } ' )
291
- print (f' [Global Rank { global_rank } ]effective token rate: { effective_token_rate } ' )
280
+
281
+ print (f" [Global Rank { global_rank } ]num tokens: { num_tokens } " )
282
+ print (f" [Global Rank { global_rank } ]effective token rate: { effective_token_rate } " )
292
283
293
284
num_tokens = []
294
285
ds_fn = partial (ds_weights_by_num_docs_sft )
295
286
train_loss_weights , valid_loss_weights = (
296
287
ds_fn (all_train_datasets_length ),
297
288
ds_fn (all_valid_datasets_length ),
298
289
)
299
-
290
+
300
291
print (f"> train loss weights in rank { global_rank } : { train_loss_weights } " )
301
292
print (f"> valid loss weights in rank { global_rank } : { valid_loss_weights } " )
302
293
@@ -306,51 +297,63 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
306
297
factor = sum (num_tokens ) / (sum (total_sample_cnt ) * args .seq_length )
307
298
factor /= sum ([1.0 / w for w in train_loss_weights ]) / len (train_loss_weights )
308
299
print (f"> common denomination factor for CE loss in rank { global_rank } : { factor } " )
309
-
300
+
310
301
train_sample_weights = [x / sum (all_train_datasets_length ) for x in all_train_datasets_length ]
311
302
valid_sample_weights = [x / sum (all_valid_datasets_length ) for x in all_valid_datasets_length ]
312
303
print (f"> train sample weights in rank { global_rank } : { train_sample_weights } " )
313
304
print (f"> valid sample weights in rank { global_rank } : { valid_sample_weights } " )
314
305
315
306
# recompute global_train_num and global_valid_num
316
-
307
+
317
308
torch .distributed .barrier ()
318
309
device = f"cuda:{ local_rank } "
319
-
310
+
320
311
global_train_num_samples_tensor = torch .tensor (local_train_num , dtype = torch .int32 )
321
312
global_train_num_samples_tensor = global_train_num_samples_tensor .to (device )
322
313
torch .distributed .all_reduce (global_train_num_samples_tensor , op = torch .distributed .ReduceOp .SUM )
323
314
global_train_num = global_train_num_samples_tensor .item ()
324
-
315
+
325
316
global_valid_num_samples_tensor = torch .tensor (local_valid_num , dtype = torch .int32 )
326
317
global_valid_num_samples_tensor = global_valid_num_samples_tensor .to (device )
327
318
torch .distributed .all_reduce (global_valid_num_samples_tensor , op = torch .distributed .ReduceOp .SUM )
328
319
global_valid_num = global_valid_num_samples_tensor .item ()
329
320
print (f"> global train num in rank { global_rank } : { global_train_num } " )
330
321
print (f"> global valid num in rank { global_rank } : { global_valid_num } " )
331
-
322
+
332
323
torch .distributed .barrier ()
333
324
334
325
for i in range (len (all_train_datasets )):
335
- print (f'loss weight of train dataset { i } before update in rank { global_rank } : { all_train_datasets [i ].ds_weight } ' )
326
+ print (
327
+ f"loss weight of train dataset { i } before update in rank { global_rank } : { all_train_datasets [i ].ds_weight } "
328
+ )
336
329
blending_train_dataset = None
337
330
if all_train_datasets :
338
331
args .do_train = True
339
332
for i in range (len (all_train_datasets )):
340
333
all_train_datasets [i ].update_ds_weight (train_loss_weights [i ] / factor )
341
- print (f'loss weight of train dataset { i } after update in rank { global_rank } : { all_train_datasets [i ].ds_weight } ' )
342
- blending_train_dataset = GPT2BlendableDataset (all_train_datasets , train_sample_weights , global_train_num , local_train_num )
334
+ print (
335
+ f"loss weight of train dataset { i } after update in rank { global_rank } : { all_train_datasets [i ].ds_weight } "
336
+ )
337
+ blending_train_dataset = GPT2BlendableDataset (
338
+ all_train_datasets , train_sample_weights , global_train_num , local_train_num
339
+ )
343
340
344
- for i in range (len (all_train_datasets )):
345
- print (f'loss weight of valid dataset { i } before update in rank { global_rank } : { all_train_datasets [i ].ds_weight } ' )
341
+ for i in range (len (all_valid_datasets )):
342
+ print (
343
+ f"loss weight of valid dataset { i } before update in rank { global_rank } : { all_valid_datasets [i ].ds_weight } "
344
+ )
346
345
blending_valid_dataset = None
347
346
if all_valid_datasets :
348
347
args .do_valid = True
349
348
for i in range (len (all_valid_datasets )):
350
349
all_valid_datasets [i ].update_ds_weight (valid_loss_weights [i ] / factor )
351
- print (f'loss weight of valid dataset { i } after update in rank { global_rank } : { all_train_datasets [i ].ds_weight } ' )
352
- blending_valid_dataset = GPT2BlendableDataset (all_valid_datasets , valid_sample_weights , global_valid_num , local_valid_num )
353
-
350
+ print (
351
+ f"loss weight of valid dataset { i } after update in rank { global_rank } : { all_valid_datasets [i ].ds_weight } "
352
+ )
353
+ blending_valid_dataset = GPT2BlendableDataset (
354
+ all_valid_datasets , valid_sample_weights , global_valid_num , local_valid_num
355
+ )
356
+
354
357
return blending_train_dataset , blending_valid_dataset
355
358
356
359
@@ -359,11 +362,13 @@ def compile_helper():
359
362
is invoked on a single process."""
360
363
import os
361
364
import subprocess
365
+
362
366
path = os .path .abspath (os .path .dirname (__file__ ))
363
367
ret = subprocess .run (["make" , "-C" , path ])
364
368
if ret .returncode != 0 :
365
369
print ("Making C++ dataset helpers module failed, exiting." )
366
370
import sys
371
+
367
372
sys .exit (1 )
368
373
else :
369
374
print ("Making C++ dataset helpers module successfully." )
0 commit comments