Skip to content

Commit c97f9bd

Browse files
committed
update example
1 parent 2cf41c6 commit c97f9bd

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

examples/nemo_run/qat/nemo_qat_flow.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def get_args():
140140
action="store_true",
141141
default=False,
142142
)
143+
parser.add_argument("--tensor_parallelism", type=int, default=1)
144+
parser.add_argument("--pipeline_parallelism", type=int, default=1)
143145
return parser.parse_args()
144146

145147

@@ -243,6 +245,8 @@ def main(args):
243245
train.trainer.devices = args.train_gpus
244246
train.trainer.num_nodes = args.train_nodes
245247
train.trainer.limit_val_batches = 32
248+
train.trainer.strategy.tensor_model_parallel_size = args.tensor_parallelism
249+
train.trainer.strategy.pipeline_model_parallel_size = args.pipeline_parallelism
246250

247251
# 5. Export
248252
export = run.Partial(
@@ -257,29 +261,33 @@ def main(args):
257261
mmlu_script_path = "examples/nemo_run/common/in_memory_mmlu.py"
258262
eval_ptq = run.Script(
259263
mmlu_script_path,
260-
args=["--nemo_ckpt", ptq_model_out],
264+
args=["--nemo_ckpt", ptq_model_out, "--tensor_parallelism", f"{args.ptq_gpus}"],
261265
entrypoint="python",
262266
)
263267
eval_bf16 = run.Script(
264268
mmlu_script_path,
265-
args=["--nemo_ckpt", bf16_ckpt_path],
269+
args=["--nemo_ckpt", bf16_ckpt_path, "--tensor_parallelism", f"{args.ptq_gpus}"],
266270
entrypoint="python",
267271
)
268272
eval_sft = run.Script(
269273
mmlu_script_path,
270-
args=["--finetuned_ckpt_dir", exp_dir],
274+
args=["--finetuned_ckpt_dir", exp_dir, "--tensor_parallelism", f"{args.ptq_gpus}"],
271275
entrypoint="python",
272276
)
273277

274278
if args.use_slurm:
275279
cpu_executor = create_slurm_executor(SLURM_CONFIG)
276-
gpu_executor = create_slurm_executor(
280+
ptq_gpu_executor = create_slurm_executor(
277281
SLURM_CONFIG, num_gpus=args.ptq_gpus, ntasks_per_node=args.ptq_gpus
278282
)
283+
train_gpu_executor = create_slurm_executor(
284+
SLURM_CONFIG, num_gpus=args.train_gpus, ntasks_per_node=args.train_gpus
285+
)
279286
single_gpu_executor = create_slurm_executor(SLURM_CONFIG, num_gpus=1, ntasks_per_node=1)
280287
else:
281288
cpu_executor = single_gpu_executor = run.LocalExecutor()
282-
gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus)
289+
ptq_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.ptq_gpus)
290+
train_gpu_executor = run.LocalExecutor(launcher="torchrun", ntasks_per_node=args.train_gpus)
283291

284292
with run.Experiment(exp_dir, log_level="INFO") as exp:
285293
if not args.data_path:
@@ -294,45 +302,46 @@ def main(args):
294302
eval_bf16,
295303
tail_logs=True,
296304
name="02_mmlu_bf16",
297-
executor=single_gpu_executor,
305+
executor=ptq_gpu_executor,
298306
dependencies=[s1],
299307
)
300308

301309
# 2. PTQ model and evaluate PTQ model
302-
s2 = exp.add(ptq, tail_logs=True, name="03_ptq", executor=gpu_executor, dependencies=[s1])
310+
s2 = exp.add(
311+
ptq, tail_logs=True, name="03_ptq", executor=ptq_gpu_executor, dependencies=[s1]
312+
)
303313
s3 = exp.add(
304314
eval_ptq,
305315
tail_logs=True,
306316
name="04_mmlu_ptq",
307-
executor=single_gpu_executor,
317+
executor=ptq_gpu_executor,
308318
dependencies=[s2],
309319
)
310320
# 3. Train PTQ model (QAT or QAD)
311-
if args.use_slurm: # Set training arguments
312-
gpu_executor.nodes = args.train_nodes
313-
gpu_executor.gpus_per_node = gpu_executor.ntasks_per_node = args.train_gpus
314-
else:
315-
gpu_executor.ntasks_per_node = args.train_gpus
316321
train_dep = [s3]
317322
if not args.data_path:
318323
train_dep.append(s0)
319324
s4 = exp.add(
320-
train, tail_logs=True, name="05_train", executor=gpu_executor, dependencies=train_dep
325+
train,
326+
tail_logs=True,
327+
name="05_train",
328+
executor=train_gpu_executor,
329+
dependencies=train_dep,
321330
)
322-
323331
s5 = exp.add(
324332
eval_sft,
325333
tail_logs=True,
326334
name="06_mmlu_sft",
327-
executor=single_gpu_executor,
335+
executor=ptq_gpu_executor,
328336
dependencies=[s4],
329337
)
330-
gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export
338+
# WAR: Export needs access to all GPUs but only 1 task due to bug in NeMo
339+
train_gpu_executor.ntasks_per_node = 1 # will throw error if more than 1 task during export
331340
exp.add(
332341
export,
333342
tail_logs=True,
334343
name="07_export_hf",
335-
executor=gpu_executor,
344+
executor=train_gpu_executor,
336345
dependencies=[s5],
337346
)
338347
exp.run(detach=True)
@@ -356,10 +365,7 @@ def main(args):
356365
use_local_tunnel=False,
357366
host="",
358367
user="",
359-
container_mounts=[
360-
"/path/to/logs:/path/to/logs",
361-
"/path/to/NeMo:/opt/NeMo",
362-
],
368+
container_mounts=[],
363369
job_dir="/path/to/logs",
364370
identity=None,
365371
)
@@ -369,7 +375,7 @@ def main(args):
369375
SEQUENCE_LENGTH = 4096
370376
MBS = 1
371377
GBS = 512
372-
TRAIN_STEPS = 200
378+
TRAIN_STEPS = 400
373379
VAL_INTERVAL = 50
374380
# # # # # # # # # # # # # # # # # # # # # #
375381

0 commit comments

Comments
 (0)