@@ -140,6 +140,8 @@ def get_args():
140140 action = "store_true" ,
141141 default = False ,
142142 )
143+ parser .add_argument ("--tensor_parallelism" , type = int , default = 1 )
144+ parser .add_argument ("--pipeline_parallelism" , type = int , default = 1 )
143145 return parser .parse_args ()
144146
145147
@@ -243,6 +245,8 @@ def main(args):
243245 train .trainer .devices = args .train_gpus
244246 train .trainer .num_nodes = args .train_nodes
245247 train .trainer .limit_val_batches = 32
248+ train .trainer .strategy .tensor_model_parallel_size = args .tensor_parallelism
249+ train .trainer .strategy .pipeline_model_parallel_size = args .pipeline_parallelism
246250
247251 # 5. Export
248252 export = run .Partial (
@@ -257,29 +261,33 @@ def main(args):
257261 mmlu_script_path = "examples/nemo_run/common/in_memory_mmlu.py"
258262 eval_ptq = run .Script (
259263 mmlu_script_path ,
260- args = ["--nemo_ckpt" , ptq_model_out ],
264+ args = ["--nemo_ckpt" , ptq_model_out , "--tensor_parallelism" , f" { args . ptq_gpus } " ],
261265 entrypoint = "python" ,
262266 )
263267 eval_bf16 = run .Script (
264268 mmlu_script_path ,
265- args = ["--nemo_ckpt" , bf16_ckpt_path ],
269+ args = ["--nemo_ckpt" , bf16_ckpt_path , "--tensor_parallelism" , f" { args . ptq_gpus } " ],
266270 entrypoint = "python" ,
267271 )
268272 eval_sft = run .Script (
269273 mmlu_script_path ,
270- args = ["--finetuned_ckpt_dir" , exp_dir ],
274+ args = ["--finetuned_ckpt_dir" , exp_dir , "--tensor_parallelism" , f" { args . ptq_gpus } " ],
271275 entrypoint = "python" ,
272276 )
273277
274278 if args .use_slurm :
275279 cpu_executor = create_slurm_executor (SLURM_CONFIG )
276- gpu_executor = create_slurm_executor (
280+ ptq_gpu_executor = create_slurm_executor (
277281 SLURM_CONFIG , num_gpus = args .ptq_gpus , ntasks_per_node = args .ptq_gpus
278282 )
283+ train_gpu_executor = create_slurm_executor (
284+ SLURM_CONFIG , num_gpus = args .train_gpus , ntasks_per_node = args .train_gpus
285+ )
279286 single_gpu_executor = create_slurm_executor (SLURM_CONFIG , num_gpus = 1 , ntasks_per_node = 1 )
280287 else :
281288 cpu_executor = single_gpu_executor = run .LocalExecutor ()
282- gpu_executor = run .LocalExecutor (launcher = "torchrun" , ntasks_per_node = args .ptq_gpus )
289+ ptq_gpu_executor = run .LocalExecutor (launcher = "torchrun" , ntasks_per_node = args .ptq_gpus )
290+ train_gpu_executor = run .LocalExecutor (launcher = "torchrun" , ntasks_per_node = args .train_gpus )
283291
284292 with run .Experiment (exp_dir , log_level = "INFO" ) as exp :
285293 if not args .data_path :
@@ -294,45 +302,46 @@ def main(args):
294302 eval_bf16 ,
295303 tail_logs = True ,
296304 name = "02_mmlu_bf16" ,
297- executor = single_gpu_executor ,
305+ executor = ptq_gpu_executor ,
298306 dependencies = [s1 ],
299307 )
300308
301309 # 2. PTQ model and evaluate PTQ model
302- s2 = exp .add (ptq , tail_logs = True , name = "03_ptq" , executor = gpu_executor , dependencies = [s1 ])
310+ s2 = exp .add (
311+ ptq , tail_logs = True , name = "03_ptq" , executor = ptq_gpu_executor , dependencies = [s1 ]
312+ )
303313 s3 = exp .add (
304314 eval_ptq ,
305315 tail_logs = True ,
306316 name = "04_mmlu_ptq" ,
307- executor = single_gpu_executor ,
317+ executor = ptq_gpu_executor ,
308318 dependencies = [s2 ],
309319 )
310320 # 3. Train PTQ model (QAT or QAD)
311- if args .use_slurm : # Set training arguments
312- gpu_executor .nodes = args .train_nodes
313- gpu_executor .gpus_per_node = gpu_executor .ntasks_per_node = args .train_gpus
314- else :
315- gpu_executor .ntasks_per_node = args .train_gpus
316321 train_dep = [s3 ]
317322 if not args .data_path :
318323 train_dep .append (s0 )
319324 s4 = exp .add (
320- train , tail_logs = True , name = "05_train" , executor = gpu_executor , dependencies = train_dep
325+ train ,
326+ tail_logs = True ,
327+ name = "05_train" ,
328+ executor = train_gpu_executor ,
329+ dependencies = train_dep ,
321330 )
322-
323331 s5 = exp .add (
324332 eval_sft ,
325333 tail_logs = True ,
326334 name = "06_mmlu_sft" ,
327- executor = single_gpu_executor ,
335+ executor = ptq_gpu_executor ,
328336 dependencies = [s4 ],
329337 )
330- gpu_executor .ntasks_per_node = 1 # will throw error if more than 1 task during export
338+ # WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
339+ train_gpu_executor .ntasks_per_node = 1 # will throw error if more than 1 task during export
331340 exp .add (
332341 export ,
333342 tail_logs = True ,
334343 name = "07_export_hf" ,
335- executor = gpu_executor ,
344+ executor = train_gpu_executor ,
336345 dependencies = [s5 ],
337346 )
338347 exp .run (detach = True )
@@ -356,10 +365,7 @@ def main(args):
356365 use_local_tunnel = False ,
357366 host = "" ,
358367 user = "" ,
359- container_mounts = [
360- "/path/to/logs:/path/to/logs" ,
361- "/path/to/NeMo:/opt/NeMo" ,
362- ],
368+ container_mounts = [],
363369 job_dir = "/path/to/logs" ,
364370 identity = None ,
365371 )
@@ -369,7 +375,7 @@ def main(args):
369375 SEQUENCE_LENGTH = 4096
370376 MBS = 1
371377 GBS = 512
372- TRAIN_STEPS = 200
378+ TRAIN_STEPS = 400
373379 VAL_INTERVAL = 50
374380 # # # # # # # # # # # # # # # # # # # # # #
375381
0 commit comments