diff --git a/examples/llm_qat/README.md b/examples/llm_qat/README.md index 522c2991..801f7ee2 100644 --- a/examples/llm_qat/README.md +++ b/examples/llm_qat/README.md @@ -82,7 +82,6 @@ def forward_loop(model): # Quantize the model in-place; The model should be unwrapped from any distributed wrapper -# The model may be wrapped in a DataParallel or DistributedDataParallel after `mtq.quantize` model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop) # Save the modelopt quantizer states diff --git a/examples/llm_qat/accelerate_config/deepspeed.yaml b/examples/llm_qat/accelerate_config/deepspeed.yaml new file mode 100644 index 00000000..913bb157 --- /dev/null +++ b/examples/llm_qat/accelerate_config/deepspeed.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: gpu +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/llm_qat/accelerate_config/fsdp1.yaml b/examples/llm_qat/accelerate_config/fsdp1.yaml index fc80dd35..5e0f5e65 100644 --- a/examples/llm_qat/accelerate_config/fsdp1.yaml +++ b/examples/llm_qat/accelerate_config/fsdp1.yaml @@ -4,7 +4,7 @@ distributed_type: FSDP downcast_bf16: 'no' enable_cpu_affinity: false fsdp_config: - fsdp_activation_checkpointing: false + fsdp_activation_checkpointing: true fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_backward_prefetch: BACKWARD_PRE fsdp_cpu_ram_efficient_loading: true diff --git a/examples/llm_qat/convert_sharded_ckpt.py b/examples/llm_qat/convert_sharded_ckpt.py deleted file mode 100644 index aa762709..00000000 --- a/examples/llm_qat/convert_sharded_ckpt.py +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os - -from transformers import AutoModelForCausalLM - -import modelopt.torch.opt as mto -from modelopt.torch.quantization.plugins.transformers_trainer import ( - convert_sharded_model_to_hf_format, -) - -# Enable ModelOpt checkpointing for HuggingFace models -mto.enable_huggingface_checkpointing() - - -def main(): - parser = argparse.ArgumentParser(description="Convert sharded checkpoint to HuggingFace format") - parser.add_argument( - "--hf_model_path", type=str, required=True, help="Path to the original HuggingFace model" - ) - parser.add_argument( - "--sharded_ckpt_path", - type=str, - required=True, - help="Path to the sharded checkpoint directory", - ) - parser.add_argument( - "--output_path", type=str, default="", help="Output path to save the converted model" - ) - - args = parser.parse_args() - - model = AutoModelForCausalLM.from_pretrained(args.hf_model_path) - if os.path.exists(os.path.join(args.sharded_ckpt_path, "pytorch_model_fsdp_0")): - convert_sharded_model_to_hf_format( - model, args.sharded_ckpt_path, "modelopt_state_train.pth", args.output_path - ) - - -if __name__ == "__main__": - main() diff --git a/examples/llm_qat/launch.sh b/examples/llm_qat/launch.sh index 879db8fd..5d9fc3a7 100755 --- a/examples/llm_qat/launch.sh +++ b/examples/llm_qat/launch.sh @@ -18,96 +18,37 @@ set -eo pipefail export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +# Helper function to parse a single argument value +parse_value() { + if [[ "$1" != *=* ]]; then shift; fi + echo "${1#*=}" +} + while [ $# -gt 0 ]; do case "$1" in - --model*) - if [[ "$1" != *=* ]]; then shift; fi - MODEL="${1#*=}" - ;; - --output_dir*) - if [[ "$1" != *=* ]]; then shift; fi - OUTPUT_DIR="${1#*=}" - ;; - --dataset*) - if [[ "$1" != *=* ]]; then shift; fi - DATASET="${1#*=}" - ;; - --train_size*) - if [[ "$1" != *=* ]]; then shift; fi - TRAIN_SIZE="${1#*=}" - ;; - --eval_size*) - if [[ "$1" != *=* ]]; then shift; fi - EVAL_SIZE="${1#*=}" - ;; - --num_epochs*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_EPOCHS="${1#*=}" - ;; - --max_steps*) - if [[ "$1" != *=* ]]; then shift; fi - MAX_STEPS="${1#*=}" - ;; - --save_steps*) - if [[ "$1" != *=* ]]; then shift; fi - SAVE_STEPS="${1#*=}" - ;; - --accum_steps*) - if [[ "$1" != *=* ]]; then shift; fi - ACCUM_STEPS="${1#*=}" - ;; - --lr*) - if [[ "$1" != *=* ]]; then shift; fi - LR="${1#*=}" - ;; - --quant_cfg*) - if [[ "$1" != *=* ]]; then shift; fi - QUANT_CFG="${1#*=}" - ;; - --compress*) - if [[ "$1" != *=* ]]; then shift; fi - COMPRESS="${1#*=}" - ;; - --calib_size*) - if [[ "$1" != *=* ]]; then shift; fi - CALIB_SIZE="${1#*=}" - ;; - --train_bs*) - if [[ "$1" != *=* ]]; then shift; fi - TRAIN_BS="${1#*=}" - ;; - --eval_bs*) - if [[ "$1" != *=* ]]; then shift; fi - EVAL_BS="${1#*=}" - ;; - --do_train*) - if [[ "$1" != *=* ]]; then shift; fi - DO_TRAIN="${1#*=}" - ;; - --lora*) - if [[ "$1" != *=* ]]; then shift; fi - LORA="${1#*=}" - ;; - --teacher_model*) - if [[ "$1" != *=* ]]; then shift; fi - TEACHER_MODEL="${1#*=}" - ;; - --distill*) - if [[ "$1" != *=* ]]; then shift; fi - DISTILL="${1#*=}" - ;; - --fsdp_transformer_layer_cls_to_wrap*) - if [[ "$1" != *=* ]]; then shift; fi - FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}" - ;; - --use_fsdp2*) - if [[ "$1" != *=* ]]; then shift; fi - USE_FSDP2="${1#*=}" - ;; - --max_seq_length*) - if [[ "$1" != *=* ]]; then shift; fi - MAX_SEQ_LENGTH="${1#*=}" - ;; + --model*) MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --output_dir*) OUTPUT_DIR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --dataset*) DATASET=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --train_size*) TRAIN_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --eval_size*) EVAL_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --num_epochs*) NUM_EPOCHS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --max_steps*) MAX_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --save_steps*) SAVE_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --accum_steps*) ACCUM_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --lr*) LR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --quant_cfg*) QUANT_CFG=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --compress*) COMPRESS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --calib_size*) CALIB_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --train_bs*) TRAIN_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --eval_bs*) EVAL_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --do_train*) DO_TRAIN=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --lora*) LORA=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --teacher_model*) TEACHER_MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --distill*) DISTILL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --fsdp_transformer_layer_cls_to_wrap*) FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --max_seq_length*) MAX_SEQ_LENGTH=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --backend*) BACKEND=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; + --use_fsdp2*) USE_FSDP2=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -142,6 +83,7 @@ COMPRESS=${COMPRESS:-"False"} DISTILL=${DISTILL:-"False"} TEACHER_MODEL=${TEACHER_MODEL:-$MODEL} FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"} +BACKEND=${BACKEND:-"fsdp1"} if [ -z $QUANT_CFG ]; then QUANT_ARGS="" @@ -154,31 +96,55 @@ if [ ! -z $MAX_STEPS ]; then OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS" fi -CONFIG_FILE="fsdp1.yaml" -FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP" -GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True" - +# Set backend based on --backend parameter, with backward compatibility for --use_fsdp2 if [[ "${USE_FSDP2,,}" == "true" ]]; then - echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers." - CONFIG_FILE="fsdp2.yaml" - GRADIENT_CHECKPOINTING_ARGS="" + echo "Warning: --use_fsdp2 is deprecated. Use --backend=fsdp2 instead." + BACKEND="fsdp2" +fi + +# if compress is true, set backend to ddp +if [[ "${COMPRESS,,}" == "true" ]]; then + BACKEND="ddp" fi +# Configure backend-specific settings +GRADIENT_CHECKPOINTING_ARGS="" +case "${BACKEND,,}" in + "fsdp1"|"fsdp") + CONFIG_FILE="fsdp1.yaml" + FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP" + ;; + "fsdp2") + echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers." + CONFIG_FILE="fsdp2.yaml" + FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP" + ;; + "ddp") + CONFIG_FILE="ddp.yaml" + FSDP_ARGS="" + GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True" + ;; + "deepspeed") + CONFIG_FILE="deepspeed.yaml" + FSDP_ARGS="" + GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True" + ;; + *) + echo "Error: Invalid backend '$BACKEND'. Supported backends: fsdp1, fsdp2, ddp, deepspeed" + exit 1 + ;; +esac + +# TODO: Remove this after simple distillation is supported DISTILLATION_ARGS="" if [[ "${DISTILL,,}" == "true" ]]; then DISTILLATION_ARGS="--distill $DISTILL --teacher_model $TEACHER_MODEL" - # Distillation does not work with memory efficient loading - FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False" + # Distillation does not work with memory efficient loading for FSDP + if [[ "${BACKEND,,}" == "fsdp1" || "${BACKEND,,}" == "fsdp2" ]]; then + FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False" + fi fi -# real quantization does not work with FSDP, only works with FSDP2 -if [[ "${COMPRESS,,}" == "true" && "${USE_FSDP2,,}" != "true" ]]; then - echo "Compression is not supported with FSDP. Disabling FSDP and using DDP." - FSDP_ARGS="" - CONFIG_FILE="ddp.yaml" -fi - - CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \ main.py \ --model_name_or_path $MODEL \ @@ -209,10 +175,9 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \ --report_to tensorboard \ --lora $LORA \ --compress $COMPRESS \ - $QUANT_ARGS $OPTIONAL_ARGS $GRADIENT_CHECKPOINTING_ARGS $DISTILLATION_ARGS + $GRADIENT_CHECKPOINTING_ARGS $QUANT_ARGS $OPTIONAL_ARGS $DISTILLATION_ARGS " start_time=$(date +%s) sh -c "$CMD" -echo "Total time taken: $(( $(date +%s) - $start_time )) seconds" -python convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR +echo "Total time taken: $(( $(date +%s) - $start_time )) seconds" \ No newline at end of file diff --git a/examples/llm_qat/llama_factory/launch_llamafactory.sh b/examples/llm_qat/llama_factory/launch_llamafactory.sh index 49551802..23e06f26 100644 --- a/examples/llm_qat/llama_factory/launch_llamafactory.sh +++ b/examples/llm_qat/llama_factory/launch_llamafactory.sh @@ -256,4 +256,3 @@ else echo "Modified FSDP args: $FSDP_ARGS" accelerate launch --config_file $ACCELERATE_CONFIG $FSDP_ARGS $SCRIPT_DIR/llama_factory.py $CONFIG_FILE fi -python $SCRIPT_DIR/../convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR diff --git a/examples/llm_qat/main.py b/examples/llm_qat/main.py index b711b9bb..30f49a6a 100644 --- a/examples/llm_qat/main.py +++ b/examples/llm_qat/main.py @@ -38,6 +38,7 @@ from transformers.trainer_utils import get_last_checkpoint from utils import ( get_lora_config, + get_metrics_with_perplexity, make_supervised_data_module, monkey_patch_training_step_to_fix_memory_leak, ) @@ -45,11 +46,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.distill.plugins.huggingface import LMLogitsLoss -from modelopt.torch.quantization.plugins.transformers_trainer import ( - QADTrainer, - QATTrainer, - get_metrics_with_perplexity, -) +from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer, QATTrainer from modelopt.torch.utils import print_rank_0 # Enable automatic save/load of modelopt state huggingface checkpointing @@ -263,16 +260,12 @@ def train(): if training_args.do_train: trainer.train(resume_from_checkpoint=checkpoint) + print_rank_0("Training completed.") if training_args.do_eval: - if not training_args.do_train: - # trainer.evaluate() will not prepare the model properly, especially for FSDP2, - # so we use the ``eval_on_start`` flag to evaluate the model and skip the training. - trainer.train(resume_from_checkpoint=checkpoint, eval_only=True) - else: - metrics = trainer.evaluate() - metrics = get_metrics_with_perplexity(metrics) - print_rank_0(f"Evaluation results: \n{metrics}") + metrics = trainer.evaluate() + metrics = get_metrics_with_perplexity(metrics) + print_rank_0(f"Evaluation results: \n{metrics}") if training_args.do_train or quant_args.quant_cfg is not None: print_rank_0("Saving the model...") diff --git a/examples/llm_qat/simple_qat_train.py b/examples/llm_qat/simple_qat_train.py index 36795802..85310278 100644 --- a/examples/llm_qat/simple_qat_train.py +++ b/examples/llm_qat/simple_qat_train.py @@ -74,7 +74,7 @@ def train(model, optimizer, train_dataloader, tokenizer, epochs, output_dir, dev def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="CNN QAT using ModelOpt") + parser = argparse.ArgumentParser(description="QAT Training Script") # Data paths parser.add_argument("--model-path", type=str, required=True, help="Path to the model") parser.add_argument("--train-size", type=int, default=2048, help="Train size") @@ -87,7 +87,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--quant-cfg", type=str, - default=mtq.NVFP4_DEFAULT_CFG, + default="NVFP4_DEFAULT_CFG", choices=mtq.config.choices, help="Quantization configuration", ) @@ -121,7 +121,7 @@ def calibrate(m: nn.Module): m(batch["input_ids"].to(device)) # Quantize the model - model = mtq.quantize(model, args.quant_cfg, calibrate) + model = mtq.quantize(model, getattr(mtq, args.quant_cfg), calibrate) # Initialize optimizer optimizer = AdamW(model.parameters(), lr=args.lr) diff --git a/examples/llm_qat/utils.py b/examples/llm_qat/utils.py index ac4a544a..bb70bdf1 100644 --- a/examples/llm_qat/utils.py +++ b/examples/llm_qat/utils.py @@ -167,3 +167,10 @@ def new_func(original_f_name, trainer, *args, **kwargs): setattr( trainer, f_name, types.MethodType(partial(new_func, "_original_" + f_name), trainer) ) + + +def get_metrics_with_perplexity(metrics): + """Add perplexity to the metrics.""" + if "eval_loss" in metrics: + metrics["perplexity"] = float(torch.exp(torch.tensor(metrics["eval_loss"]))) + return metrics diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index 468c36e8..183514f9 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -380,7 +380,7 @@ def apply_mode( return model.init_modellike() if isinstance(model, ModelLikeModule) else model # check if the model is in a wrapper - model = unwrap_model(model, raise_error=True) + model = unwrap_model(model, force_unwrap=True) # standardize mode to ModeConfigList mode_and_config = get_mode_config(mode) @@ -493,10 +493,6 @@ def save(model: nn.Module, f: str | os.PathLike | BinaryIO, **kwargs) -> None: model: Any model. f: Target file location. **kwargs: additional args for ``torch.save()``. - - .. note:: - - If model is a wrapper such as DistributedDataParallel, it will be unwrapped for saving. """ # unwrap model model = unwrap_model(model, warn=True) @@ -545,11 +541,6 @@ def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any] Returns: A modified model architecture based on the restored modifications with the unmodified weights as stored in the provided ``model`` argument. - - .. note:: - - Note that wrappers such as DistributedDataParallel are `not` supported during the restore - process. Please wrap the model after the restore process. """ # initialize ModelLikeModule if needed. model = model if isinstance(model, nn.Module) else ModelLikeModule(model) @@ -590,13 +581,15 @@ def restore(model: ModelLike, f: str | os.PathLike | BinaryIO, **kwargs) -> nn.M The model with original weights and stored architecture. .. note:: - Note that wrappers such as DistributedDataParallel are `not` supported during the restore process. Please wrap the model after the restore process. """ # initialize ModelLikeModule if needed. model = model if isinstance(model, nn.Module) else ModelLikeModule(model) + # check if the model is in a wrapper; we dont support restoring with wrappers + model = unwrap_model(model, raise_error=True) + # load checkpoint kwargs.setdefault("map_location", "cpu") kwargs.setdefault("weights_only", False) diff --git a/modelopt/torch/opt/dynamic.py b/modelopt/torch/opt/dynamic.py index 533b1f05..ac414367 100644 --- a/modelopt/torch/opt/dynamic.py +++ b/modelopt/torch/opt/dynamic.py @@ -1273,7 +1273,8 @@ def config(self, configurable: bool | None = None) -> dict[str, Any]: A dict of ``(parameter_name, choice)`` that specifies an active subnet. """ return { - get_unwrapped_name(name): hp.active for name, hp in self.named_hparams(configurable) + get_unwrapped_name(name, self.model): hp.active + for name, hp in self.named_hparams(configurable) } def select(self, config: dict[str, Any], strict: bool = True) -> None: diff --git a/modelopt/torch/opt/plugins/peft.py b/modelopt/torch/opt/plugins/peft.py index 55855d50..5e5ed0f9 100644 --- a/modelopt/torch/opt/plugins/peft.py +++ b/modelopt/torch/opt/plugins/peft.py @@ -57,14 +57,9 @@ def _new_save_pretrained_peft(self, save_directory, *args, **kwargs): # So we need to save the quantizer state_dict separately # TODO: Move this to modelopt.torch.quantization.plugins.peft - from modelopt.torch.quantization.nn import TensorQuantizer - - # We should not call self/model.state_dict() here. HF Trainer calls model.save_pretrained() only from process 0 - # With FSDP, model.state_dict() will hang if it is not called from all processes - quantizer_state_dict = {} - for name, module in self.named_modules(): - if isinstance(module, TensorQuantizer): - quantizer_state_dict[get_unwrapped_name(name)] = module.state_dict() + from modelopt.torch.quantization.utils import get_quantizer_state_dict + + quantizer_state_dict = get_quantizer_state_dict(self) if len(quantizer_state_dict) > 0: torch.save(quantizer_state_dict, _get_quantizer_state_save_path(save_directory)) return outputs @@ -95,7 +90,7 @@ def _new_load_adapter(self, model_id, adapter_name, *args, **kwargs): ) for name, module in self.named_modules(): if isinstance(module, TensorQuantizer): - module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name)]) + module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name, self)]) return outputs diff --git a/modelopt/torch/quantization/calib/histogram.py b/modelopt/torch/quantization/calib/histogram.py index d0a13bc2..e27a5471 100644 --- a/modelopt/torch/quantization/calib/histogram.py +++ b/modelopt/torch/quantization/calib/histogram.py @@ -157,8 +157,7 @@ def compute_amax( """ if dist.is_initialized(): warnings.warn( - "This method does not perform any synchronization across DistributedDataParallel" - " (DDP) https://pytorch.org/docs/stable/notes/ddp.html modules. The recommended" + "This method does not perform any synchronization across distributed processes. The recommended" " method is to use the same calibration dataset across all distributed data" " parallel groups so that `amax` is the same for all DDP modules." ) diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 6a457f17..7c2f84b8 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -123,12 +123,12 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: for name, module in model.named_modules(): if isinstance(module, TensorQuantizer): - name = get_unwrapped_name(name) + name = get_unwrapped_name(name, model) module.set_from_modelopt_state(quantizer_state_dict[name]) for name, module in model.named_modules(): if isinstance(module, QuantModule): - name = get_unwrapped_name(name) + name = get_unwrapped_name(name, model) module.modelopt_post_restore(name) return model @@ -166,7 +166,7 @@ def update_quantize_metadata( def quantizer_state(model: nn.Module) -> dict[str, Any]: """Returns the quantizer state dict describing the quantizer states in the model.""" return { - get_unwrapped_name(n): m.get_modelopt_state() + get_unwrapped_name(n, model): m.get_modelopt_state() for n, m in model.named_modules() if isinstance(m, (TensorQuantizer, SequentialQuantizer)) } diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 93df3651..12aaee3f 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -140,8 +140,10 @@ class QuantLinearConvBase(QuantInputBase): def quantize_weight(self): """Context in which `self.weight` is quantized.""" self._enable_weight_quantization = True - yield - self._enable_weight_quantization = False + try: + yield + finally: + self._enable_weight_quantization = False @staticmethod def _get_quantized_weight(module: "QuantLinearConvBase", weight: torch.Tensor) -> torch.Tensor: diff --git a/modelopt/torch/quantization/plugins/transformers_trainer.py b/modelopt/torch/quantization/plugins/transformers_trainer.py index 017d8160..e6a0a2b7 100644 --- a/modelopt/torch/quantization/plugins/transformers_trainer.py +++ b/modelopt/torch/quantization/plugins/transformers_trainer.py @@ -17,12 +17,11 @@ import gc import os -from contextlib import suppress +import types from dataclasses import dataclass, field import torch -import torch.distributed.checkpoint as dist_cp -from accelerate.utils import save_fsdp_model +from tqdm import tqdm import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq @@ -32,10 +31,13 @@ from modelopt.torch.opt.conversion import restore_from_modelopt_state from modelopt.torch.opt.plugins import ModelOptHFTrainer from modelopt.torch.quantization.config import QuantizeConfig +from modelopt.torch.quantization.nn import TensorQuantizer from modelopt.torch.quantization.utils import ( calibrate_with_adapters, disable_lora_quantizers_in_config, + get_quantizer_state_dict, is_quantized, + set_quantizer_state_dict, ) from modelopt.torch.utils import print_rank_0 @@ -98,10 +100,6 @@ class QuantizationArgumentsWithConfig(QuantizationArguments): ) -class EvalOnlyError(Exception): - """Exception to raise when evaluation is only needed.""" - - def check_awq_smoothquant(quant_cfg): # TODO: Remove this once deepspeed for AWQ and SmoothQuant is added """Get the quantization type from the configuration.""" @@ -116,54 +114,6 @@ def check_awq_smoothquant(quant_cfg): return is_awq_smoothquant -def get_metrics_with_perplexity(metrics): - """Add perplexity to the metrics.""" - metrics = {"perplexity": float(torch.exp(torch.tensor(metrics["eval_loss"]))), **metrics} - return metrics - - -def convert_sharded_model_to_hf_format( - model, model_path, modelopt_state_name="modelopt_state.pth", output_path=None -): - """Convert a sharded model to HF format. - - Args: - model: The original HF model. - model_path: The path to the sharded model with pytorch_model_fsdp_0 directory. - modelopt_state_name: The name of the modelopt state file. If not provided, the default name - "modelopt_state.pth" will be used. - output_path: The path to save the converted model. If not provided, the model will be saved - to the same directory as the sharded model. - """ - if output_path is None: - output_path = model_path - os.makedirs(output_path, exist_ok=True) - state_dict = {"model": model.state_dict()} - sharded_model_path = os.path.join(model_path, "pytorch_model_fsdp_0") - modelopt_state_path = os.path.join(model_path, modelopt_state_name) - if not os.path.exists(sharded_model_path): - print_rank_0(f"Sharded model path does not exist: {sharded_model_path}") - return model - dist_cp.load_state_dict( - state_dict=state_dict, - storage_reader=dist_cp.FileSystemReader(sharded_model_path), - no_dist=True, - ) - model.load_state_dict(state_dict["model"]) - restore_modelopt_state_with_weights(model, modelopt_state_path) - model.save_pretrained(output_path) - return model - - -def restore_modelopt_state_with_weights(model, modelopt_state_path): - """Restore the modelopt weights for fsdp2 models.""" - _modelopt_state = torch.load(modelopt_state_path, weights_only=False) - modelopt_weights = _modelopt_state.pop("modelopt_state_weights", None) - restore_from_modelopt_state(model, _modelopt_state) - if modelopt_weights is not None: - model.load_state_dict(modelopt_weights, strict=False) - - class QATTrainer(ModelOptHFTrainer): """A drop-in replacement of HuggingFace's Trainer for quantization aware training with ModelOpt. @@ -190,15 +140,6 @@ def __init__( else quant_args.quant_cfg ) self.quant_cfg = quant_cfg - self._eval_without_training = False - - self._is_fsdp2 = self.is_fsdp_enabled and ( - getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2 - ) - self.fsdp_state_dict_type = ( - str(self.accelerator.state.fsdp_plugin.state_dict_type) if self.is_fsdp_enabled else "" - ) - self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth") # Add lora adapter before quantizing the model if getattr(self.args, "lora_config", None) is not None and not hasattr( @@ -219,144 +160,159 @@ def __init__( f"QAT DeepSpeed does not currently support AWQ or SmoothQuant: {self.quant_cfg}" ) - # FSDP1 requires pre-restoring the quantized model if the modelopt state exists. - if os.path.exists(self._modelopt_state_path) and not self._is_fsdp2: - self._quantize_model() - - def _get_quantize_forward_loop(self, data_loader, use_eval_loop=True): - def forward_loop(_model): - print_rank_0("Calibrating...") - if use_eval_loop: - return self.evaluation_loop( - data_loader, - description="Calibration", - prediction_loss_only=True, - ignore_keys=None, - metric_key_prefix="calibration", - ) - else: - for batch in data_loader: - batch = self._prepare_inputs(batch) - _model(**batch) - print_rank_0("Calibration done!") + self._patch_accelerate_for_fsdp2_fix() - return forward_loop + self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth") + if os.path.exists(self._modelopt_state_path): + self._restore_modelopt_state_with_weights() + elif is_quantized(self.model): + self._save_modelopt_state_with_weights() - def _save_modelopt_state_with_weights(self, model, modelopt_state_path, save_weights=False): + def _save_modelopt_state_with_weights(self): """Save the modelopt weights for fsdp2 models.""" - modelopt_state = mto.modelopt_state(model) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + modelopt_state = mto.modelopt_state(self.model) + # TODO: remove this from ModelOpt HF Trainer flows modelopt_state["modelopt_state_dict"] = [ state for state in modelopt_state["modelopt_state_dict"] if "kd_loss" not in state and "export_student" not in state ] - if save_weights: - state_dict = model.state_dict() - modelopt_weights = {} - for k, v in state_dict.items(): - if "_quantizer" in k: - modelopt_weights[k] = v.cpu() - modelopt_state["modelopt_state_weights"] = modelopt_weights + modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model) if self.args.should_save: - torch.save(modelopt_state, modelopt_state_path) + torch.save(modelopt_state, self._modelopt_state_path) - if torch.distributed.is_initialized(): - torch.distributed.barrier() + print_rank_0(f"Saved modelopt state to {self._modelopt_state_path}") - def _quantize_model(self, use_eval_loop=True): + def _restore_modelopt_state_with_weights(self): + modelopt_state = torch.load(self._modelopt_state_path, weights_only=False) + modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) + restore_from_modelopt_state(self.model, modelopt_state) + if modelopt_weights is not None: + set_quantizer_state_dict(self.model, modelopt_weights) + print_rank_0("Restored modelopt state with weights.") + + def _quantize_model(self): """Quantize the model. Restore the quantization state if it exists.""" - model = self.accelerator.unwrap_model(self.model) - if os.path.exists(self._modelopt_state_path): - print_rank_0(f"Restoring modelopt state from {self._modelopt_state_path}...") - restore_modelopt_state_with_weights(self.model, self._modelopt_state_path) - print_rank_0("Restored model from modelopt state.") - else: - dataset = torch.utils.data.Subset( - self.eval_dataset, - list(range(min(self.quant_args.calib_size, len(self.eval_dataset)))), # type: ignore [union-attr] - ) - data_loader = self.get_eval_dataloader(dataset) - forward_loop = self._get_quantize_forward_loop(data_loader, use_eval_loop) - with calibrate_with_adapters(model, self.args): - print_rank_0("Quantizing the model...") - mtq.quantize(model, self.quant_cfg, forward_loop) # type: ignore [arg-type] - print_rank_0("Quantization done!") - - if getattr(self.quant_args, "compress", False): - print_rank_0("Compressing model after calibration") - mtq.compress(model) - - # Force garbage collection to free up memory - gc.collect() - - print_rank_0(f"Saving modelopt state to {self._modelopt_state_path}") - self._save_modelopt_state_with_weights( - model, self._modelopt_state_path, save_weights=True - ) - torch.cuda.empty_cache() - if use_eval_loop: - self.callback_handler.on_evaluate(self, self.state, self.control, metrics=None) + dataset = self.train_dataset if self.train_dataset is not None else self.eval_dataset + assert dataset is not None, "Calibration requires either eval or train dataset." + num_samples = min(self.quant_args.calib_size, len(dataset)) # type: ignore [union-attr] + dataset = torch.utils.data.Subset(dataset, list(range(num_samples))) + data_loader = self.get_eval_dataloader(dataset) + + def forward_loop(model): + for batch in tqdm(data_loader, desc="Calibrating", disable=not self.args.should_save): + batch = self._prepare_inputs(batch) + # Important: We should forward pass using the unwrapped model + # mtq.quantize will unwrap the model & pass to the forward_loop + self.model(**batch) + + # TODO: Remove calibrate_with_adapters - this should not be needed + with calibrate_with_adapters(self.model, self.args): + print_rank_0("Quantizing the model...") + mtq.quantize(self.model, self.quant_cfg, forward_loop) # type: ignore [arg-type] + + if getattr(self.quant_args, "compress", False): + print_rank_0("Compressing model after calibration") + mtq.compress(self.model) + + # Force garbage collection to free up memory + gc.collect() + + self._save_modelopt_state_with_weights() + torch.cuda.empty_cache() if self.accelerator.is_main_process: - mtq.print_quant_summary(model) + mtq.print_quant_summary(self.model) - def _evaluate(self, *args, **kwargs): - """Quantize the model before evaluation. + def training_step(self, *args, **kwargs): + """Training step.""" + if self.quant_cfg is not None and not is_quantized(self.model): + self._quantize_model() + return super().training_step(*args, **kwargs) - Note that we do not force to run the evaluation if the `eval_on_start` is False. - """ + def prediction_step(self, *args, **kwargs): + """Prediction step.""" if self.quant_cfg is not None and not is_quantized(self.model): self._quantize_model() - metrics = None - if self._original_evaluate_on_start: - metrics = super()._evaluate(*args, **kwargs) - else: - metrics = super()._evaluate(*args, **kwargs) - # used for eval without training - if self._eval_without_training: - metrics = get_metrics_with_perplexity(metrics) - print_rank_0(f"Evaluation results: \n{metrics}") - raise EvalOnlyError() - return metrics - - def train(self, *args, eval_only=False, **kwargs): - """Train the model with quantization.""" - self._eval_without_training = eval_only - self._original_evaluate_on_start = ( - self.args.eval_on_start if not self._eval_without_training else True + return super().prediction_step(*args, **kwargs) + + def evaluate(self, *args, **kwargs): + """Evaluate the model.""" + if self.args.do_eval and not self.args.do_train and self.accelerator.is_fsdp2: + # [Not related to ModelOpt] HF does not support eval only for FSDP2. + # This is a hack to make it work + dummy_optimizer = torch.optim.SGD([next(self.model.parameters())], lr=0.0) + self.model, _ = self.accelerator.prepare(self.model, dummy_optimizer) + return super().evaluate(*args, **kwargs) + + def train(self, *args, **kwargs): + """Train the model.""" + outputs = super().train(*args, **kwargs) + print_rank_0( + "Training completed. Please save the final model using `Trainer.save_model()` " + "to preserve ModelOpt states." ) - if getattr(self.quant_args, "quant_cfg", None) is not None and not is_quantized(self.model): - self.args.eval_on_start = True - train_result = None - with suppress(EvalOnlyError): - train_result = super().train(*args, **kwargs) - self.args.eval_on_start = self._original_evaluate_on_start - return train_result + return outputs - def save_model( - self, output_dir: str | None = None, _internal_call: bool = False, *args, **kwargs - ): + def save_model(self, *args, **kwargs): """Save the quantized model.""" - dict_type = ( - str(self.accelerator.state.fsdp_plugin.state_dict_type) if self.is_fsdp_enabled else "" - ) - if not _internal_call and self.is_fsdp_enabled and "SHARDED_STATE_DICT" in dict_type: - # The default save_model in Trainer doesn't save checkpoint with SHARDED_STATE_DICT + FSDP. - # We save the model manually at the end of the training in order to convert the last - # checkpoint from distcp to HF compatible format. - if output_dir is None: - output_dir = self.args.output_dir - save_fsdp_model( - self.accelerator.state.fsdp_plugin, - self.accelerator, - self.model, - output_dir, - ) - self.processing_class.save_pretrained(output_dir) - self.model.config.save_pretrained(output_dir) + if ( + (not self.is_in_train) + and self.is_fsdp_enabled + and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT" + ): + print_rank_0("Setting state_dict_type to FULL_STATE_DICT for final checkpoint save.") + original_type = self.accelerator.state.fsdp_plugin.state_dict_type + self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + outputs = super().save_model(*args, **kwargs) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)): + print_rank_0( + "Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the " + "model. See https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.opt.plugins.huggingface.html#modelopt.torch.opt.plugins.huggingface.enable_huggingface_checkpointing" + ) + self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type) else: - super().save_model(output_dir, _internal_call, *args, **kwargs) + outputs = super().save_model(*args, **kwargs) + return outputs + + def _patch_accelerate_for_fsdp2_fix(self): + """Fixes for accelerate prepare. + + Accelerate fsdp2 prepare assumes that all parameters and buffers are sharded. This assumption + is causing issues with quantized models since quantization modules adds buffers which are not sharded. + This patch hides the buffers added by quantization modules from the original accelerate prepare. + """ + + def _modelopt_prepare(self, *args, **kwargs): + if not self.is_fsdp2: + return self._original_prepare(*args, **kwargs) + + model = next((obj for obj in args if isinstance(obj, torch.nn.Module)), None) + if model is None: + return self._original_prepare(*args, **kwargs) + + tq_og_non_prsist_buffers = {} + for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)): + tq.to_empty(device=self.device) + tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy() + tq._non_persistent_buffers_set.update(tq._buffers.keys()) + + outputs = self._original_prepare(*args, **kwargs) + + for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)): + tq._non_persistent_buffers_set.clear() + tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq]) + + return outputs + + self.accelerator._original_prepare = self.accelerator.prepare + self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator) class QADTrainer(QATTrainer, KDTrainer): @@ -385,7 +341,7 @@ def __init__( # And memory efficient loading doesn't work. self.model.cuda() if self.quant_cfg is not None and not is_quantized(self.model): - self._quantize_model(use_eval_loop=False) + self._quantize_model() if getattr(self.args, "lora_config", None) is not None: self.model.add_adapter(self.args.lora_config, adapter_name="adapter") print_rank_0("Lora adapter added.") @@ -416,7 +372,9 @@ def save_model( output_dir: The directory to save the model and ModelOpt states. export_student: Whether to export the student model. """ - if "SHARDED_STATE_DICT" in self.fsdp_state_dict_type and self._is_fsdp2: + if self.accelerator.is_fsdp2 and "SHARDED_STATE_DICT" in str( + self.accelerator.state.fsdp_plugin.state_dict_type + ): if export_student: model = self.accelerator.unwrap_model(self.model) model = model.export() diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 09faf58e..6167daf2 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -25,7 +25,7 @@ from torch.distributed.fsdp import FSDPModule from torch.distributed.tensor import Replicate -from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils import get_unwrapped_name, print_rank_0 __all__ = [ "EXPORT_MODE", @@ -441,3 +441,26 @@ def enable_weight_access_and_writeback(module, root_model): with context: yield + + +def get_quantizer_state_dict(model: nn.Module): + """Get the state dict of the quantizers in the model.""" + # We should not call model.state_dict() here. + # With FSDP, model.state_dict() will hang if it is not called from all processes + from .nn import TensorQuantizer + + quantizer_state_dict = {} + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer): + quantizer_state_dict[get_unwrapped_name(name, model)] = module.state_dict() + return quantizer_state_dict + + +def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict): + """Set the state dict of the quantizers in the model.""" + from .nn import TensorQuantizer + + for name, module in model.named_modules(): + key = get_unwrapped_name(name, model) + if isinstance(module, TensorQuantizer) and key in quantizer_state_dict: + module.load_state_dict(quantizer_state_dict[key]) diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index 93dbffbd..1940295c 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -70,12 +70,19 @@ def _convert_to_wrapped_module_name(name: str) -> str: ] # NOTE: can be extended dynamically in appropriate plugin files if available (e.g. megatron core) -SUPPORTED_WRAPPERS = { +SUPPORTED_WRAPPERS: dict[type[nn.Module], str] = { nn.parallel.DataParallel: "module", # indicating attribute key to unwrap nn.parallel.DistributedDataParallel: "module", + torch.distributed.fsdp.FullyShardedDataParallel: "module", } -UNSUPPORTED_WRAPPERS = {torch.distributed.fsdp.FullyShardedDataParallel: "module"} +try: + from deepspeed.runtime.engine import DeepSpeedEngine +except: # noqa: E722 + DeepSpeedEngine = None + +if DeepSpeedEngine is not None: + SUPPORTED_WRAPPERS[DeepSpeedEngine] = "module" ModelLike = Union[nn.Module, type[nn.Module], tuple, Callable] # noqa: UP007 ConstructorLike = Callable | tuple @@ -430,11 +437,8 @@ def unwrap_model( """Unwrap a model that is wrapped by supported wrapper module or return original model.""" if force_unwrap: try: - if type(model) in SUPPORTED_WRAPPERS or type(model) in UNSUPPORTED_WRAPPERS: - return getattr( - model, - SUPPORTED_WRAPPERS.get(type(model), UNSUPPORTED_WRAPPERS.get(type(model))), # type: ignore [arg-type] - ) + if type(model) in SUPPORTED_WRAPPERS: + return getattr(model, SUPPORTED_WRAPPERS[type(model)]) except AttributeError: raise ValueError( f"Model of type {type(model)} could not be forcefully unwrapped! Please manually" @@ -447,11 +451,6 @@ def unwrap_model( elif warn: warnings.warn(msg or f"Model {model} is wrapped by {type(model)}; unwrapping...") return getattr(model, SUPPORTED_WRAPPERS[type(model)]) - elif type(model) in UNSUPPORTED_WRAPPERS: - raise ValueError( - f"Automatically unwrapping {type(model)} is not supported at this time! Please manually" - " unwrap the model before passing it in." - ) return model @@ -597,14 +596,17 @@ def delete_grad_hook(*_unused): return accum_grad, handle -def get_unwrapped_name(name: str) -> str: +def get_unwrapped_name(name: str, model: nn.Module | None = None) -> str: """Get the cleaned module name (i.e, the name before wrapping with sharded modules).""" # The distributed sharded wrappers such as FSDP wraps the child modules as well # So unwrapping just the parent module is not enough # Instead of unwrapping the child modules and changing the model, we can just clean the name # _convert_to_wrapped_module_name is a Pytorch utility function to do this + if isinstance(model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)) or ( + DeepSpeedEngine is not None and isinstance(model, DeepSpeedEngine) + ): + name = name.removeprefix("module.") - # TODO: Implement support for DeepSpeed Zero wrapped modules name = _convert_to_wrapped_module_name(name) name = name.removesuffix(".") return name diff --git a/tests/_test_utils/examples/run_command.py b/tests/_test_utils/examples/run_command.py index 8e6bbb64..cf31ce38 100644 --- a/tests/_test_utils/examples/run_command.py +++ b/tests/_test_utils/examples/run_command.py @@ -32,9 +32,15 @@ def _extend_cmd_parts(cmd_parts: list[str], **kwargs): return cmd_parts -def run_example_command(cmd_parts: list[str], example_path: str): +def run_example_command(cmd_parts: list[str], example_path: str, setup_free_port: bool = False): print(f"[{example_path}] Running command: {cmd_parts}") - subprocess.run(cmd_parts, cwd=MODELOPT_ROOT / "examples" / example_path, check=True) + env = os.environ.copy() + + if setup_free_port: + free_port = get_free_port() + env["MASTER_PORT"] = str(free_port) + + subprocess.run(cmd_parts, cwd=MODELOPT_ROOT / "examples" / example_path, env=env, check=True) def run_command_in_background(cmd_parts, example_path, stdout=None, stderr=None, text=True): diff --git a/tests/examples/llm_qat/test_llm_qat.py b/tests/examples/llm_qat/test_llm_qat.py index 34982438..c9fef976 100644 --- a/tests/examples/llm_qat/test_llm_qat.py +++ b/tests/examples/llm_qat/test_llm_qat.py @@ -33,10 +33,16 @@ def _run_command(extra_cmd_args: list[str]): *extra_cmd_args, ], "llm_qat", + setup_free_port=True, ) - -def test_llama_qat_int4w_int8a(tiny_llama_path, tmp_path): +@pytest.mark.parametrize("backend", [ + "fsdp1", + "fsdp2", + "deepspeed", + "ddp", +]) +def test_llama_qat_int4w_int8a(tiny_llama_path, tmp_path, backend): ptq_output_dir = tmp_path / "ptq" qat_output_dir = tmp_path / "qat" @@ -47,6 +53,7 @@ def test_llama_qat_int4w_int8a(tiny_llama_path, tmp_path): "--do_train", "False", "--quant_cfg", "INT4_WEIGHT_INT8_ACTIVATIONS", "--output_dir", ptq_output_dir, + "--backend", backend, ] ) @@ -56,9 +63,27 @@ def test_llama_qat_int4w_int8a(tiny_llama_path, tmp_path): "--model", ptq_output_dir, "--do_train", "True", "--output_dir", qat_output_dir, + "--backend", backend, ] ) +@pytest.mark.parametrize("backend", [ + "fsdp1", + "fsdp2", + "deepspeed", + "ddp", +]) +def test_llama_qat_int4w_int8a_direct_qat(tiny_llama_path, tmp_path, backend): + # Run PTQ + QAT together + _run_command( + [ + "--model", tiny_llama_path, + "--do_train", "True", + "--quant_cfg", "INT4_WEIGHT_INT8_ACTIVATIONS", + "--output_dir", tmp_path, + "--backend", backend, + ] + ) def test_llama_lora_qat_nvfp4(tiny_llama_path, tmp_path): _run_command(