|
| 1 | +import torch |
| 2 | +from deepspeed import PipelineModule |
| 3 | +from deepspeed.runtime.pipe import TiedLayerSpec, LayerSpec |
| 4 | +from torch.nn import LayerNorm |
| 5 | + |
| 6 | +from megatron.enums import AttnMaskType, LayerType |
| 7 | + |
| 8 | +from megatron.model.transformer import ParallelTransformerLayerPipe |
| 9 | + |
| 10 | +from megatron.model.language_model import EmbeddingPipe, parallel_lm_logits |
| 11 | + |
| 12 | +from megatron.model.utils import init_method_normal, scaled_init_method_normal |
| 13 | + |
| 14 | +from megatron import get_args, mpu |
| 15 | + |
| 16 | +from megatron.model.module import MegatronModule, fp32_to_16bit, float16_to_fp32 |
| 17 | + |
| 18 | +def cross_entropy(output, labels): |
| 19 | + labels, loss_mask = labels[0], labels[1] |
| 20 | + |
| 21 | + losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) |
| 22 | + |
| 23 | + expected_number_of_tokens = loss_mask.sum() |
| 24 | + |
| 25 | + loss_mask = loss_mask.view(-1) |
| 26 | + loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens |
| 27 | + return loss |
| 28 | + |
| 29 | +class SharedT5ModelPipe(PipelineModule, MegatronModule): |
| 30 | + """Share encoder decoder language model.""" |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + num_tokentypes=0, |
| 35 | + parallel_output=True, |
| 36 | + ): |
| 37 | + args = get_args() |
| 38 | + self.parallel_output = parallel_output |
| 39 | + |
| 40 | + init_method = init_method_normal(args.init_method_std) |
| 41 | + |
| 42 | + self.specs = [] |
| 43 | + |
| 44 | + def _to_16bit(inputs): |
| 45 | + if args.fp16: |
| 46 | + return fp32_to_16bit(inputs, lambda v: v.half()) |
| 47 | + elif args.bf16: |
| 48 | + return fp32_to_16bit(inputs, lambda v: v.bfloat16()) |
| 49 | + else: |
| 50 | + return inputs |
| 51 | + |
| 52 | + self.specs.append(lambda inputss: tuple(_to_16bit(inputs) for inputs in inputss)) |
| 53 | + |
| 54 | + # Embedding layer |
| 55 | + self.specs.append(TiedLayerSpec('embed', |
| 56 | + EmbeddingPipe, |
| 57 | + args.hidden_size, |
| 58 | + args.padded_vocab_size, |
| 59 | + args.hidden_dropout, |
| 60 | + init_method=init_method, |
| 61 | + num_tokentypes=num_tokentypes, |
| 62 | + tied_weight_attr='word_embeddings_weight')) |
| 63 | + |
| 64 | + assert hasattr(args, 'attn_mask'), "Deepspeed integration should have attention mask s" |
| 65 | + if args.fp32_residual_connection: |
| 66 | + self.specs.append(lambda x: x.transpose(0, 1).contiguous().float()) |
| 67 | + else: |
| 68 | + self.specs.append(lambda x: x.transpose(0, 1).contiguous()) |
| 69 | + |
| 70 | + ### ----- Encoder ----- |
| 71 | + for layer_idx in range(args.num_layers): |
| 72 | + self.specs.append( |
| 73 | + TiedLayerSpec( |
| 74 | + f"block_{layer_idx}", |
| 75 | + ParallelTransformerLayerPipe, |
| 76 | + init_method=init_method, |
| 77 | + # Inputs: (input_tokens, target_tokens, |
| 78 | + forward_fn=lambda module, *inputs: , |
| 79 | + output_layer_init_method=scaled_init_method_normal(args.init_method_std, |
| 80 | + args.num_layers), |
| 81 | + layer_type=LayerType.encoder, |
| 82 | + layer_number=layer_idx, |
| 83 | + self_attn_mask_type=AttnMaskType.causal, |
| 84 | + tied_weight_attrs=["input_layernorm", "self_attention", "post_attention_layernorm", "mlp"] |
| 85 | + )) |
| 86 | + |
| 87 | + # Final layernorm after encoder layers |
| 88 | + self.specs.append( |
| 89 | + TiedLayerSpec( |
| 90 | + "final_layer_norm", |
| 91 | + LayerNorm, |
| 92 | + args.hidden_size, |
| 93 | + eps=args.layernorm_epsilon |
| 94 | + )) |
| 95 | + |
| 96 | + # Decoder |
| 97 | + for layer_idx in range(args.num_layers): |
| 98 | + self.specs.append( |
| 99 | + TiedLayerSpec( |
| 100 | + f"block_{layer_idx}", |
| 101 | + ParallelTransformerLayerPipe, |
| 102 | + init_method=init_method, |
| 103 | + output_layer_init_method=scaled_init_method_normal(args.init_method_std, |
| 104 | + args.num_layers), |
| 105 | + layer_number=layer_idx, |
| 106 | + layer_type=LayerType.decoder, |
| 107 | + self_attn_mask_type=AttnMaskType.padding, |
| 108 | + tied_weight_attrs=["input_layernorm", "self_attention", "post_attention_layernorm", "mlp"] |
| 109 | + ) |
| 110 | + ) |
| 111 | + |
| 112 | + # Final layernorm after decoder layers |
| 113 | + self.specs.append( |
| 114 | + TiedLayerSpec( |
| 115 | + "final_layer_norm", |
| 116 | + LayerNorm, |
| 117 | + args.hidden_size, |
| 118 | + eps=args.layernorm_epsilon |
| 119 | + )) |
| 120 | + |
| 121 | + # Undo data format change |
| 122 | + self.specs.append(lambda x: x.transpose(0, 1).contiguous()) |
| 123 | + |
| 124 | + def _logits_helper(embedding, lm_output): |
| 125 | + """A wrapper to massage inputs/outputs from pipeline. """ |
| 126 | + return parallel_lm_logits( |
| 127 | + lm_output, |
| 128 | + embedding.word_embeddings_weight, |
| 129 | + self.parallel_output) |
| 130 | + |
| 131 | + self.specs.append( |
| 132 | + TiedLayerSpec('embed', |
| 133 | + EmbeddingPipe, |
| 134 | + args.hidden_size, |
| 135 | + args.padded_vocab_size, |
| 136 | + args.hidden_dropout, |
| 137 | + init_method=init_method, |
| 138 | + num_tokentypes=num_tokentypes, |
| 139 | + forward_fn=_logits_helper, |
| 140 | + tied_weight_attr='word_embeddings_weight') |
| 141 | + ) |
| 142 | + |
| 143 | + if not hasattr(args, 'attn_mask'): |
| 144 | + # We drop attention mask from the pipeline |
| 145 | + self.specs.append(lambda x: x[0]) |
| 146 | + |
| 147 | + # Final layernorm after transformer layers |
| 148 | + self.specs.append( |
| 149 | + TiedLayerSpec( |
| 150 | + "final_layer_norm", |
| 151 | + LayerNorm, |
| 152 | + args.hidden_size, |
| 153 | + eps=args.layernorm_epsilon |
| 154 | + )) |
| 155 | + |
| 156 | + # Undo data format change |
| 157 | + self.specs.append(lambda x: x.transpose(0, 1).contiguous()) |
| 158 | + |
| 159 | + # Convert to fp32 if needed |
| 160 | + if args.fp16 or args.bf16: |
| 161 | + self.specs.append(float16_to_fp32) |
| 162 | + |
| 163 | + if args.checkpoint_activations: |
| 164 | + interval = args.checkpoint_num_layers |
| 165 | + else: |
| 166 | + interval = 0 |
| 167 | + |
| 168 | + from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology |
| 169 | + topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(), |
| 170 | + num_mp=mpu.get_tensor_model_parallel_world_size(), |
| 171 | + num_dp=mpu.get_data_parallel_world_size()) |
| 172 | + |
| 173 | + # here one can extend the regex to include more layers to be counted towards partitioning, |
| 174 | + # e.g. 'type:transformer|embedding' will add up all the transformer blocks and also the first |
| 175 | + # and last embedding layers and then partition that transformers+2 layers - so to get a good |
| 176 | + # balance you may want to use less transformer layers |
| 177 | + # |
| 178 | + # caveat emptor: the current implementation of PP fails unless each stage has at least one |
| 179 | + # transformer layer |
| 180 | + if args.pp_partition_method is not None: |
| 181 | + partition_method = args.pp_partition_method |
| 182 | + else: |
| 183 | + partition_method = 'type:transformer' |
| 184 | + |
| 185 | + super().__init__(layers=self.specs, |
| 186 | + loss_fn=cross_entropy, |
| 187 | + topology=topo, |
| 188 | + activation_checkpoint_interval=interval, |
| 189 | + partition_method=partition_method) |
0 commit comments