Skip to content

Commit 2a4d8b0

Browse files
author
Your Name
committed
Fixed FSDP2 QATTrainer: Restore modelopt state before loading weights; Cleaned QATTrainer
Added QAT examples tests for various backends minor minor minor fixes
1 parent 1ef1d72 commit 2a4d8b0

File tree

14 files changed

+296
-314
lines changed

14 files changed

+296
-314
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
deepspeed_config:
4+
gradient_clipping: 1.0
5+
offload_optimizer_device: none
6+
offload_param_device: none
7+
zero3_init_flag: true
8+
zero3_save_16bit_model: true
9+
zero_stage: 3
10+
distributed_type: DEEPSPEED
11+
downcast_bf16: 'no'
12+
enable_cpu_affinity: false
13+
machine_rank: 0
14+
main_training_function: main
15+
mixed_precision: bf16
16+
num_machines: 1
17+
num_processes: gpu
18+
rdzv_backend: static
19+
same_network: true
20+
tpu_env: []
21+
tpu_use_cluster: false
22+
tpu_use_sudo: false
23+
use_cpu: false

examples/llm_qat/convert_sharded_ckpt.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

examples/llm_qat/launch.sh

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,18 @@ while [ $# -gt 0 ]; do
100100
if [[ "$1" != *=* ]]; then shift; fi
101101
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
102102
;;
103-
--use_fsdp2*)
104-
if [[ "$1" != *=* ]]; then shift; fi
105-
USE_FSDP2="${1#*=}"
106-
;;
107103
--max_seq_length*)
108104
if [[ "$1" != *=* ]]; then shift; fi
109105
MAX_SEQ_LENGTH="${1#*=}"
110106
;;
107+
--backend*)
108+
if [[ "$1" != *=* ]]; then shift; fi
109+
BACKEND="${1#*=}"
110+
;;
111+
--use_fsdp2*)
112+
if [[ "$1" != *=* ]]; then shift; fi
113+
USE_FSDP2="${1#*=}"
114+
;;
111115
*)
112116
>&2 printf "Error: Invalid argument ${1#*=}\n"
113117
exit 1
@@ -142,6 +146,7 @@ COMPRESS=${COMPRESS:-"False"}
142146
DISTILL=${DISTILL:-"False"}
143147
TEACHER_MODEL=${TEACHER_MODEL:-$MODEL}
144148
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
149+
BACKEND=${BACKEND:-"fsdp1"}
145150

146151
if [ -z $QUANT_CFG ]; then
147152
QUANT_ARGS=""
@@ -154,31 +159,56 @@ if [ ! -z $MAX_STEPS ]; then
154159
OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS"
155160
fi
156161

157-
CONFIG_FILE="fsdp1.yaml"
158-
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
159-
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
160-
162+
# Set backend based on --backend parameter, with backward compatibility for --use_fsdp2
161163
if [[ "${USE_FSDP2,,}" == "true" ]]; then
162-
echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
163-
CONFIG_FILE="fsdp2.yaml"
164-
GRADIENT_CHECKPOINTING_ARGS=""
164+
echo "Warning: --use_fsdp2 is deprecated. Use --backend=fsdp2 instead."
165+
BACKEND="fsdp2"
166+
fi
167+
168+
# if compress is true, set backend to ddp
169+
if [[ "${COMPRESS,,}" == "true" ]]; then
170+
BACKEND="ddp"
165171
fi
166172

173+
# Configure backend-specific settings
174+
case "${BACKEND,,}" in
175+
"fsdp1"|"fsdp")
176+
CONFIG_FILE="fsdp1.yaml"
177+
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
178+
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
179+
;;
180+
"fsdp2")
181+
echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
182+
CONFIG_FILE="fsdp2.yaml"
183+
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
184+
GRADIENT_CHECKPOINTING_ARGS=""
185+
;;
186+
"ddp")
187+
CONFIG_FILE="ddp.yaml"
188+
FSDP_ARGS=""
189+
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
190+
;;
191+
"deepspeed")
192+
CONFIG_FILE="deepspeed.yaml"
193+
FSDP_ARGS=""
194+
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
195+
;;
196+
*)
197+
echo "Error: Invalid backend '$BACKEND'. Supported backends: fsdp1, fsdp2, ddp, deepspeed"
198+
exit 1
199+
;;
200+
esac
201+
202+
# TODO: Remove this after simple distillation is supported
167203
DISTILLATION_ARGS=""
168204
if [[ "${DISTILL,,}" == "true" ]]; then
169205
DISTILLATION_ARGS="--distill $DISTILL --teacher_model $TEACHER_MODEL"
170-
# Distillation does not work with memory efficient loading
171-
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
206+
# Distillation does not work with memory efficient loading for FSDP
207+
if [[ "${BACKEND,,}" == "fsdp1" || "${BACKEND,,}" == "fsdp2" ]]; then
208+
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
209+
fi
172210
fi
173211

174-
# real quantization does not work with FSDP, only works with FSDP2
175-
if [[ "${COMPRESS,,}" == "true" && "${USE_FSDP2,,}" != "true" ]]; then
176-
echo "Compression is not supported with FSDP. Disabling FSDP and using DDP."
177-
FSDP_ARGS=""
178-
CONFIG_FILE="ddp.yaml"
179-
fi
180-
181-
182212
CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
183213
main.py \
184214
--model_name_or_path $MODEL \
@@ -214,5 +244,4 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
214244

215245
start_time=$(date +%s)
216246
sh -c "$CMD"
217-
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
218-
python convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR
247+
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"

