Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/llm_qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def forward_loop(model):


# Quantize the model in-place; The model should be unwrapped from any distributed wrapper
# The model may be wrapped in a DataParallel or DistributedDataParallel after `mtq.quantize`
model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop)

# Save the modelopt quantizer states
Expand Down
23 changes: 23 additions & 0 deletions examples/llm_qat/accelerate_config/deepspeed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: gpu
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
2 changes: 1 addition & 1 deletion examples/llm_qat/accelerate_config/fsdp1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_activation_checkpointing: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
Expand Down
55 changes: 0 additions & 55 deletions examples/llm_qat/convert_sharded_ckpt.py

This file was deleted.

181 changes: 73 additions & 108 deletions examples/llm_qat/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,96 +18,37 @@ set -eo pipefail

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Helper function to parse a single argument value
parse_value() {
if [[ "$1" != *=* ]]; then shift; fi
echo "${1#*=}"
}

while [ $# -gt 0 ]; do
case "$1" in
--model*)
if [[ "$1" != *=* ]]; then shift; fi
MODEL="${1#*=}"
;;
--output_dir*)
if [[ "$1" != *=* ]]; then shift; fi
OUTPUT_DIR="${1#*=}"
;;
--dataset*)
if [[ "$1" != *=* ]]; then shift; fi
DATASET="${1#*=}"
;;
--train_size*)
if [[ "$1" != *=* ]]; then shift; fi
TRAIN_SIZE="${1#*=}"
;;
--eval_size*)
if [[ "$1" != *=* ]]; then shift; fi
EVAL_SIZE="${1#*=}"
;;
--num_epochs*)
if [[ "$1" != *=* ]]; then shift; fi
NUM_EPOCHS="${1#*=}"
;;
--max_steps*)
if [[ "$1" != *=* ]]; then shift; fi
MAX_STEPS="${1#*=}"
;;
--save_steps*)
if [[ "$1" != *=* ]]; then shift; fi
SAVE_STEPS="${1#*=}"
;;
--accum_steps*)
if [[ "$1" != *=* ]]; then shift; fi
ACCUM_STEPS="${1#*=}"
;;
--lr*)
if [[ "$1" != *=* ]]; then shift; fi
LR="${1#*=}"
;;
--quant_cfg*)
if [[ "$1" != *=* ]]; then shift; fi
QUANT_CFG="${1#*=}"
;;
--compress*)
if [[ "$1" != *=* ]]; then shift; fi
COMPRESS="${1#*=}"
;;
--calib_size*)
if [[ "$1" != *=* ]]; then shift; fi
CALIB_SIZE="${1#*=}"
;;
--train_bs*)
if [[ "$1" != *=* ]]; then shift; fi
TRAIN_BS="${1#*=}"
;;
--eval_bs*)
if [[ "$1" != *=* ]]; then shift; fi
EVAL_BS="${1#*=}"
;;
--do_train*)
if [[ "$1" != *=* ]]; then shift; fi
DO_TRAIN="${1#*=}"
;;
--lora*)
if [[ "$1" != *=* ]]; then shift; fi
LORA="${1#*=}"
;;
--teacher_model*)
if [[ "$1" != *=* ]]; then shift; fi
TEACHER_MODEL="${1#*=}"
;;
--distill*)
if [[ "$1" != *=* ]]; then shift; fi
DISTILL="${1#*=}"
;;
--fsdp_transformer_layer_cls_to_wrap*)
if [[ "$1" != *=* ]]; then shift; fi
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}"
;;
--use_fsdp2*)
if [[ "$1" != *=* ]]; then shift; fi
USE_FSDP2="${1#*=}"
;;
--max_seq_length*)
if [[ "$1" != *=* ]]; then shift; fi
MAX_SEQ_LENGTH="${1#*=}"
;;
--model*) MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--output_dir*) OUTPUT_DIR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--dataset*) DATASET=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--train_size*) TRAIN_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--eval_size*) EVAL_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--num_epochs*) NUM_EPOCHS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--max_steps*) MAX_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--save_steps*) SAVE_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--accum_steps*) ACCUM_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--lr*) LR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--quant_cfg*) QUANT_CFG=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--compress*) COMPRESS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--calib_size*) CALIB_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--train_bs*) TRAIN_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--eval_bs*) EVAL_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--do_train*) DO_TRAIN=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--lora*) LORA=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--teacher_model*) TEACHER_MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--distill*) DISTILL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--fsdp_transformer_layer_cls_to_wrap*) FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--max_seq_length*) MAX_SEQ_LENGTH=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--backend*) BACKEND=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--use_fsdp2*) USE_FSDP2=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -142,6 +83,7 @@ COMPRESS=${COMPRESS:-"False"}
DISTILL=${DISTILL:-"False"}
TEACHER_MODEL=${TEACHER_MODEL:-$MODEL}
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
BACKEND=${BACKEND:-"fsdp1"}

if [ -z $QUANT_CFG ]; then
QUANT_ARGS=""
Expand All @@ -154,31 +96,55 @@ if [ ! -z $MAX_STEPS ]; then
OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS"
fi

CONFIG_FILE="fsdp1.yaml"
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"

# Set backend based on --backend parameter, with backward compatibility for --use_fsdp2
if [[ "${USE_FSDP2,,}" == "true" ]]; then
echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
CONFIG_FILE="fsdp2.yaml"
GRADIENT_CHECKPOINTING_ARGS=""
echo "Warning: --use_fsdp2 is deprecated. Use --backend=fsdp2 instead."
BACKEND="fsdp2"
fi

