Skip to content

Commit 00f5f88

Browse files
authored
make partition_method configurable (#256)
* make partition_method configurable * fix
1 parent 49d21af commit 00f5f88

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

megatron/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,8 @@ def _add_training_args(parser):
565565
help='Iteration ranges to skip. The values are one or more dash-separated ranges. e.g., 101-200 251-300.')
566566
group.add_argument('--abort-on-unmet-fused-kernel-constraints', action='store_true',
567567
help="If set to True, the program will abort if the constraints for loading a fused kernel aren't met")
568+
group.add_argument('--pp-partition-method', type=str, default=None,
569+
help="Use to override the pipeline stages partitioning method. e.g., 'type:transformer|embedding'")
568570

569571
return parser
570572

megatron/model/gpt_model.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def forward(self, input_ids, position_ids, attention_mask, labels=None,
115115

116116
# attention_mask has size [1, 1, seqlen, seqlen]
117117
attention_mask = attention_mask[:, :, :curriculum_seqlen, :curriculum_seqlen].contiguous()
118-
118+
119119
lm_output = self.language_model(
120120
input_ids,
121121
position_ids,
@@ -230,7 +230,7 @@ def _to_float16(inputs):
230230
init_method=init_method,
231231
num_tokentypes=num_tokentypes,
232232
tied_weight_attr='word_embeddings_weight'))
233-
233+
234234
if args.fp32_residual_connection:
235235
if hasattr(args, 'attn_mask'):
236236
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
@@ -253,8 +253,8 @@ def _to_float16(inputs):
253253
layer_number=layer_idx,
254254
# TODO: Change naming of class from GPT to something that encapsulate prefix lm.
255255
self_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal))
256-
257-
256+
257+
258258
if not hasattr(args, 'attn_mask'):
259259
# We drop attention mask from the pipeline
260260
self.specs.append(lambda x: x[0])
@@ -295,14 +295,26 @@ def _logits_helper(embedding, lm_output):
295295
interval = args.checkpoint_num_layers
296296
else:
297297
interval = 0
298-
298+
299299
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
300300
topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(),
301301
num_mp=mpu.get_tensor_model_parallel_world_size(),
302302
num_dp=mpu.get_data_parallel_world_size())
303303

304+
# here one can extend the regex to include more layers to be counted towards partitioning,
305+
# e.g. 'type:transformer|embedding' will add up all the transformer blocks and also the first
306+
# and last embedding layers and then partition that transformers+2 layers - so to get a good
307+
# balance you may want to use less transformer layers
308+
#
309+
# caveat emptor: the current implementation of PP fails unless each stage has at least one
310+
# transformer layer
311+
if args.pp_partition_method is not None:
312+
partition_method = args.pp_partition_method
313+
else:
314+
partition_method = 'type:transformer'
315+
304316
super().__init__(layers=self.specs,
305317
loss_fn=get_cross_entropy(is_prefix=prefix_lm),
306318
topology=topo,
307319
activation_checkpoint_interval=interval,
308-
partition_method='type:transformer')
320+
partition_method=partition_method)

0 commit comments

Comments
 (0)