Skip to content

Commit 86aa795

Browse files
committed
WIP
1 parent e23393f commit 86aa795

File tree

12 files changed

+395
-24
lines changed

12 files changed

+395
-24
lines changed

megatron/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .distributed import DistributedDataParallel
1919
from .bert_model import BertModel
2020
from .gpt_model import GPTModel, GPTModelPipe
21+
from .shared_t5_model import SharedT5ModelPipe
2122
from .t5_model import T5Model
2223
from .language_model import get_language_model
2324
from .module import Float16Module

megatron/model/gpt_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from megatron import get_args
2222
from megatron import mpu
2323
from megatron.enums import AttnMaskType
24-
from .module import MegatronModule, fp32_to_float16
24+
from .module import MegatronModule, fp32_to_16bit
2525

2626
from .language_model import parallel_lm_logits
2727
from .language_model import get_language_model
@@ -213,9 +213,9 @@ def __init__(
213213

214214
def _to_float16(inputs):
215215
if args.fp16:
216-
return fp32_to_float16(inputs, lambda v: v.half())
216+
return fp32_to_16bit(inputs, lambda v: v.half())
217217
elif args.bf16:
218-
return fp32_to_float16(inputs, lambda v: v.bfloat16())
218+
return fp32_to_16bit(inputs, lambda v: v.bfloat16())
219219
else:
220220
return inputs
221221

megatron/model/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def conversion_helper(val, conversion):
122122
return rtn
123123

124124

125-
def fp32_to_float16(val, float16_convertor):
125+
def fp32_to_16bit(val, float16_convertor):
126126
"""Convert fp32 `val` to fp16/bf16"""
127127
def half_conversion(val):
128128
val_typecheck = val
@@ -168,7 +168,7 @@ def float16_convertor(val):
168168

169169
def forward(self, *inputs, **kwargs):
170170
if mpu.is_pipeline_first_stage():
171-
inputs = fp32_to_float16(inputs, self.float16_convertor)
171+
inputs = fp32_to_16bit(inputs, self.float16_convertor)
172172
outputs = self.module(*inputs, **kwargs)
173173
if mpu.is_pipeline_last_stage():
174174
outputs = float16_to_fp32(outputs)

megatron/model/shared_t5_model.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import torch
2+
from deepspeed import PipelineModule
3+
from deepspeed.runtime.pipe import TiedLayerSpec, LayerSpec
4+
from torch.nn import LayerNorm
5+
6+
from megatron.enums import AttnMaskType, LayerType
7+
8+
from megatron.model.transformer import ParallelTransformerLayerPipe
9+
10+
from megatron.model.language_model import EmbeddingPipe, parallel_lm_logits
11+
12+
from megatron.model.utils import init_method_normal, scaled_init_method_normal
13+
14+
from megatron import get_args, mpu
15+
16+
from megatron.model.module import MegatronModule, fp32_to_16bit, float16_to_fp32
17+
18+
def cross_entropy(output, labels):
19+
labels, loss_mask = labels[0], labels[1]
20+
21+
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels)
22+
23+
expected_number_of_tokens = loss_mask.sum()
24+
25+
loss_mask = loss_mask.view(-1)
26+
loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens
27+
return loss
28+
29+
class SharedT5ModelPipe(PipelineModule, MegatronModule):
30+
"""Share encoder decoder language model."""
31+
32+
def __init__(
33+
self,
34+
num_tokentypes=0,
35+
parallel_output=True,
36+
):
37+
args = get_args()
38+
self.parallel_output = parallel_output
39+
40+
init_method = init_method_normal(args.init_method_std)
41+
42+
self.specs = []
43+
44+
def _to_16bit(inputs):
45+
if args.fp16:
46+
return fp32_to_16bit(inputs, lambda v: v.half())
47+
elif args.bf16:
48+
return fp32_to_16bit(inputs, lambda v: v.bfloat16())
49+
else:
50+
return inputs
51+
52+
self.specs.append(lambda inputss: tuple(_to_16bit(inputs) for inputs in inputss))
53+
54+
# Embedding layer
55+
self.specs.append(TiedLayerSpec('embed',
56+
EmbeddingPipe,
57+
args.hidden_size,
58+
args.padded_vocab_size,
59+
args.hidden_dropout,
60+
init_method=init_method,
61+
num_tokentypes=num_tokentypes,
62+
tied_weight_attr='word_embeddings_weight'))
63+
64+
assert hasattr(args, 'attn_mask'), "Deepspeed integration should have attention mask s"
65+
if args.fp32_residual_connection:
66+
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
67+
else:
68+
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
69+
70+
### ----- Encoder -----
71+
for layer_idx in range(args.num_layers):
72+
self.specs.append(
73+
TiedLayerSpec(
74+
f"block_{layer_idx}",
75+
ParallelTransformerLayerPipe,
76+
init_method=init_method,
77+
# Inputs: (input_tokens, target_tokens,
78+
forward_fn=lambda module, *inputs: ,
79+
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
80+
args.num_layers),
81+
layer_type=LayerType.encoder,
82+
layer_number=layer_idx,
83+
self_attn_mask_type=AttnMaskType.causal,
84+
tied_weight_attrs=["input_layernorm", "self_attention", "post_attention_layernorm", "mlp"]
85+
))
86+
87+
# Final layernorm after encoder layers
88+
self.specs.append(
89+
TiedLayerSpec(
90+
"final_layer_norm",
91+
LayerNorm,
92+
args.hidden_size,
93+
eps=args.layernorm_epsilon
94+
))
95+
96+
# Decoder
97+
for layer_idx in range(args.num_layers):
98+
self.specs.append(
99+
TiedLayerSpec(
100+
f"block_{layer_idx}",
101+
ParallelTransformerLayerPipe,
102+
init_method=init_method,
103+
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
104+
args.num_layers),
105+
layer_number=layer_idx,
106+
layer_type=LayerType.decoder,
107+
self_attn_mask_type=AttnMaskType.padding,
108+
tied_weight_attrs=["input_layernorm", "self_attention", "post_attention_layernorm", "mlp"]
109+
)
110+
)
111+
112+
# Final layernorm after decoder layers
113+
self.specs.append(
114+
TiedLayerSpec(
115+
"final_layer_norm",
116+
LayerNorm,
117+
args.hidden_size,
118+
eps=args.layernorm_epsilon
119+
))
120+
121+
# Undo data format change
122+
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
123+
124+
def _logits_helper(embedding, lm_output):
125+
"""A wrapper to massage inputs/outputs from pipeline. """
126+
return parallel_lm_logits(
127+
lm_output,
128+
embedding.word_embeddings_weight,
129+
self.parallel_output)
130+
131+
self.specs.append(
132+
TiedLayerSpec('embed',
133+
EmbeddingPipe,
134+
args.hidden_size,
135+
args.padded_vocab_size,
136+
args.hidden_dropout,
137+
init_method=init_method,
138+
num_tokentypes=num_tokentypes,
139+
forward_fn=_logits_helper,
140+
tied_weight_attr='word_embeddings_weight')
141+
)
142+
143+
if not hasattr(args, 'attn_mask'):
144+
# We drop attention mask from the pipeline
145+
self.specs.append(lambda x: x[0])
146+
147+
# Final layernorm after transformer layers
148+
self.specs.append(
149+
TiedLayerSpec(
150+
"final_layer_norm",
151+
LayerNorm,
152+
args.hidden_size,
153+
eps=args.layernorm_epsilon
154+
))
155+
156+
# Undo data format change
157+
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
158+
159+
# Convert to fp32 if needed
160+
if args.fp16 or args.bf16:
161+
self.specs.append(float16_to_fp32)
162+
163+
if args.checkpoint_activations:
164+
interval = args.checkpoint_num_layers
165+
else:
166+
interval = 0
167+
168+
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
169+
topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(),
170+
num_mp=mpu.get_tensor_model_parallel_world_size(),
171+
num_dp=mpu.get_data_parallel_world_size())
172+
173+
# here one can extend the regex to include more layers to be counted towards partitioning,
174+
# e.g. 'type:transformer|embedding' will add up all the transformer blocks and also the first
175+
# and last embedding layers and then partition that transformers+2 layers - so to get a good
176+
# balance you may want to use less transformer layers
177+
#
178+
# caveat emptor: the current implementation of PP fails unless each stage has at least one
179+
# transformer layer
180+
if args.pp_partition_method is not None:
181+
partition_method = args.pp_partition_method
182+
else:
183+
partition_method = 'type:transformer'
184+
185+
super().__init__(layers=self.specs,
186+
loss_fn=cross_entropy,
187+
topology=topo,
188+
activation_checkpoint_interval=interval,
189+
partition_method=partition_method)

