Skip to content

Commit 1b14ede

Browse files
committed
Fixed FSDP2 QATTrainer: Restore modelopt state before loading weights; Cleaned QATTrainer
Added QAT examples tests for various backends minor minor minor fixes minor revert unnecessary change minor minor minor address comments minor minor fixed minor bug minor minor minor Signed-off-by: Your Name <[email protected]>
1 parent 76e8ce2 commit 1b14ede

File tree

19 files changed

+325
-412
lines changed

19 files changed

+325
-412
lines changed

examples/llm_qat/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def forward_loop(model):
8282

8383

8484
# Quantize the model in-place; The model should be unwrapped from any distributed wrapper
85-
# The model may be wrapped in a DataParallel or DistributedDataParallel after `mtq.quantize`
8685
model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop)
8786

8887
# Save the modelopt quantizer states
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
deepspeed_config:
4+
gradient_clipping: 1.0
5+
offload_optimizer_device: none
6+
offload_param_device: none
7+
zero3_init_flag: true
8+
zero3_save_16bit_model: true
9+
zero_stage: 3
10+
distributed_type: DEEPSPEED
11+
downcast_bf16: 'no'
12+
enable_cpu_affinity: false
13+
machine_rank: 0
14+
main_training_function: main
15+
mixed_precision: bf16
16+
num_machines: 1
17+
num_processes: gpu
18+
rdzv_backend: static
19+
same_network: true
20+
tpu_env: []
21+
tpu_use_cluster: false
22+
tpu_use_sudo: false
23+
use_cpu: false

examples/llm_qat/accelerate_config/fsdp1.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ distributed_type: FSDP
44
downcast_bf16: 'no'
55
enable_cpu_affinity: false
66
fsdp_config:
7-
fsdp_activation_checkpointing: false
7+
fsdp_activation_checkpointing: true
88
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
99
fsdp_backward_prefetch: BACKWARD_PRE
1010
fsdp_cpu_ram_efficient_loading: true

examples/llm_qat/convert_sharded_ckpt.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

examples/llm_qat/launch.sh

Lines changed: 73 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -18,96 +18,37 @@ set -eo pipefail
1818

1919
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
2020

