Skip to content

Commit 3524732

Browse files
authored
[1/N] QATTrainer training workflow fixes and clean up; Added backend specific unitests; (#318)
Signed-off-by: realAsma <[email protected]>
1 parent 8d0e40f commit 3524732

File tree

20 files changed

+345
-410
lines changed

20 files changed

+345
-410
lines changed

examples/llm_qat/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def forward_loop(model):
8282

8383

8484
# Quantize the model in-place; The model should be unwrapped from any distributed wrapper
85-
# The model may be wrapped in a DataParallel or DistributedDataParallel after `mtq.quantize`
8685
model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop)
8786

8887
# Save the modelopt quantizer states
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
deepspeed_config:
4+
gradient_clipping: 1.0
5+
offload_optimizer_device: none
6+
offload_param_device: none
7+
zero3_init_flag: true
8+
zero3_save_16bit_model: true
9+
zero_stage: 3
10+
distributed_type: DEEPSPEED
11+
downcast_bf16: 'no'
12+
enable_cpu_affinity: false
13+
machine_rank: 0
14+
main_training_function: main
15+
mixed_precision: bf16
16+
num_machines: 1
17+
num_processes: gpu
18+
rdzv_backend: static
19+
same_network: true
20+
tpu_env: []
21+
tpu_use_cluster: false
22+
tpu_use_sudo: false
23+
use_cpu: false

examples/llm_qat/accelerate_config/fsdp1.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ distributed_type: FSDP
44
downcast_bf16: 'no'
55
enable_cpu_affinity: false
66
fsdp_config:
7-
fsdp_activation_checkpointing: false
7+
fsdp_activation_checkpointing: true
88
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
99
fsdp_backward_prefetch: BACKWARD_PRE
1010
fsdp_cpu_ram_efficient_loading: true

examples/llm_qat/convert_sharded_ckpt.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

examples/llm_qat/launch.sh

Lines changed: 73 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -18,96 +18,37 @@ set -eo pipefail
1818

1919
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
2020

21+
# Helper function to parse a single argument value
22+
parse_value() {
23+
if [[ "$1" != *=* ]]; then shift; fi
24+
echo "${1#*=}"
25+
}
26+
2127
while [ $# -gt 0 ]; do
2228
case "$1" in
23-
--model*)
24-
if [[ "$1" != *=* ]]; then shift; fi
25-
MODEL="${1#*=}"
26-
;;
27-
--output_dir*)
28-
if [[ "$1" != *=* ]]; then shift; fi
29-
OUTPUT_DIR="${1#*=}"
30-
;;
31-
--dataset*)
32-
if [[ "$1" != *=* ]]; then shift; fi
33-
DATASET="${1#*=}"
34-
;;
35-
--train_size*)
36-
if [[ "$1" != *=* ]]; then shift; fi
37-
TRAIN_SIZE="${1#*=}"
38-
;;
39-
--eval_size*)
40-
if [[ "$1" != *=* ]]; then shift; fi
41-
EVAL_SIZE="${1#*=}"
42-
;;
43-
--num_epochs*)
44-
if [[ "$1" != *=* ]]; then shift; fi
45-
NUM_EPOCHS="${1#*=}"
46-
;;
47-
--max_steps*)
48-
if [[ "$1" != *=* ]]; then shift; fi
49-
MAX_STEPS="${1#*=}"
50-
;;
51-
--save_steps*)
52-
if [[ "$1" != *=* ]]; then shift; fi
53-
SAVE_STEPS="${1#*=}"
54-
;;
55-
--accum_steps*)
56-
if [[ "$1" != *=* ]]; then shift; fi
57-
ACCUM_STEPS="${1#*=}"
58-
;;
59-
--lr*)
60-
if [[ "$1" != *=* ]]; then shift; fi
61-
LR="${1#*=}"
62-
;;
63-
--quant_cfg*)
64-
if [[ "$1" != *=* ]]; then shift; fi
65-
QUANT_CFG="${1#*=}"
66-
;;
67-
--compress*)
68-
if [[ "$1" != *=* ]]; then shift; fi
69-
COMPRESS="${1#*=}"
70-
;;
71-
--calib_size*)
72-
if [[ "$1" != *=* ]]; then shift; fi
73-
CALIB_SIZE="${1#*=}"
74-
;;
75-
--train_bs*)
76-
if [[ "$1" != *=* ]]; then shift; fi
77-
TRAIN_BS="${1#*=}"
78-
;;
79-
--eval_bs*)
80-
if [[ "$1" != *=* ]]; then shift; fi
81-
EVAL_BS="${1#*=}"
82-
;;
83-
--do_train*)
84-
if [[ "$1" != *=* ]]; then shift; fi
85-
DO_TRAIN="${1#*=}"
86-
;;
87-
--lora*)
88-
if [[ "$1" != *=* ]]; then shift; fi
89-
LORA="${1#*=}"
90-
;;
91-
--teacher_model*)
92-
if [[ "$1" != *=* ]]; then shift; fi
93-
TEACHER_MODEL="${1#*=}"
94-
;;
95-
--distill*)
96-
if [[ "$1" != *=* ]]; then shift; fi
97-
DISTILL="${1#*=}"
98-
;;
99-
--fsdp_transformer_layer_cls_to_wrap*)
100-
if [[ "$1" != *=* ]]; then shift; fi
101-
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
102-
;;
103-
--use_fsdp2*)
104-
if [[ "$1" != *=* ]]; then shift; fi
105-
USE_FSDP2="${1#*=}"
106-
;;
107-
--max_seq_length*)
108-
if [[ "$1" != *=* ]]; then shift; fi
109-
MAX_SEQ_LENGTH="${1#*=}"
110-
;;
29+
--model*) MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
30+
--output_dir*) OUTPUT_DIR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
31+
--dataset*) DATASET=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
32+
--train_size*) TRAIN_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
33+
--eval_size*) EVAL_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
34+
--num_epochs*) NUM_EPOCHS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
35+
--max_steps*) MAX_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
36+
--save_steps*) SAVE_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
37+
--accum_steps*) ACCUM_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
38+
--lr*) LR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
39+
--quant_cfg*) QUANT_CFG=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
40+
--compress*) COMPRESS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
41+
--calib_size*) CALIB_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
42+
--train_bs*) TRAIN_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
43+
--eval_bs*) EVAL_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
44+
--do_train*) DO_TRAIN=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
45+
--lora*) LORA=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
46+
--teacher_model*) TEACHER_MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
47+
--distill*) DISTILL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
48+
--fsdp_transformer_layer_cls_to_wrap*) FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
49+
--max_seq_length*) MAX_SEQ_LENGTH=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
50+
--backend*) BACKEND=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
51+
--use_fsdp2*) USE_FSDP2=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
11152
*)
11253
>&2 printf "Error: Invalid argument ${1#*=}\n"
11354
exit 1
@@ -142,6 +83,7 @@ COMPRESS=${COMPRESS:-"False"}
14283
DISTILL=${DISTILL:-"False"}
14384
TEACHER_MODEL=${TEACHER_MODEL:-$MODEL}
14485
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
86+
BACKEND=${BACKEND:-"fsdp1"}
14587