examples/llm_qat/main.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,15 @@
3838
from transformers.trainer_utils import get_last_checkpoint
3939
from utils import (
4040
get_lora_config,
41+
get_metrics_with_perplexity,
4142
make_supervised_data_module,
4243
monkey_patch_training_step_to_fix_memory_leak,
4344
)
4445

4546
import modelopt.torch.opt as mto
4647
import modelopt.torch.quantization as mtq
4748
from modelopt.torch.distill.plugins.huggingface import LMLogitsLoss
48-
from modelopt.torch.quantization.plugins.transformers_trainer import (
49-
QADTrainer,
50-
QATTrainer,
51-
get_metrics_with_perplexity,
52-
)
49+
from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer, QATTrainer
5350
from modelopt.torch.utils import print_rank_0
5451

5552
# Enable automatic save/load of modelopt state huggingface checkpointing
@@ -263,22 +260,17 @@ def train():
263260

264261
if training_args.do_train:
265262
trainer.train(resume_from_checkpoint=checkpoint)
263+
print_rank_0("Training completed.")
266264

267265
if training_args.do_eval:
268-
if not training_args.do_train:
269-
# trainer.evaluate() will not prepare the model properly, especially for FSDP2,
270-
# so we use the ``eval_on_start`` flag to evaluate the model and skip the training.
271-
trainer.train(resume_from_checkpoint=checkpoint, eval_only=True)
272-
else:
273-
metrics = trainer.evaluate()
274-
metrics = get_metrics_with_perplexity(metrics)
275-
print_rank_0(f"Evaluation results: \n{metrics}")
276-
277-
if training_args.do_train or quant_args.quant_cfg is not None:
278-
print_rank_0("Saving the model...")
279-
trainer.save_state()
280-
kwargs = {"export_student": True} if training_args.distill else {}
281-
trainer.save_model(training_args.output_dir, **kwargs)
266+
metrics = trainer.evaluate()
267+
metrics = get_metrics_with_perplexity(metrics)
268+
print_rank_0(f"Evaluation results: \n{metrics}")
269+
270+
print_rank_0("Saving the model...")
271+
trainer.save_state()
272+
kwargs = {"export_student": True} if training_args.distill else {}
273+
trainer.save_model(training_args.output_dir, **kwargs)
282274

283275

284276
if __name__ == "__main__":

examples/llm_qat/simple_qat_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def parse_args() -> argparse.Namespace:
8787
parser.add_argument(
8888
"--quant-cfg",
8989
type=str,
90-
default=mtq.NVFP4_DEFAULT_CFG,
90+
default="NVFP4_DEFAULT_CFG",
9191
choices=mtq.config.choices,
9292
help="Quantization configuration",
9393
)
@@ -121,7 +121,7 @@ def calibrate(m: nn.Module):
121121
m(batch["input_ids"].to(device))
122122

123123
# Quantize the model
124-
model = mtq.quantize(model, args.quant_cfg, calibrate)
124+
model = mtq.quantize(model, getattr(mtq, args.quant_cfg), calibrate)
125125

126126
# Initialize optimizer
127127
optimizer = AdamW(model.parameters(), lr=args.lr)

examples/llm_qat/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,9 @@ def new_func(original_f_name, trainer, *args, **kwargs):
167167
setattr(
168168
trainer, f_name, types.MethodType(partial(new_func, "_original_" + f_name), trainer)
169169
)
170+
171+
172+
def get_metrics_with_perplexity(metrics):
173+
"""Add perplexity to the metrics."""
174+
metrics = {"perplexity": float(torch.exp(torch.tensor(metrics["eval_loss"]))), **metrics}
175+
return metrics

modelopt/torch/opt/conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def apply_mode(
380380
return model.init_modellike() if isinstance(model, ModelLikeModule) else model
381381

382382
# check if the model is in a wrapper
383-
model = unwrap_model(model, raise_error=True)
383+
model = unwrap_model(model, force_unwrap=True)
384384

385385
# standardize mode to ModeConfigList
386386
mode_and_config = get_mode_config(mode)

modelopt/torch/opt/plugins/peft.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,9 @@ def _new_save_pretrained_peft(self, save_directory, *args, **kwargs):
5757
# So we need to save the quantizer state_dict separately
5858

5959
# TODO: Move this to modelopt.torch.quantization.plugins.peft
60-
from modelopt.torch.quantization.nn import TensorQuantizer
61-
62-
# We should not call self/model.state_dict() here. HF Trainer calls model.save_pretrained() only from process 0
63-
# With FSDP, model.state_dict() will hang if it is not called from all processes
64-
quantizer_state_dict = {}
65-
for name, module in self.named_modules():
66-
if isinstance(module, TensorQuantizer):
67-
quantizer_state_dict[get_unwrapped_name(name)] = module.state_dict()
60+
from modelopt.torch.quantization.utils import get_quantizer_state_dict
61+
62+
quantizer_state_dict = get_quantizer_state_dict(self)
6863
if len(quantizer_state_dict) > 0:
6964
torch.save(quantizer_state_dict, _get_quantizer_state_save_path(save_directory))
7065
return outputs

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,10 @@ class QuantLinearConvBase(QuantInputBase):
158158
def quantize_weight(self):
159159
"""Context in which `self.weight` is quantized."""
160160
self._enable_weight_quantization = True
161-
yield
162-
self._enable_weight_quantization = False
161+
try:
162+
yield
163+
finally:
164+
self._enable_weight_quantization = False
163165

164166
@staticmethod
165167
def _get_quantized_weight(module: "QuantLinearConvBase", weight: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)