21+
# Helper function to parse a single argument value
22+
parse_value() {
23+
if [[ "$1" != *=* ]]; then shift; fi
24+
echo "${1#*=}"
25+
}
26+
2127
while [ $# -gt 0 ]; do
2228
case "$1" in
23-
--model*)
24-
if [[ "$1" != *=* ]]; then shift; fi
25-
MODEL="${1#*=}"
26-
;;
27-
--output_dir*)
28-
if [[ "$1" != *=* ]]; then shift; fi
29-
OUTPUT_DIR="${1#*=}"
30-
;;
31-
--dataset*)
32-
if [[ "$1" != *=* ]]; then shift; fi
33-
DATASET="${1#*=}"
34-
;;
35-
--train_size*)
36-
if [[ "$1" != *=* ]]; then shift; fi
37-
TRAIN_SIZE="${1#*=}"
38-
;;
39-
--eval_size*)
40-
if [[ "$1" != *=* ]]; then shift; fi
41-
EVAL_SIZE="${1#*=}"
42-
;;
43-
--num_epochs*)
44-
if [[ "$1" != *=* ]]; then shift; fi
45-
NUM_EPOCHS="${1#*=}"
46-
;;
47-
--max_steps*)
48-
if [[ "$1" != *=* ]]; then shift; fi
49-
MAX_STEPS="${1#*=}"
50-
;;
51-
--save_steps*)
52-
if [[ "$1" != *=* ]]; then shift; fi
53-
SAVE_STEPS="${1#*=}"
54-
;;
55-
--accum_steps*)
56-
if [[ "$1" != *=* ]]; then shift; fi
57-
ACCUM_STEPS="${1#*=}"
58-
;;
59-
--lr*)
60-
if [[ "$1" != *=* ]]; then shift; fi
61-
LR="${1#*=}"
62-
;;
63-
--quant_cfg*)
64-
if [[ "$1" != *=* ]]; then shift; fi
65-
QUANT_CFG="${1#*=}"
66-
;;
67-
--compress*)
68-
if [[ "$1" != *=* ]]; then shift; fi
69-
COMPRESS="${1#*=}"
70-
;;
71-
--calib_size*)
72-
if [[ "$1" != *=* ]]; then shift; fi
73-
CALIB_SIZE="${1#*=}"
74-
;;
75-
--train_bs*)
76-
if [[ "$1" != *=* ]]; then shift; fi
77-
TRAIN_BS="${1#*=}"
78-
;;
79-
--eval_bs*)
80-
if [[ "$1" != *=* ]]; then shift; fi
81-
EVAL_BS="${1#*=}"
82-
;;
83-
--do_train*)
84-
if [[ "$1" != *=* ]]; then shift; fi
85-
DO_TRAIN="${1#*=}"
86-
;;
87-
--lora*)
88-
if [[ "$1" != *=* ]]; then shift; fi
89-
LORA="${1#*=}"
90-
;;
91-
--teacher_model*)
92-
if [[ "$1" != *=* ]]; then shift; fi
93-
TEACHER_MODEL="${1#*=}"
94-
;;
95-
--distill*)
96-
if [[ "$1" != *=* ]]; then shift; fi
97-
DISTILL="${1#*=}"
98-
;;
99-
--fsdp_transformer_layer_cls_to_wrap*)
100-
if [[ "$1" != *=* ]]; then shift; fi
101-
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
102-
;;
103-
--use_fsdp2*)
104-
if [[ "$1" != *=* ]]; then shift; fi
105-
USE_FSDP2="${1#*=}"
106-
;;
107-
--max_seq_length*)
108-
if [[ "$1" != *=* ]]; then shift; fi
109-
MAX_SEQ_LENGTH="${1#*=}"
110-
;;
29+
--model*) MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
30+
--output_dir*) OUTPUT_DIR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
31+
--dataset*) DATASET=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
32+
--train_size*) TRAIN_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
33+
--eval_size*) EVAL_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
34+
--num_epochs*) NUM_EPOCHS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
35+
--max_steps*) MAX_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
36+
--save_steps*) SAVE_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
37+
--accum_steps*) ACCUM_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
38+
--lr*) LR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
39+
--quant_cfg*) QUANT_CFG=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
40+
--compress*) COMPRESS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
41+
--calib_size*) CALIB_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
42+
--train_bs*) TRAIN_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
43+
--eval_bs*) EVAL_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
44+
--do_train*) DO_TRAIN=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
45+
--lora*) LORA=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
46+
--teacher_model*) TEACHER_MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
47+
--distill*) DISTILL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
48+
--fsdp_transformer_layer_cls_to_wrap*) FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
49+
--max_seq_length*) MAX_SEQ_LENGTH=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
50+
--backend*) BACKEND=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
51+
--use_fsdp2*) USE_FSDP2=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
11152
*)
11253
>&2 printf "Error: Invalid argument ${1#*=}\n"
11354
exit 1
@@ -142,6 +83,7 @@ COMPRESS=${COMPRESS:-"False"}
14283
DISTILL=${DISTILL:-"False"}
14384
TEACHER_MODEL=${TEACHER_MODEL:-$MODEL}
14485
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
86+
BACKEND=${BACKEND:-"fsdp1"}
14587

14688
if [ -z $QUANT_CFG ]; then
14789
QUANT_ARGS=""
@@ -154,31 +96,55 @@ if [ ! -z $MAX_STEPS ]; then
15496
OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS"
15597
fi
15698

157-
CONFIG_FILE="fsdp1.yaml"
158-
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
159-
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
160-
99+
# Set backend based on --backend parameter, with backward compatibility for --use_fsdp2
161100
if [[ "${USE_FSDP2,,}" == "true" ]]; then
162-
echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
163-
CONFIG_FILE="fsdp2.yaml"
164-
GRADIENT_CHECKPOINTING_ARGS=""
101+
echo "Warning: --use_fsdp2 is deprecated. Use --backend=fsdp2 instead."
102+
BACKEND="fsdp2"
103+
fi
104+
105+
# if compress is true, set backend to ddp
106+
if [[ "${COMPRESS,,}" == "true" ]]; then
107+
BACKEND="ddp"
165108
fi
166109

110+
# Configure backend-specific settings
111+
GRADIENT_CHECKPOINTING_ARGS=""
112+
case "${BACKEND,,}" in
113+
"fsdp1"|"fsdp")
114+
CONFIG_FILE="fsdp1.yaml"
115+
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
116+
;;
117+
"fsdp2")
118+
echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
119+
CONFIG_FILE="fsdp2.yaml"
120+
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
121+
;;
122+
"ddp")
123+
CONFIG_FILE="ddp.yaml"
124+
FSDP_ARGS=""
125+
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
126+
;;
127+
"deepspeed")
128+
CONFIG_FILE="deepspeed.yaml"
129+
FSDP_ARGS=""
130+
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
131+
;;
132+
*)
133+
echo "Error: Invalid backend '$BACKEND'. Supported backends: fsdp1, fsdp2, ddp, deepspeed"
134+
exit 1
135+
;;
136+
esac
137+
138+
# TODO: Remove this after simple distillation is supported
167139
DISTILLATION_ARGS=""
168140
if [[ "${DISTILL,,}" == "true" ]]; then
169141
DISTILLATION_ARGS="--distill $DISTILL --teacher_model $TEACHER_MODEL"
170-
# Distillation does not work with memory efficient loading
171-
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
142+
# Distillation does not work with memory efficient loading for FSDP
143+
if [[ "${BACKEND,,}" == "fsdp1" || "${BACKEND,,}" == "fsdp2" ]]; then
144+
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
145+
fi
172146
fi
173147

