@@ -115,7 +115,7 @@ def forward(self, input_ids, position_ids, attention_mask, labels=None,
115115
116116 # attention_mask has size [1, 1, seqlen, seqlen]
117117 attention_mask = attention_mask [:, :, :curriculum_seqlen , :curriculum_seqlen ].contiguous ()
118-
118+
119119 lm_output = self .language_model (
120120 input_ids ,
121121 position_ids ,
@@ -230,7 +230,7 @@ def _to_float16(inputs):
230230 init_method = init_method ,
231231 num_tokentypes = num_tokentypes ,
232232 tied_weight_attr = 'word_embeddings_weight' ))
233-
233+
234234 if args .fp32_residual_connection :
235235 if hasattr (args , 'attn_mask' ):
236236 self .specs .append (lambda x : x .transpose (0 , 1 ).contiguous ().float ())
@@ -253,8 +253,8 @@ def _to_float16(inputs):
253253 layer_number = layer_idx ,
254254 # TODO: Change naming of class from GPT to something that encapsulate prefix lm.
255255 self_attn_mask_type = AttnMaskType .prefix if prefix_lm else AttnMaskType .causal ))
256-
257-
256+
257+
258258 if not hasattr (args , 'attn_mask' ):
259259 # We drop attention mask from the pipeline
260260 self .specs .append (lambda x : x [0 ])
@@ -295,14 +295,26 @@ def _logits_helper(embedding, lm_output):
295295 interval = args .checkpoint_num_layers
296296 else :
297297 interval = 0
298-
298+
299299 from deepspeed .runtime .pipe .topology import PipeModelDataParallelTopology
300300 topo = PipeModelDataParallelTopology (num_pp = mpu .get_pipeline_model_parallel_world_size (),
301301 num_mp = mpu .get_tensor_model_parallel_world_size (),
302302 num_dp = mpu .get_data_parallel_world_size ())
303303
304+ # here one can extend the regex to include more layers to be counted towards partitioning,
305+ # e.g. 'type:transformer|embedding' will add up all the transformer blocks and also the first
306+ # and last embedding layers and then partition that transformers+2 layers - so to get a good
307+ # balance you may want to use less transformer layers
308+ #
309+ # caveat emptor: the current implementation of PP fails unless each stage has at least one
310+ # transformer layer
311+ if args .pp_partition_method is not None :
312+ partition_method = args .pp_partition_method
313+ else :
314+ partition_method = 'type:transformer'
315+
304316 super ().__init__ (layers = self .specs ,
305317 loss_fn = get_cross_entropy (is_prefix = prefix_lm ),
306318 topology = topo ,
307319 activation_checkpoint_interval = interval ,
308- partition_method = 'type:transformer' )
320+ partition_method = partition_method )
0 commit comments