Skip to content

Commit 13ad050

Browse files
committed
v0.4.2 refactor
1 parent e8d03f9 commit 13ad050

22 files changed

+3492
-1002
lines changed
Binary file not shown.

mftcoder_accelerate/src/data/multi_task_dataset.py

Lines changed: 60 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ def __init__(
2727

2828
self.name = name
2929
self.input_dataset = input_dataset
30-
self.num_samples = len(self.input_dataset['input_ids'])
30+
self.num_samples = len(self.input_dataset["input_ids"])
3131
self.seq_length = seq_length
3232

3333
self.weighted_loss_mode = weighted_loss_mode
3434
self.ds_weight = ds_weight
35-
self.task_name = data_prefix.split('/')[-1]
35+
self.task_name = data_prefix.split("/")[-1]
3636
self.task_id = TASK2ID[self.task_name]
3737

3838
# Checks
@@ -47,8 +47,7 @@ def __getitem__(self, idx):
4747
try:
4848
# Get the shuffled index.
4949
idx = idx % self.num_samples
50-
idx_data = {key: self.input_dataset[key][idx]
51-
for key in self.input_dataset}
50+
idx_data = {key: self.input_dataset[key][idx] for key in self.input_dataset}
5251

5352
if self.weighted_loss_mode:
5453
idx_data["weight"] = np.array([self.ds_weight], dtype=np.float32)
@@ -115,9 +114,7 @@ def __init__(self, datasets, weights, global_num_samples, local_num_samples):
115114

116115
print(
117116
"> RANK {} elapsed time for building blendable dataset indices: "
118-
"{:.2f} (sec)".format(
119-
torch.distributed.get_rank(), time.time() - start_time
120-
)
117+
"{:.2f} (sec)".format(torch.distributed.get_rank(), time.time() - start_time)
121118
)
122119

123120
def calc_weights(self):
@@ -166,7 +163,7 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
166163
encoder = UniformEncoder(args, args.tokenize_mode)
167164
encoder.initializer()
168165

169-
data_prefixes = list(args.data_paths[1:-1].split(','))
166+
data_prefixes = list(args.data_paths[1:-1].split(","))
170167

171168
splits = []
172169
splits_string = args.data_split
@@ -179,7 +176,7 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
179176
while len(splits) < 3:
180177
splits.append(0.0)
181178
splits = splits[:3]
182-
print(f'data splits: {splits}')
179+
print(f"data splits: {splits}")
183180

184181
all_train_datasets = []
185182
all_valid_datasets = []
@@ -200,40 +197,40 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
200197
cur_dataset_loss_mask = []
201198
# support multiple jsonl files under task dir
202199
for file in files:
203-
file_name = data_prefixes[dataset_index] + '/' + file
200+
file_name = data_prefixes[dataset_index] + "/" + file
204201
if os.path.isdir(file_name):
205202
continue
206-
fin = open(file_name, 'r')
207-
print(f'[Global Rank {global_rank}] open file {file_name}')
203+
fin = open(file_name, "r")
204+
print(f"[Global Rank {global_rank}] open file {file_name}")
208205

209-
if args.padding_mode == 'padding' or args.padding_mode == 'pack':
206+
if args.padding_mode == "padding" or args.padding_mode == "pack":
210207
for i, line in enumerate(fin):
211208
# pre-sharding
212209
if shard_data and i % world_size != global_rank:
213210
continue
214-
data = json.loads(line.rstrip('\n\r'))
211+
data = json.loads(line.rstrip("\n\r"))
215212
features, length = encoder.encode(data, verbose=(i < 1))
216213
# features, length = encoder.encode(data)
217214
# may have more samples
218-
for idx in range(len(features['input_ids'])):
219-
cur_dataset_input_ids.append(features['input_ids'][idx])
220-
cur_dataset_loss_mask.append(features['loss_mask'][idx])
215+
for idx in range(len(features["input_ids"])):
216+
cur_dataset_input_ids.append(features["input_ids"][idx])
217+
cur_dataset_loss_mask.append(features["loss_mask"][idx])
221218

222219
fin.close()
223220
else:
224221
i = 0
225222
for line in fin:
226-
data = json.loads(line.rstrip('\n\r'))
223+
data = json.loads(line.rstrip("\n\r"))
227224
features, length = encoder.encode(data)
228225
# 一个document可能编码不出sample,可能编码出多个sample
229-
for idx in range(len(features['input_ids'])):
226+
for idx in range(len(features["input_ids"])):
230227
# post-sharding
231228
if shard_data and i % world_size != global_rank:
232229
i += 1
233230
continue
234231
i += 1
235-
cur_dataset_input_ids.append(features['input_ids'][idx])
236-
cur_dataset_loss_mask.append(features['loss_mask'][idx])
232+
cur_dataset_input_ids.append(features["input_ids"][idx])
233+
cur_dataset_loss_mask.append(features["loss_mask"][idx])
237234
fin.close()
238235

239236
cur_dataset_input_ids = np.array(cur_dataset_input_ids, dtype=np.float32)
@@ -249,54 +246,48 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
249246
train_ratio = splits[0] / 100.0
250247
train_num = int(math.ceil(train_ratio * cur_dataset_sample_num))
251248
# split train/valid
252-
cur_train_input_ids, cur_valid_input_ids = cur_dataset_input_ids[: train_num], cur_dataset_input_ids[train_num:]
253-
cur_train_loss_mask, cur_valid_loss_mask = cur_dataset_loss_mask[: train_num], cur_dataset_loss_mask[train_num:]
249+
cur_train_input_ids, cur_valid_input_ids = cur_dataset_input_ids[:train_num], cur_dataset_input_ids[train_num:]
250+
cur_train_loss_mask, cur_valid_loss_mask = cur_dataset_loss_mask[:train_num], cur_dataset_loss_mask[train_num:]
254251
local_train_num += train_num
255-
local_valid_num += (cur_dataset_sample_num - train_num)
256-
257-
cur_train_dataset = {
258-
'input_ids': cur_train_input_ids,
259-
'loss_mask': cur_train_loss_mask
260-
}
261-
cur_valid_dataset = {
262-
'input_ids': cur_valid_input_ids,
263-
'loss_mask': cur_valid_loss_mask
264-
}
252+
local_valid_num += cur_dataset_sample_num - train_num
253+
254+
cur_train_dataset = {"input_ids": cur_train_input_ids, "loss_mask": cur_train_loss_mask}
255+
cur_valid_dataset = {"input_ids": cur_valid_input_ids, "loss_mask": cur_valid_loss_mask}
265256
print(f"[Global Rank {global_rank}]shape of cur train dataset: {cur_train_dataset['input_ids'].shape}")
266257
print(f"[Global Rank {global_rank}]shape of cur valid dataset: {cur_valid_dataset['input_ids'].shape}")
267258

268259
cur_train_ds = GPT2FromRawDataset(
269-
'train',
260+
"train",
270261
data_prefixes[dataset_index],
271262
cur_train_dataset,
272263
args.seq_length,
273264
weighted_loss_mode=args.weighted_loss_mode,
274-
ds_weight=splits[0]
265+
ds_weight=splits[0],
275266
)
276267
cur_valid_ds = GPT2FromRawDataset(
277-
'valid',
268+
"valid",
278269
data_prefixes[dataset_index],
279270
cur_valid_dataset,
280271
args.seq_length,
281272
weighted_loss_mode=args.weighted_loss_mode,
282-
ds_weight=splits[1]
273+
ds_weight=splits[1],
283274
)
284-
275+
285276
all_train_datasets.append(cur_train_ds)
286277
all_valid_datasets.append(cur_valid_ds)
287278
all_train_datasets_length.append(len(cur_train_ds))
288279
all_valid_datasets_length.append(len(cur_valid_ds))
289-
290-
print(f'[Global Rank {global_rank}]num tokens: {num_tokens}')
291-
print(f'[Global Rank {global_rank}]effective token rate: {effective_token_rate}')
280+
281+
print(f"[Global Rank {global_rank}]num tokens: {num_tokens}")
282+
print(f"[Global Rank {global_rank}]effective token rate: {effective_token_rate}")
292283

293284
num_tokens = []
294285
ds_fn = partial(ds_weights_by_num_docs_sft)
295286
train_loss_weights, valid_loss_weights = (
296287
ds_fn(all_train_datasets_length),
297288
ds_fn(all_valid_datasets_length),
298289
)
299-
290+
300291
print(f"> train loss weights in rank {global_rank}: {train_loss_weights}")
301292
print(f"> valid loss weights in rank {global_rank}: {valid_loss_weights}")
302293

@@ -306,51 +297,63 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0,
306297
factor = sum(num_tokens) / (sum(total_sample_cnt) * args.seq_length)
307298
factor /= sum([1.0 / w for w in train_loss_weights]) / len(train_loss_weights)
308299
print(f"> common denomination factor for CE loss in rank {global_rank}: {factor}")
309-
300+
310301
train_sample_weights = [x / sum(all_train_datasets_length) for x in all_train_datasets_length]
311302
valid_sample_weights = [x / sum(all_valid_datasets_length) for x in all_valid_datasets_length]
312303
print(f"> train sample weights in rank {global_rank}: {train_sample_weights}")
313304
print(f"> valid sample weights in rank {global_rank}: {valid_sample_weights}")
314305

315306
# recompute global_train_num and global_valid_num
316-
307+
317308
torch.distributed.barrier()
318309
device = f"cuda:{local_rank}"
319-
310+
320311
global_train_num_samples_tensor = torch.tensor(local_train_num, dtype=torch.int32)
321312
global_train_num_samples_tensor = global_train_num_samples_tensor.to(device)
322313
torch.distributed.all_reduce(global_train_num_samples_tensor, op=torch.distributed.ReduceOp.SUM)
323314
global_train_num = global_train_num_samples_tensor.item()
324-
315+
325316
global_valid_num_samples_tensor = torch.tensor(local_valid_num, dtype=torch.int32)
326317
global_valid_num_samples_tensor = global_valid_num_samples_tensor.to(device)
327318
torch.distributed.all_reduce(global_valid_num_samples_tensor, op=torch.distributed.ReduceOp.SUM)
328319
global_valid_num = global_valid_num_samples_tensor.item()
329320
print(f"> global train num in rank {global_rank}: {global_train_num}")
330321
print(f"> global valid num in rank {global_rank}: {global_valid_num}")
331-
322+
332323
torch.distributed.barrier()
333324

334325
for i in range(len(all_train_datasets)):
335-
print(f'loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}')
326+
print(
327+
f"loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}"
328+
)
336329
blending_train_dataset = None
337330
if all_train_datasets:
338331
args.do_train = True
339332
for i in range(len(all_train_datasets)):
340333
all_train_datasets[i].update_ds_weight(train_loss_weights[i] / factor)
341-
print(f'loss weight of train dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}')
342-
blending_train_dataset = GPT2BlendableDataset(all_train_datasets, train_sample_weights, global_train_num, local_train_num)
334+
print(
335+
f"loss weight of train dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}"
336+
)
337+
blending_train_dataset = GPT2BlendableDataset(
338+
all_train_datasets, train_sample_weights, global_train_num, local_train_num
339+
)
343340

344-
for i in range(len(all_train_datasets)):
345-
print(f'loss weight of valid dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}')
341+
for i in range(len(all_valid_datasets)):
342+
print(
343+
f"loss weight of valid dataset {i} before update in rank {global_rank}: {all_valid_datasets[i].ds_weight}"
344+
)
346345
blending_valid_dataset = None
347346
if all_valid_datasets:
348347
args.do_valid = True
349348
for i in range(len(all_valid_datasets)):
350349
all_valid_datasets[i].update_ds_weight(valid_loss_weights[i] / factor)
351-
print(f'loss weight of valid dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}')
352-
blending_valid_dataset = GPT2BlendableDataset(all_valid_datasets, valid_sample_weights, global_valid_num, local_valid_num)
353-
350+
print(
351+
f"loss weight of valid dataset {i} after update in rank {global_rank}: {all_valid_datasets[i].ds_weight}"
352+
)
353+
blending_valid_dataset = GPT2BlendableDataset(
354+
all_valid_datasets, valid_sample_weights, global_valid_num, local_valid_num
355+
)
356+
354357
return blending_train_dataset, blending_valid_dataset
355358

356359

@@ -359,11 +362,13 @@ def compile_helper():
359362
is invoked on a single process."""
360363
import os
361364
import subprocess
365+
362366
path = os.path.abspath(os.path.dirname(__file__))
363367
ret = subprocess.run(["make", "-C", path])
364368
if ret.returncode != 0:
365369
print("Making C++ dataset helpers module failed, exiting.")
366370
import sys
371+
367372
sys.exit(1)
368373
else:
369374
print("Making C++ dataset helpers module successfully.")

0 commit comments

Comments
 (0)