174-
# real quantization does not work with FSDP, only works with FSDP2
175-
if [[ "${COMPRESS,,}" == "true" && "${USE_FSDP2,,}" != "true" ]]; then
176-
echo "Compression is not supported with FSDP. Disabling FSDP and using DDP."
177-
FSDP_ARGS=""
178-
CONFIG_FILE="ddp.yaml"
179-
fi
180-
181-
182148
CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
183149
main.py \
184150
--model_name_or_path $MODEL \
@@ -209,10 +175,9 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
209175
--report_to tensorboard \
210176
--lora $LORA \
211177
--compress $COMPRESS \
212-
$QUANT_ARGS $OPTIONAL_ARGS $GRADIENT_CHECKPOINTING_ARGS $DISTILLATION_ARGS
178+
$GRADIENT_CHECKPOINTING_ARGS $QUANT_ARGS $OPTIONAL_ARGS $DISTILLATION_ARGS
213179
"
214180

215181
start_time=$(date +%s)
216182
sh -c "$CMD"
217-
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
218-
python convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR
183+
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"

examples/llm_qat/main.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,15 @@
3838
from transformers.trainer_utils import get_last_checkpoint
3939
from utils import (
4040
get_lora_config,
41+
get_metrics_with_perplexity,
4142
make_supervised_data_module,
4243
monkey_patch_training_step_to_fix_memory_leak,
4344
)
4445

4546
import modelopt.torch.opt as mto
4647
import modelopt.torch.quantization as mtq
4748
from modelopt.torch.distill.plugins.huggingface import LMLogitsLoss
48-
from modelopt.torch.quantization.plugins.transformers_trainer import (
49-
QADTrainer,
50-
QATTrainer,
51-
get_metrics_with_perplexity,
52-
)
49+
from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer, QATTrainer
5350
from modelopt.torch.utils import print_rank_0
5451

5552
# Enable automatic save/load of modelopt state huggingface checkpointing
@@ -263,16 +260,12 @@ def train():
263260

264261
if training_args.do_train:
265262
trainer.train(resume_from_checkpoint=checkpoint)
263+
print_rank_0("Training completed.")
266264

267265
if training_args.do_eval:
268-
if not training_args.do_train:
269-
# trainer.evaluate() will not prepare the model properly, especially for FSDP2,
270-
# so we use the ``eval_on_start`` flag to evaluate the model and skip the training.
271-
trainer.train(resume_from_checkpoint=checkpoint, eval_only=True)
272-
else:
273-
metrics = trainer.evaluate()
274-
metrics = get_metrics_with_perplexity(metrics)
275-
print_rank_0(f"Evaluation results: \n{metrics}")
266+
metrics = trainer.evaluate()
267+
metrics = get_metrics_with_perplexity(metrics)
268+
print_rank_0(f"Evaluation results: \n{metrics}")
276269

277270
if training_args.do_train or quant_args.quant_cfg is not None:
278271
print_rank_0("Saving the model...")

examples/llm_qat/simple_qat_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def parse_args() -> argparse.Namespace:
8787
parser.add_argument(
8888
"--quant-cfg",
8989
type=str,
90-
default=mtq.NVFP4_DEFAULT_CFG,
90+
default="NVFP4_DEFAULT_CFG",
9191
choices=mtq.config.choices,
9292
help="Quantization configuration",
9393
)
@@ -121,7 +121,7 @@ def calibrate(m: nn.Module):
121121
m(batch["input_ids"].to(device))
122122

123123
# Quantize the model
124-
model = mtq.quantize(model, args.quant_cfg, calibrate)
124+
model = mtq.quantize(model, getattr(mtq, args.quant_cfg), calibrate)
125125

126126
# Initialize optimizer
127127
optimizer = AdamW(model.parameters(), lr=args.lr)

examples/llm_qat/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,9 @@ def new_func(original_f_name, trainer, *args, **kwargs):
167167
setattr(
168168
trainer, f_name, types.MethodType(partial(new_func, "_original_" + f_name), trainer)
169169
)
170+
171+
172+
def get_metrics_with_perplexity(metrics):
173+
"""Add perplexity to the metrics."""
174+
metrics = {"perplexity": float(torch.exp(torch.tensor(metrics["eval_loss"]))), **metrics}
175+
return metrics

0 commit comments

Comments
 (0)