megatron/text_generation_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from megatron import get_args
2727
from megatron import get_tokenizer
2828
from megatron import mpu
29-
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
29+
from megatron.utils import get_attention_masks_and_position_ids, unwrap_model
3030
from megatron.p2p_communication import recv_forward, send_forward
3131

3232
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
@@ -42,7 +42,7 @@ def get_batch(context_tokens):
4242
# Move to GPU.
4343
tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
4444
# Get the attention mask and position ids.
45-
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
45+
attention_mask, _, position_ids = get_attention_masks_and_position_ids(
4646
tokens,
4747
tokenizer.eod,
4848
args.reset_position_ids,

megatron/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,16 @@ def check_adlr_autoresume_termination(iteration, model,
151151
sys.exit(0)
152152

153153

154-
def get_ltor_masks_and_position_ids(
154+
155+
def get_attention_masks_and_position_ids(
155156
data,
156157
eod_token,
157158
reset_position_ids,
158159
reset_attention_mask,
159160
eod_mask_loss,
160161
prefix_indices,
161162
loss_on_targets_only,
163+
ltor=True,
162164
):
163165
"""
164166
Build masks and position id for left to right model.
@@ -177,9 +179,10 @@ def get_ltor_masks_and_position_ids(
177179
att_mask_batch = micro_batch_size
178180
else:
179181
att_mask_batch = 1
180-
attention_mask = torch.tril(torch.ones(
181-
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
182-
att_mask_batch, 1, seq_length, seq_length)
182+
attention_mask = torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
183+
if ltor:
184+
attention_mask = torch.tril(attention_mask)
185+
attention_mask = attention_mask.view(att_mask_batch, 1, seq_length, seq_length)
183186

184187
# Loss mask.
185188
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)

pretrain_gpt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group
2626
from megatron.model import GPTModel, GPTModelPipe
2727
from megatron.training import pretrain
28-
from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices
28+
from megatron.utils import get_attention_masks_and_position_ids, get_prefix_indices
2929
from megatron.utils import average_losses_across_data_parallel_group
3030

3131
import deepspeed
@@ -110,7 +110,7 @@ def get_batch(data_iterator):
110110
tokens = tokens_[:, :-1].contiguous()
111111

112112
# Get the masks and postition ids.
113-
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
113+
attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids(
114114
tokens,
115115
tokenizer.eod,
116116
args.reset_position_ids,
@@ -141,7 +141,7 @@ def get_batch_pipe(data):
141141
tokens = tokens_[:, :-1].contiguous()
142142

143143
# Get the masks and position ids.
144-
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
144+
attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids(
145145
tokens,
146146
tokenizer.eod,
147147
args.reset_position_ids,

pretrain_prefix_lm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group
2626
from megatron.model import GPTModel, GPTModelPipe
2727
from megatron.training import pretrain
28-
from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices, reweight_loss_mask_
28+
from megatron.utils import get_attention_masks_and_position_ids, get_prefix_indices, reweight_loss_mask_
2929
from megatron.utils import average_losses_across_data_parallel_group
3030

3131
import deepspeed
@@ -97,7 +97,7 @@ def get_batch(data_iterator):
9797
)
9898

9999
# Get the masks and postition ids.
100-
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
100+
attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids(
101101
tokens,
102102
tokenizer.eod,
103103
args.reset_position_ids,
@@ -131,6 +131,7 @@ def get_batch_pipe(data):
131131
tokens = tokens_[:, :-1].contiguous()
132132

133133
# Prefix
134+
# TODO @thomasw21 actually since this step is random, we need to make sure that random state are synchronized. Otherwise we need to broadcast after this step.
134135
prefix_indices = get_prefix_indices(
135136
tokens,
136137
tokenizer.eod,
@@ -139,7 +140,7 @@ def get_batch_pipe(data):
139140
)
140141

141142
# Get the masks and position ids.
142-
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
143+
attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids(
143144
tokens,
144145
tokenizer.eod,
145146
args.reset_position_ids,

0 commit comments

Comments
 (0)