14688
if [ -z $QUANT_CFG ]; then
14789
QUANT_ARGS=""
@@ -154,31 +96,55 @@ if [ ! -z $MAX_STEPS ]; then
15496
OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS"
15597
fi
15698

157-
CONFIG_FILE="fsdp1.yaml"
158-
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
159-
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
160-
99+
# Set backend based on --backend parameter, with backward compatibility for --use_fsdp2
161100
if [[ "${USE_FSDP2,,}" == "true" ]]; then
162-
echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
163-
CONFIG_FILE="fsdp2.yaml"
164-
GRADIENT_CHECKPOINTING_ARGS=""
101+
echo "Warning: --use_fsdp2 is deprecated. Use --backend=fsdp2 instead."
102+
BACKEND="fsdp2"
103+
fi
104+
105+
# if compress is true, set backend to ddp
106+
if [[ "${COMPRESS,,}" == "true" ]]; then
107+
BACKEND="ddp"
165108
fi
166109

110+
# Configure backend-specific settings
111+
GRADIENT_CHECKPOINTING_ARGS=""
112+
case "${BACKEND,,}" in
113+
"fsdp1"|"fsdp")
114+
CONFIG_FILE="fsdp1.yaml"
115+
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
116+
;;
117+
"fsdp2")
118+
echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
119+
CONFIG_FILE="fsdp2.yaml"
120+
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
121+
;;
122+
"ddp")
123+
CONFIG_FILE="ddp.yaml"
124+
FSDP_ARGS=""
125+
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
126+
;;
127+
"deepspeed")
128+
CONFIG_FILE="deepspeed.yaml"
129+
FSDP_ARGS=""
130+
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
131+
;;
132+
*)
133+
echo "Error: Invalid backend '$BACKEND'. Supported backends: fsdp1, fsdp2, ddp, deepspeed"
134+
exit 1
135+
;;
136+
esac
137+
138+
# TODO: Remove this after simple distillation is supported
167139
DISTILLATION_ARGS=""
168140
if [[ "${DISTILL,,}" == "true" ]]; then
169141
DISTILLATION_ARGS="--distill $DISTILL --teacher_model $TEACHER_MODEL"
170-
# Distillation does not work with memory efficient loading
171-
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
142+
# Distillation does not work with memory efficient loading for FSDP
143+
if [[ "${BACKEND,,}" == "fsdp1" || "${BACKEND,,}" == "fsdp2" ]]; then
144+
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
145+
fi
172146
fi
173147