# if compress is true, set backend to ddp
if [[ "${COMPRESS,,}" == "true" ]]; then
BACKEND="ddp"
fi

# Configure backend-specific settings
GRADIENT_CHECKPOINTING_ARGS=""
case "${BACKEND,,}" in
"fsdp1"|"fsdp")
CONFIG_FILE="fsdp1.yaml"
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
;;
"fsdp2")
echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
CONFIG_FILE="fsdp2.yaml"
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
;;
"ddp")
CONFIG_FILE="ddp.yaml"
FSDP_ARGS=""
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
;;
"deepspeed")
CONFIG_FILE="deepspeed.yaml"
FSDP_ARGS=""
GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True"
;;
*)
echo "Error: Invalid backend '$BACKEND'. Supported backends: fsdp1, fsdp2, ddp, deepspeed"
exit 1
;;
esac

# TODO: Remove this after simple distillation is supported
DISTILLATION_ARGS=""
if [[ "${DISTILL,,}" == "true" ]]; then
DISTILLATION_ARGS="--distill $DISTILL --teacher_model $TEACHER_MODEL"
# Distillation does not work with memory efficient loading
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
# Distillation does not work with memory efficient loading for FSDP
if [[ "${BACKEND,,}" == "fsdp1" || "${BACKEND,,}" == "fsdp2" ]]; then
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
fi
fi

# real quantization does not work with FSDP, only works with FSDP2
if [[ "${COMPRESS,,}" == "true" && "${USE_FSDP2,,}" != "true" ]]; then
echo "Compression is not supported with FSDP. Disabling FSDP and using DDP."
FSDP_ARGS=""
CONFIG_FILE="ddp.yaml"
fi


CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
main.py \
--model_name_or_path $MODEL \
Expand Down Expand Up @@ -209,10 +175,9 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
--report_to tensorboard \
--lora $LORA \
--compress $COMPRESS \
$QUANT_ARGS $OPTIONAL_ARGS $GRADIENT_CHECKPOINTING_ARGS $DISTILLATION_ARGS
$GRADIENT_CHECKPOINTING_ARGS $QUANT_ARGS $OPTIONAL_ARGS $DISTILLATION_ARGS
"

start_time=$(date +%s)
sh -c "$CMD"
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
python convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
1 change: 0 additions & 1 deletion examples/llm_qat/llama_factory/launch_llamafactory.sh
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,3 @@ else
echo "Modified FSDP args: $FSDP_ARGS"
accelerate launch --config_file $ACCELERATE_CONFIG $FSDP_ARGS $SCRIPT_DIR/llama_factory.py $CONFIG_FILE
fi
python $SCRIPT_DIR/../convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR
19 changes: 6 additions & 13 deletions examples/llm_qat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,15 @@
from transformers.trainer_utils import get_last_checkpoint
from utils import (
get_lora_config,
get_metrics_with_perplexity,
make_supervised_data_module,
monkey_patch_training_step_to_fix_memory_leak,
)

import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.distill.plugins.huggingface import LMLogitsLoss
from modelopt.torch.quantization.plugins.transformers_trainer import (
QADTrainer,
QATTrainer,
get_metrics_with_perplexity,
)
from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer, QATTrainer
from modelopt.torch.utils import print_rank_0

# Enable automatic save/load of modelopt state huggingface checkpointing
Expand Down Expand Up @@ -263,16 +260,12 @@ def train():

if training_args.do_train:
trainer.train(resume_from_checkpoint=checkpoint)
print_rank_0("Training completed.")

if training_args.do_eval:
if not training_args.do_train:
# trainer.evaluate() will not prepare the model properly, especially for FSDP2,
# so we use the ``eval_on_start`` flag to evaluate the model and skip the training.
trainer.train(resume_from_checkpoint=checkpoint, eval_only=True)
else:
metrics = trainer.evaluate()
metrics = get_metrics_with_perplexity(metrics)
print_rank_0(f"Evaluation results: \n{metrics}")
metrics = trainer.evaluate()
metrics = get_metrics_with_perplexity(metrics)
print_rank_0(f"Evaluation results: \n{metrics}")

if training_args.do_train or quant_args.quant_cfg is not None:
print_rank_0("Saving the model...")
Expand Down
6 changes: 3 additions & 3 deletions examples/llm_qat/simple_qat_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def train(model, optimizer, train_dataloader, tokenizer, epochs, output_dir, dev


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="CNN QAT using ModelOpt")
parser = argparse.ArgumentParser(description="QAT Training Script")
# Data paths
parser.add_argument("--model-path", type=str, required=True, help="Path to the model")
parser.add_argument("--train-size", type=int, default=2048, help="Train size")
Expand All @@ -87,7 +87,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--quant-cfg",
type=str,
default=mtq.NVFP4_DEFAULT_CFG,
default="NVFP4_DEFAULT_CFG",
choices=mtq.config.choices,
help="Quantization configuration",
)
Expand Down Expand Up @@ -121,7 +121,7 @@ def calibrate(m: nn.Module):
m(batch["input_ids"].to(device))

# Quantize the model
model = mtq.quantize(model, args.quant_cfg, calibrate)
model = mtq.quantize(model, getattr(mtq, args.quant_cfg), calibrate)

# Initialize optimizer
optimizer = AdamW(model.parameters(), lr=args.lr)
Expand Down
Loading
Loading