Skip to content

Commit 555f1ba

Browse files
committed
WIP
1 parent 98a4c3b commit 555f1ba

File tree

1 file changed

+22
-35
lines changed

1 file changed

+22
-35
lines changed

pretrain_shared_t5.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
import torch
2-
from functools import partial
32
from megatron import get_args
43
from megatron import print_rank_0
5-
from megatron import get_timers
64
from megatron import get_tokenizer
75
from megatron import mpu
8-
from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group
6+
from megatron.data.mlm_dataset import build_train_valid_test_datasets, build_dataset_group
97
from megatron.model import SharedT5ModelPipe
108
from megatron.training import pretrain
11-
from megatron.utils import get_attention_masks_and_position_ids, get_prefix_indices
12-
from megatron.utils import average_losses_across_data_parallel_group
9+
from megatron.utils import get_attention_masks_and_position_ids
1310

1411
import deepspeed
1512
from deepspeed.runtime.utils import see_memory_usage
16-
import os
1713

1814
try:
1915
from torch.distributed.elastic.multiprocessing.errors import record
@@ -39,24 +35,6 @@ def model_provider(pre_process=True, post_process=True):
3935
# TODO @thomasw21: fix this for PP > 1 (the issue is that you're passing two values that require grad)
4036
assert mpu.get_pipeline_model_parallel_world_size() != 1, "PP > 1 is not supported yet"
4137

42-
# TODO: actually I'm fairly confident that you don't need the causal mask here as it's handled with `AttnMaskType`
43-
# # Precompute the attention mask and store it in args. This avoids having to
44-
# # pipeline it as an activation during training. The mask is constant, and thus
45-
# # we can reuse it.
46-
# attention_mask = torch.tril(torch.ones(
47-
# (1, args.seq_length, args.seq_length), device=torch.cuda.current_device())).view(
48-
# 1, 1, args.seq_length, args.seq_length)
49-
#
50-
# # Convert attention mask to binary:
51-
# attention_mask = (attention_mask < 0.5)
52-
# if args.fp16:
53-
# attention_mask = attention_mask.half()
54-
# elif args.bf16:
55-
# attention_mask = attention_mask.bfloat16()
56-
#
57-
# # must be bool or the training crashes expecting bool, but getting Half
58-
# args.attn_mask = attention_mask.to(torch.bool)
59-
6038
model = SharedT5ModelPipe(
6139
num_tokentypes=0,
6240
parallel_output=True
@@ -72,12 +50,10 @@ def model_provider(pre_process=True, post_process=True):
7250

7351
def get_batch_pipe(data):
7452
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
75-
raise NotImplementedError("Waiting for MLM data loader to work")
7653
args = get_args()
7754
tokenizer = get_tokenizer()
7855

7956
# Items and their type.
80-
# TODO @thomasw21
8157
keys = ["input_tokens", "target_tokens"]
8258
datatype = torch.int64
8359

@@ -116,7 +92,7 @@ def get_batch_pipe(data):
11692

11793
def train_valid_test_datasets_provider(train_val_test_num_samples):
11894
"""Build train, valid, and test datasets."""
119-
raise NotImplementedError("Waiting for MLM data loader")
95+
12096
args = get_args()
12197
train_ds, valid_ds, test_ds = None, None, None
12298

@@ -129,9 +105,12 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
129105
data_impl=args.data_impl,
130106
splits_string=args.split,
131107
train_valid_test_num_samples=train_val_test_num_samples,
132-
seq_length=args.seq_length,
108+
sequence_length=args.seq_length,
109+
noise_density=args.noise_density,
110+
mean_noise_span_length=args.mean_noise_span_length,
133111
seed=args.seed,
134-
skip_warmup=(not args.mmap_warmup))
112+
skip_warmup=(not args.mmap_warmup)
113+
)
135114
# Option 2 of data loading using --(train|valid|test)-weighted-split-paths
136115
elif args.train_weighted_split_paths:
137116
assigned_train_valid_test = []
@@ -151,12 +130,20 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
151130
eval(f"args.{s}_weighted_split_splits"),
152131
eval(f"args.{s}_weighted_split_names"))
153132
for paths, weights, splits, name in data_groups:
154-
d = build_dataset_group(name, paths, weights, splits,
155-
args.data_impl,
156-
train_val_test_num_samples,
157-
args.seq_length, args.seed,
158-
(not args.mmap_warmup),
159-
train_valid_test=s)
133+
d = build_dataset_group(
134+
dataset_group_name=name,
135+
paths=paths,
136+
weights=weights,
137+
splits=splits,
138+
data_impl=args.data_impl,
139+
train_valid_test_num_samples=train_val_test_num_samples,
140+
seq_length=args.seq_length,
141+
noise_density=args.noise_density,
142+
mean_noise_span_length=args.mean_noise_span_length,
143+
seed=args.seed,
144+
skip_warmup=(not args.mmap_warmup),
145+
train_valid_test=s
146+
)
160147
eval(f"{s}_ds").append(d)
161148
else:
162149
raise NotImplementedError("No dataloading argument passed")

0 commit comments

Comments
 (0)