174-
# real quantization does not work with FSDP, only works with FSDP2
175-
if [[ "${COMPRESS,,}" == "true" && "${USE_FSDP2,,}" != "true" ]]; then
176-
echo "Compression is not supported with FSDP. Disabling FSDP and using DDP."
177-
FSDP_ARGS=""
178-
CONFIG_FILE="ddp.yaml"
179-
fi
180-
181-
182148
CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
183149
main.py \
184150
--model_name_or_path $MODEL \
@@ -209,10 +175,9 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
209175
--report_to tensorboard \
210176
--lora $LORA \
211177
--compress $COMPRESS \
212-
$QUANT_ARGS $OPTIONAL_ARGS $GRADIENT_CHECKPOINTING_ARGS $DISTILLATION_ARGS
178+
$GRADIENT_CHECKPOINTING_ARGS $QUANT_ARGS $OPTIONAL_ARGS $DISTILLATION_ARGS
213179
"
214180

215181
start_time=$(date +%s)
216182
sh -c "$CMD"
217-
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
218-
python convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR
183+
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"

examples/llm_qat/llama_factory/launch_llamafactory.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,4 +256,3 @@ else
256256
echo "Modified FSDP args: $FSDP_ARGS"
257257
accelerate launch --config_file $ACCELERATE_CONFIG $FSDP_ARGS $SCRIPT_DIR/llama_factory.py $CONFIG_FILE
258258
fi
259-
python $SCRIPT_DIR/../convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR

examples/llm_qat/main.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,15 @@
3838
from transformers.trainer_utils import get_last_checkpoint
3939
from utils import (
4040
get_lora_config,
41+
get_metrics_with_perplexity,
4142
make_supervised_data_module,
4243
monkey_patch_training_step_to_fix_memory_leak,
4344
)
4445

4546
import modelopt.torch.opt as mto
4647
import modelopt.torch.quantization as mtq
4748
from modelopt.torch.distill.plugins.huggingface import LMLogitsLoss
48-
from modelopt.torch.quantization.plugins.transformers_trainer import (
49-
QADTrainer,
50-
QATTrainer,
51-
get_metrics_with_perplexity,
52-
)
49+
from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer, QATTrainer
5350
from modelopt.torch.utils import print_rank_0
5451

5552
# Enable automatic save/load of modelopt state huggingface checkpointing
@@ -263,16 +260,12 @@ def train():
263260

264261
if training_args.do_train:
265262
trainer.train(resume_from_checkpoint=checkpoint)
263+
print_rank_0("Training completed.")
266264

267265
if training_args.do_eval:
268-
if not training_args.do_train:
269-
# trainer.evaluate() will not prepare the model properly, especially for FSDP2,
270-
# so we use the ``eval_on_start`` flag to evaluate the model and skip the training.
271-
trainer.train(resume_from_checkpoint=checkpoint, eval_only=True)
272-
else:
273-
metrics = trainer.evaluate()
274-
metrics = get_metrics_with_perplexity(metrics)
275-
print_rank_0(f"Evaluation results: \n{metrics}")
266+
metrics = trainer.evaluate()
267+
metrics = get_metrics_with_perplexity(metrics)
268+
print_rank_0(f"Evaluation results: \n{metrics}")
276269

277270
if training_args.do_train or quant_args.quant_cfg is not None:
278271
print_rank_0("Saving the model...")

examples/llm_qat/simple_qat_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def train(model, optimizer, train_dataloader, tokenizer, epochs, output_dir, dev
7474

7575

7676
def parse_args() -> argparse.Namespace:
77-
parser = argparse.ArgumentParser(description="CNN QAT using ModelOpt")
77+
parser = argparse.ArgumentParser(description="QAT Training Script")
7878
# Data paths
7979
parser.add_argument("--model-path", type=str, required=True, help="Path to the model")
8080
parser.add_argument("--train-size", type=int, default=2048, help="Train size")
@@ -87,7 +87,7 @@ def parse_args() -> argparse.Namespace:
8787
parser.add_argument(
8888
"--quant-cfg",
8989
type=str,
90-
default=mtq.NVFP4_DEFAULT_CFG,
90+
default="NVFP4_DEFAULT_CFG",
9191
choices=mtq.config.choices,
9292
help="Quantization configuration",
9393
)
@@ -121,7 +121,7 @@ def calibrate(m: nn.Module):
121121
m(batch["input_ids"].to(device))
122122

123123
# Quantize the model
124-
model = mtq.quantize(model, args.quant_cfg, calibrate)
124+
model = mtq.quantize(model, getattr(mtq, args.quant_cfg), calibrate)
125125

126126
# Initialize optimizer
127127
optimizer = AdamW(model.parameters(), lr=args.lr)

0 commit comments

Comments
 (0)