@@ -98,6 +98,20 @@ def get_args(argv):
9898 type = int ,
9999 help = "Pipeline parallelism size" ,
100100 )
101+ parser .add_argument (
102+ "-nlfps" ,
103+ "--num_layers_in_first_pipeline_stage" ,
104+ default = None ,
105+ type = int ,
106+ help = "Number of layers in the first pipeline stage" ,
107+ )
108+ parser .add_argument (
109+ "-nllps" ,
110+ "--num_layers_in_last_pipeline_stage" ,
111+ default = None ,
112+ type = int ,
113+ help = "Number of layers in the last pipeline stage" ,
114+ )
101115 parser .add_argument (
102116 "-cps" ,
103117 "--context_parallel_size" ,
@@ -112,6 +126,20 @@ def get_args(argv):
112126 type = int ,
113127 help = "Distributes MoE Experts across sub data parallel dimension." ,
114128 )
129+ parser .add_argument (
130+ "-eps" ,
131+ "--account_for_embedding_in_pipeline_split" ,
132+ default = False ,
133+ action = "store_true" ,
134+ help = "Account for embedding in the pipeline split" ,
135+ )
136+ parser .add_argument (
137+ "-lps" ,
138+ "--account_for_loss_in_pipeline_split" ,
139+ default = False ,
140+ action = "store_true" ,
141+ help = "Account for loss in the pipeline split" ,
142+ )
115143 parser .add_argument (
116144 "-mbs" ,
117145 "--max_batch_size" ,
@@ -203,6 +231,17 @@ def nemo_deploy(argv):
203231 if args .nemo_checkpoint is None :
204232 raise ValueError ("In-Framework deployment requires a checkpoint folder." )
205233
234+ model_config_kwargs = {
235+ "account_for_embedding_in_pipeline_split" : args .account_for_embedding_in_pipeline_split ,
236+ "account_for_loss_in_pipeline_split" : args .account_for_loss_in_pipeline_split ,
237+ }
238+
239+ if args .num_layers_in_first_pipeline_stage is not None :
240+ model_config_kwargs ["num_layers_in_first_pipeline_stage" ] = args .num_layers_in_first_pipeline_stage
241+
242+ if args .num_layers_in_last_pipeline_stage is not None :
243+ model_config_kwargs ["num_layers_in_last_pipeline_stage" ] = args .num_layers_in_last_pipeline_stage
244+
206245 model = MegatronLLMDeployableNemo2 (
207246 num_devices = args .num_gpus ,
208247 num_nodes = args .num_nodes ,
@@ -219,6 +258,7 @@ def nemo_deploy(argv):
219258 model_type = args .model_type ,
220259 model_format = args .model_format ,
221260 micro_batch_size = args .micro_batch_size ,
261+ ** model_config_kwargs ,
222262 )
223263
224264 if torch .distributed .is_initialized ():
0 commit comments