diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 9e824d775..82a3381e1 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -43,8 +43,9 @@ examples/llm_eval @NVIDIA/modelopt-examples-llm_ptq-codeowners examples/llm_ptq @NVIDIA/modelopt-examples-llm_ptq-codeowners examples/llm_qat @NVIDIA/modelopt-examples-llm_qat-codeowners examples/llm_sparsity @NVIDIA/modelopt-torch-sparsity-codeowners +examples/megatron-lm @NVIDIA/modelopt-examples-megatron-codeowners examples/model_hub @NVIDIA/modelopt-examples-model_hub-codeowners -examples/nemo_run @NVIDIA/modelopt-examples-nemo_run-codeowners +examples/nemo_run @NVIDIA/modelopt-examples-megatron-codeowners examples/onnx_ptq @NVIDIA/modelopt-onnx-codeowners examples/pruning @NVIDIA/modelopt-torch-nas-prune-codeowners examples/speculative_decoding @NVIDIA/modelopt-torch-speculative-codeowners diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 35e0b34cb..23e571620 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -26,6 +26,7 @@ Model Optimizer Changelog (Linux) - Add support for ``mamba_num_heads``, ``mamba_head_dim``, ``hidden_size`` and ``num_layers`` pruning for Megatron Core Mamba or Hybrid Transformer Mamba models in ``mcore_minitron`` (previously ``mcore_gpt_minitron``) mode. - Add example for QAT/QAD training with `LLaMA Factory `_. See ``examples/llm_qat/llama_factory`` for more details. - Upgrade TensorRT-LLM dependency to 1.0.0rc6. +- Add unified HuggingFace model export support for quantized NVFP4 GPT-OSS models. 0.33 (2025-07-14) ^^^^^^^^^^^^^^^^^ diff --git a/examples/chained_optimizations/bert_prune_distill_quantize.py b/examples/chained_optimizations/bert_prune_distill_quantize.py index 5d9eb4f69..3f8604546 100644 --- a/examples/chained_optimizations/bert_prune_distill_quantize.py +++ b/examples/chained_optimizations/bert_prune_distill_quantize.py @@ -1107,6 +1107,7 @@ def main(input_args: list[str] | None = None) -> None: format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, + force=True, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: diff --git a/examples/diffusers/quantization/diffusion_trt.py b/examples/diffusers/quantization/diffusion_trt.py index 7743d7218..4db12e9c2 100644 --- a/examples/diffusers/quantization/diffusion_trt.py +++ b/examples/diffusers/quantization/diffusion_trt.py @@ -22,7 +22,7 @@ remove_nesting, update_dynamic_axes, ) -from quantize import create_pipeline +from quantize import ModelType, PipelineManager import modelopt.torch.opt as mto from modelopt.torch._deploy._runtime import RuntimeRegistry @@ -31,6 +31,20 @@ from modelopt.torch._deploy.device_model import DeviceModel from modelopt.torch._deploy.utils import get_onnx_bytes_and_metadata +MODEL_ID = { + "sdxl-1.0": ModelType.SDXL_BASE, + "sdxl-turbo": ModelType.SDXL_TURBO, + "sd3-medium": ModelType.SD3_MEDIUM, + "flux-dev": ModelType.FLUX_DEV, + "flux-schnell": ModelType.FLUX_SCHNELL, +} + +dtype_map = { + "Half": torch.float16, + "BFloat16": torch.bfloat16, + "Float": torch.float32, +} + def generate_image(pipe, prompt, image_name): seed = 42 @@ -91,7 +105,7 @@ def main(): image_name = args.save_image_as if args.save_image_as else f"{args.model}.png" - pipe = create_pipeline(args.model, args.model_dtype, args.override_model_path) + pipe = PipelineManager.create_pipeline_from(MODEL_ID[args.model], dtype_map[args.model_dtype]) # Save the backbone of the pipeline and move it to the GPU add_embedding = None diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 7f64222cb..31a688172 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -306,6 +306,37 @@ def __init__(self, config: ModelConfig, logger: logging.Logger): self.pipe: DiffusionPipeline | None = None self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling + @staticmethod + def create_pipeline_from( + model_type: ModelType, torch_dtype: torch.dtype = torch.bfloat16 + ) -> DiffusionPipeline: + """ + Create and return an appropriate pipeline based on configuration. + + Returns: + Configured diffusion pipeline + + Raises: + ValueError: If model type is unsupported + """ + try: + model_id = MODEL_REGISTRY[model_type] + if model_type == ModelType.SD3_MEDIUM: + pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype) + elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: + pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) + else: + # SDXL models + pipe = DiffusionPipeline.from_pretrained( + model_id, + torch_dtype=torch_dtype, + use_safetensors=True, + ) + pipe.set_progress_bar_config(disable=True) + return pipe + except Exception as e: + raise e + def create_pipeline(self) -> DiffusionPipeline: """ Create and return an appropriate pipeline based on configuration. diff --git a/examples/gpt-oss/README.md b/examples/gpt-oss/README.md index 486130157..80484be2e 100644 --- a/examples/gpt-oss/README.md +++ b/examples/gpt-oss/README.md @@ -49,6 +49,8 @@ model = mtq.quantize(model, config, forward_loop) train(model, train_loader, optimizer, scheduler, ...) ``` +For an end to end example showcasing the above workflow, checkout [qat-finetune-transformers.ipynb](/examples/gpt-oss/qat-finetune-transformers.ipynb). + If you are training Huggingface models with trainer classes from Huggingface such as [SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) performing QAT is even easier - simply replace the trainer with its equivalent, `QATSFTTrainer` from ModelOpt and specify additional quantization arguments to it. `QATSFTTrainer` will perform the necessary quantization steps in the backend and train the model just like the original `SFTTrainer`. A real end-to-end example for this is in `sft.py` in this folder. To perform QAT with full parameter SFT on GPT-OSS 20B model, run: diff --git a/examples/gpt-oss/convert_oai_mxfp4_weight_only.py b/examples/gpt-oss/convert_oai_mxfp4_weight_only.py index 6c3fce16b..bebb91486 100644 --- a/examples/gpt-oss/convert_oai_mxfp4_weight_only.py +++ b/examples/gpt-oss/convert_oai_mxfp4_weight_only.py @@ -23,11 +23,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config from utils import get_original_huggingface_quant_method -import modelopt.torch.opt as mto from modelopt.torch.quantization.qtensor import MXFP4QTensor -mto.enable_huggingface_checkpointing() - def _to_oai_mxfp4_weight_only(model, block_size=32): new_state_dict = {} @@ -36,15 +33,20 @@ def _to_oai_mxfp4_weight_only(model, block_size=32): # Only convert experts weights, skip bias and other modules if "experts" in name and "bias" not in name: param = param.transpose(-1, -2).contiguous() - quantized, scales = MXFP4QTensor.quantize(param, block_size=block_size) - - shape = quantized._quantized_data.shape + quantized_tensors = [] + scales_tensors = [] + for expert in param: + quantized, scales = MXFP4QTensor.quantize(expert, block_size=block_size) + quantized_tensors.append(quantized._quantized_data) + scales_tensors.append(scales) + quantized = torch.stack(quantized_tensors) + scales = torch.stack(scales_tensors) + + shape = quantized.shape # Add converted weights and scales to state_dict new_state_dict.update( { - f"{name}_blocks": quantized._quantized_data.view( - shape[0], shape[1], -1, block_size // 2 - ).cpu(), + f"{name}_blocks": quantized.view(shape[0], shape[1], -1, block_size // 2).cpu(), f"{name}_scales": scales.view(shape[0], shape[1], -1).cpu(), } ) @@ -134,6 +136,8 @@ def create_parser(): if args.lora_path: model = PeftModel.from_pretrained(model, args.lora_path) model = model.merge_and_unload() # Merge LoRA-QAT adapter weights to base model + torch.cuda.empty_cache() + gc.collect() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) diff --git a/examples/gpt-oss/qat-finetune-transformers.ipynb b/examples/gpt-oss/qat-finetune-transformers.ipynb index da570370e..695ed39f6 100644 --- a/examples/gpt-oss/qat-finetune-transformers.ipynb +++ b/examples/gpt-oss/qat-finetune-transformers.ipynb @@ -5,13 +5,348 @@ "id": "13d9d5ea", "metadata": {}, "source": [ - "# To Do: QAT simple workflow without the Trainer" + "# Quantization-Aware Fine-Tuning for GPT-OSS\n", + "\n", + "This notebook demonstrates a complete workflow for fine-tuning language models with Quantization-Aware Training (QAT) using modelopt and SFTTrainer for gpt-oss models.\n", + "\n", + "## Overview\n", + "\n", + "The workflow includes:\n", + "\n", + "• Model and tokenizer loading\n", + "\n", + "• Dataset preparation\n", + "\n", + "• Training configuration setup\n", + "\n", + "• Model quantization\n", + "\n", + "• Quantization aware training\n", + "\n", + "• Model saving and checkpointing" + ] + }, + { + "cell_type": "markdown", + "id": "30a838e9", + "metadata": {}, + "source": [ + "**Setup Environment**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "533b4e64", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade transformers trl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a4a12d9", + "metadata": {}, + "outputs": [], + "source": [ + "import modelopt.torch.opt as mto\n", + "\n", + "# Enable automatic save/load of modelopt state huggingface checkpointing\n", + "# modelopt state will be saved automatically to \"modelopt_state.pth\"\n", + "mto.enable_huggingface_checkpointing()" + ] + }, + { + "cell_type": "markdown", + "id": "01d85cd4", + "metadata": {}, + "source": [ + "**Model Configuration**\n", + "\n", + "Configure the model parameters including the model path, attention implementation, and data type. Set up the model configuration and prepare the model loading arguments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a38a1233", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoConfig, Mxfp4Config\n", + "from trl import ModelConfig\n", + "\n", + "model_args = ModelConfig(\n", + " model_name_or_path=\"openai/gpt-oss-20b\",\n", + " attn_implementation=\"eager\",\n", + " torch_dtype=\"bfloat16\",\n", + ")\n", + "model_kwargs = {\n", + " \"revision\": model_args.model_revision,\n", + " \"trust_remote_code\": model_args.trust_remote_code,\n", + " \"attn_implementation\": model_args.attn_implementation,\n", + " \"torch_dtype\": model_args.torch_dtype,\n", + " \"use_cache\": False,\n", + " \"device_map\": \"auto\",\n", + "}\n", + "\n", + "# Dequantize if the model is in MXFP4 format\n", + "config = AutoConfig.from_pretrained(model_args.model_name_or_path)\n", + "if (\n", + " getattr(config, \"quantization_config\", {})\n", + " and config.quantization_config.get(\"quant_method\", None) == \"mxfp4\"\n", + "):\n", + " model_kwargs[\"quantization_config\"] = Mxfp4Config(dequantize=True)" + ] + }, + { + "cell_type": "markdown", + "id": "c773c8f3", + "metadata": {}, + "source": [ + "**Load the Model and Tokenizer**\n", + "\n", + "Load the pre-trained model and tokenizer with the specified configuration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "747f068b", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\n", + " model_args.model_name_or_path,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "43d143e0", + "metadata": {}, + "source": [ + "**Dataset Configuration**\n", + "\n", + "Set up the dataset parameters for training and evaluation. This includes specifying the dataset name, train/test splits, and test size ratio." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3577c9c", + "metadata": {}, + "outputs": [], + "source": [ + "from trl import ScriptArguments\n", + "\n", + "script_args = ScriptArguments(\n", + " dataset_name=\"HuggingFaceH4/Multilingual-Thinking\",\n", + " dataset_train_split=\"train\",\n", + " dataset_test_split=\"test\",\n", + ")\n", + "test_size = 0.1" + ] + }, + { + "cell_type": "markdown", + "id": "dataset-loading-markdown", + "metadata": {}, + "source": [ + "**Load and Prepare Dataset**\n", + "\n", + "Load the dataset and split it into training and evaluation sets. The dataset is split with the specified test size ratio and random seed for reproducibility." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da17e0da", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset(script_args.dataset_name)\n", + "# split the dataset into train and test\n", + "dataset = dataset[script_args.dataset_train_split].train_test_split(test_size=test_size, seed=42)\n", + "train_dataset = dataset[script_args.dataset_train_split]\n", + "eval_dataset = dataset[script_args.dataset_test_split]" + ] + }, + { + "cell_type": "markdown", + "id": "training-config-markdown", + "metadata": {}, + "source": [ + "**Training Configuration**\n", + "\n", + "Configure the training parameters including epochs, batch sizes, learning rate, gradient accumulation, and evaluation strategy. This sets up the SFT configuration for supervised fine-tuning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "260e1104", + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTConfig\n", + "\n", + "training_args = SFTConfig(\n", + " output_dir=\"gpt-oss-20b-multilingual-reasoner\",\n", + " num_train_epochs=0.1,\n", + " learning_rate=2e-5,\n", + " per_device_train_batch_size=1,\n", + " per_device_eval_batch_size=1,\n", + " gradient_accumulation_steps=2,\n", + " max_length=4096,\n", + " warmup_ratio=0.03,\n", + " eval_strategy=\"steps\",\n", + " eval_on_start=True,\n", + " logging_steps=10,\n", + " save_steps=50,\n", + " eval_steps=10,\n", + " save_total_limit=2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "trainer-init-markdown", + "metadata": {}, + "source": [ + "**Initialize Trainer**\n", + "\n", + "Set up the SFT trainer with the model, dataset, and training configuration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7948da7", + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTTrainer\n", + "\n", + "trainer = SFTTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=dataset[script_args.dataset_train_split],\n", + " eval_dataset=dataset[script_args.dataset_test_split],\n", + " processing_class=tokenizer,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "quantization-setup-markdown", + "metadata": {}, + "source": [ + "**Quantization aware Training**\n", + "\n", + "Configure the quantization parameters and prepare the calibration dataset. This step sets up the quantization configuration, creates a calibration subset from the evaluation dataset, and defines a forward loop function for model calibration. The calibration process helps determine optimal quantization scales for the model weights and activations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fef71bc7", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "import modelopt.torch.quantization as mtq\n", + "\n", + "# MXFP4_MLP_WEIGHT_ONLY_CFG doesn't need calibration, but other quantization configurations may require it.\n", + "quantization_config = mtq.MXFP4_MLP_WEIGHT_ONLY_CFG\n", + "calib_size = 128\n", + "\n", + "dataset = torch.utils.data.Subset(\n", + " trainer.eval_dataset, list(range(min(len(trainer.eval_dataset), calib_size)))\n", + ")\n", + "data_loader = trainer.get_eval_dataloader(dataset)\n", + "\n", + "\n", + "def forward_loop(model):\n", + " for data in data_loader:\n", + " model(**data)" + ] + }, + { + "cell_type": "markdown", + "id": "quantization-execution-markdown", + "metadata": {}, + "source": [ + "Apply quantization to the model using the prepared configuration and calibration data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8270fe84", + "metadata": {}, + "outputs": [], + "source": [ + "mtq.quantize(model, quantization_config, forward_loop)" + ] + }, + { + "cell_type": "markdown", + "id": "training-execution-markdown", + "metadata": {}, + "source": [ + "Start the quantization-aware training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73dc36bc", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "id": "model-saving-markdown", + "metadata": {}, + "source": [ + "**Model Saving and Checkpointing**\n", + "\n", + "Save the trained and quantized model with HuggingFace checkpointing enabled to store the modelopt state automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "874863aa", + "metadata": {}, + "outputs": [], + "source": [ + "model.save_pretrained(training_args.output_dir)" ] } ], "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, "language_info": { - "name": "python" + "name": "python", + "version": "3.10.16" } }, "nbformat": 4, diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 7026bb78d..aa26457af 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -67,6 +67,7 @@ "fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG, "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, + "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, } KV_QUANT_CFG_CHOICES = { @@ -88,7 +89,16 @@ def auto_quantize( if args.export_fmt == "hf": assert all( qformat - in ["fp8", "int4_awq", "nvfp4", "nvfp4_awq", "w4a8_awq", "fp8_pb_wo", "w4a8_mxfp4_fp8"] + in [ + "fp8", + "int4_awq", + "nvfp4", + "nvfp4_awq", + "w4a8_awq", + "fp8_pb_wo", + "w4a8_mxfp4_fp8", + "nvfp4_mlp_only", + ] for qformat in qformat_list ), ( "One or more quantization formats provided are not supported for unified checkpoint export" @@ -223,6 +233,7 @@ def main(args): "w4a8_awq", "fp8_pb_wo", "w4a8_mxfp4_fp8", + "nvfp4_mlp_only", ] or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES ), f"Quantization format {args.qformat} not supported for HF export path" @@ -288,6 +299,7 @@ def main(args): args.dataset = ["cnn_dailymail"] warnings.warn("No dataset specified. Defaulting to cnn_dailymail.") tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) + default_padding_side = tokenizer.padding_side # Left padding usually provides better calibration result. tokenizer.padding_side = "left" @@ -386,6 +398,9 @@ def main(args): assert processor is not None and isinstance(processor, MllamaImageProcessor), ( "The MllamaImageProcessor must be set." ) + assert len(args.calib_size) == 1, ( + "mllama only supports one dataset for calibration, can extend this in the future" + ) calib_dataloader = get_vlm_dataset_dataloader( dataset_name=args.dataset[0] if args.dataset else "scienceqa", processor=processor, @@ -396,11 +411,14 @@ def main(args): assert processor is not None and isinstance(processor, WhisperProcessor), ( "The AutoProcessor must be set." ) + assert len(args.calib_size) == 1, ( + "whisper only supports one dataset for calibration, can extend this in the future" + ) calib_dataloader, first_text = get_speech_dataset_dataloader( dataset_name=args.dataset[0] if args.dataset else "peoples_speech", processor=processor, batch_size=args.batch_size, - num_samples=args.calib_size, + num_samples=args.calib_size[0], device=device, dtype=model.dtype, ) @@ -466,6 +484,8 @@ def main(args): ) print(f"Error details: {e}") raise + if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": + print("Applying nvfp4 quantization (MoE only) for gpt-oss") # quantize the model model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only) @@ -533,6 +553,18 @@ def output_decode(generated_ids, input_shape): export_path = args.export_path + if hasattr(model, "language_model"): + # Save original model config and the preprocessor config to the export path for VLMs. + from transformers import AutoConfig, AutoProcessor + + AutoConfig.from_pretrained( + args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code + ).save_pretrained(export_path) + + AutoProcessor.from_pretrained( + args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code + ).save_pretrained(export_path) + if model_type == "mllama": full_model_config = model.config model = model.language_model @@ -608,7 +640,8 @@ def output_decode(generated_ids, input_shape): "--calib_size", help=( "Number of samples for calibration. If a comma separated list of values is provided, " - "each value will be used as the calibration size for the corresponding dataset." + "each value will be used as the calibration size for the corresponding dataset. " + "This argument will be parsed and converted as a list of ints." ), type=str, default="512", diff --git a/examples/llm_qat/llama_factory/README.md b/examples/llm_qat/llama_factory/README.md index 5f32bcb5b..b72c3dbfd 100644 --- a/examples/llm_qat/llama_factory/README.md +++ b/examples/llm_qat/llama_factory/README.md @@ -92,3 +92,17 @@ modelopt: > **_NOTE:_** `compress: true` enables weight compression and will by default use [ddp.yaml](../accelerate_config/ddp.yaml). > **_NOTE:_** When training without [cli](#training-using-cli), avoid using deepspeed option in the YAML configuration file. + +## Deployment + +The final QAT/QAD model after training is similar in architecture to that of PTQ model. It simply has updated weights as compared to the PTQ model. It can be deployed to TensorRT-LLM (TRTLLM) or to TensorRT just like a regular **ModelOpt** PTQ model if the quantization format is supported for deployment. + +To run QAT/QAD model with TRTLLM, run: + +```sh +cd ../../llm_ptq + +./scripts/huggingface_example.sh --model --quant nvfp4 +``` + +See more details on deployment of quantized model [here](../../llm_ptq/README.md). diff --git a/examples/megatron-lm/ADVANCED.md b/examples/megatron-lm/ADVANCED.md new file mode 100644 index 000000000..625c6d358 --- /dev/null +++ b/examples/megatron-lm/ADVANCED.md @@ -0,0 +1,50 @@ +
+ +# Megatron-LM Integration Advanced Usage + +[Kimi-K2-Instruct Slurm Examples](#slurm-examples) | +[Advanved Configuration](#advanced-configuration) | +[Checkpoint Resume](#getting-started-in-a-local-environment) | +[Megatron-LM Integration](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt) + +
+ +## Slurm Examples + +For models that require multi-node, our scripts in Megatron-LM examples also support `slurm` with a sbatch warpper. +Start with the example `slurm/sbatch.sh` with some minor modification or use your existing `sbatch` +script. + +
+ +**⭐ BF16 Kimi-K2-Instruct EAGLE3 Training:** + +Different from local environment, we only allow passing variables through a shell script (default: `.env_setup_template.sh`). +Commandline variable passthrough is not supported. `config/moonshotai/kimi_k2_instruct.sh` is a config that has been tested +with 8 nodes of DGX H100 (TP=8, ETP=1, EP=64, overall 64 H100 GPUs in total). Update `HF_MODEL_CKPT` to the exact +checkpoint path in the container to start: + +```sh +export USER_FSW= +export CONTAINER_IMAGE= +export SANDBOX_ENV_SETUP=./config/moonshotai/kimi_k2_instruct.sh +sbatch --nodes=8 slurm/sbatch.sh "/workspace/Megatron-LM/examples/post_training/modelopt/eagle3.sh moonshotai/Kimi-K2-Instruct" +``` + +To export the trained EAGLE3 model, swtich to `kimi_k2_instruct_export.sh`. +**We only support pipeline-parallel (PP) export.** In this case, 2 nodes are used (PP=16). + +```sh +export USER_FSW= +export CONTAINER_IMAGE= +export SANDBOX_ENV_SETUP=./config/moonshotai/kimi_k2_instruct_export.sh +sbatch --nodes=2 slurm/sbatch.sh "/workspace/Megatron-LM/examples/post_training/modelopt/export.sh moonshotai/Kimi-K2-Instruct" +``` + +## Advanced Configuration + +WIP + +## Checkpoint Resume + +WIP diff --git a/examples/megatron-lm/Dockerfile b/examples/megatron-lm/Dockerfile new file mode 100644 index 000000000..68745fb5e --- /dev/null +++ b/examples/megatron-lm/Dockerfile @@ -0,0 +1,20 @@ +FROM nvcr.io/nvidia/pytorch:25.04-py3 + +ARG PIP_CONSTRAINT= + +RUN pip install jsonlines omegaconf pulp +RUN pip install datasets transformers +RUN pip install tiktoken blobfile +RUN pip install flask flask_restful fire nltk + +WORKDIR /workspace + +RUN git clone https://github.com/NVIDIA/Megatron-LM.git +RUN pip install -e Megatron-LM/ +RUN chmod -R 777 Megatron-LM/ + +RUN git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git +RUN pip install -e TensorRT-Model-Optimizer +RUN chmod -R 777 TensorRT-Model-Optimizer + +WORKDIR /workspace/nmm-sandbox diff --git a/examples/megatron-lm/README.md b/examples/megatron-lm/README.md new file mode 100644 index 000000000..515e7a6c5 --- /dev/null +++ b/examples/megatron-lm/README.md @@ -0,0 +1,143 @@ +
+ +# Megatron-LM Integrated Examples + +[Local Examples](#getting-started-in-a-local-environment) | +[Configuration](#learn-more-about-configuration) | +[Advanced Topics](ADVANCED.md) | +[Megatron-LM Integration](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt) + +
+ +**Major Features:** + +- Start from Hugging Face pretrained model checkpoint with on the fly conversion. +- Support all kinds of model parallelism (TP, EP, ETP, PP). +- Export to TensorRT-LLM, vLLM, and SGLang ready unified checkpoint. + +**Support Matrix: {Model}x{Features}** + +| Model | Quantization | EAGLE3 | Q-LoRA | Distillation | +| ------------------------------------------------------ | -----------| ------ | ----- | ---- | +| `moonshotai/Kimi-K2-Instruct` | ✅ | **Online** | | | +| `Qwen/Qwen3-{30B-A3B, 235B-A22B}` | **WAR** | **Online** | | | +| `Qwen/Qwen3-{0.6B, 8B}` | ✅ | **Online** | | | +| `deepseek-ai/DeepSeek-R1` | ✅ | **Online** | | | +| `meta-llama/Llama-{3.1-8B, 3.1-405B, 3.2-1B}-Instruct` | ✅ | **Online** | | | + +## Getting Started in a Local Environment + +Given that only `megatron.core` can be pip-install, the examples are containerized with +[Megatron-LM](https://github.com/NVIDIA/Megatron-LM) and +[TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) +pre-installed. **Use the following command to build the container.** + +```sh +docker build --no-cache --network=host --rm -t nvidia-modelopt-megatron:latest . +``` + +> **📙 NOTE:** If you plan to use `slurm` for multi-node execution, push the image to a container registry. + +For local execution, a **READ/WRITE scratch space** needs to be mounted. Mount additional volumes for +checkpoints, datasets, and other artifact. + +```sh +USER_FSW= bash interactive.sh +``` + +> **📙 NOTE:** The current dir will be mounted to `$(pwd):/workspace/nmm-sandbox` and the scratch +> space will be mounted to `/workspace/scratch`. + +
+ +**⭐ FP8 Post-Training Quantization (PTQ):** + +Provide the pretrained checkpoint path through variable `${HF_MODEL_CKPT}`: + +```sh +\ + TP=1 \ + HF_MODEL_CKPT= \ + MLM_MODEL_SAVE=/tmp/Llama-3.2-1B-Instruct-FP8 \ + bash megatron-lm/examples/post_training/modelopt/quantize.sh meta-llama/Llama-3.2-1B-Instruct fp8 + +\ + PP=1 \ + HF_MODEL_CKPT= \ + MLM_MODEL_LOAD=/tmp/Llama-3.2-1B-Instruct-FP8 \ + EXPORT_DIR=/tmp/Llama-3.2-1B-Instruct-Export \ + bash megatron-lm/examples/post_training/modelopt/export.sh meta-llama/Llama-3.2-1B-Instruct + +``` + +You can find a resumable Megatron-LM checkpoint for quantization-aware training or simulated evaluation +(`/tmp/Llama-3.2-1B-Instruct-FP8`) and a Hugging Face-Like exported checkpoint for +deployment (`/tmp/Llama-3.2-1B-Instruct-Export`). + +
+ +**⭐ Online BF16 EAGLE3 Training:** + +Online EAGLE3 training has both the target (frozen) and draft models in the memory where the `hidden_states` +required for training is generated on the fly. + +```sh +\ + TP=1 \ + HF_MODEL_CKPT= \ + MLM_MODEL_SAVE=/tmp/Llama-3.2-1B-Eagle3 \ + bash megatron-lm/examples/post_training/modelopt/eagle3.sh meta-llama/Llama-3.2-1B-Instruct + +\ + PP=1 \ + HF_MODEL_CKPT= \ + MLM_MODEL_LOAD=/tmp/Llama-3.2-1B-Eagle3 \ + EXPORT_DIR=/tmp/Llama-3.2-1B-Eagle3-Export \ + bash megatron-lm/examples/post_training/modelopt/export.sh meta-llama/Llama-3.2-1B-Instruct +``` + +Periodically, **acceptance length (AL)** is evaluated on MT-Bench prompts. You can find resumable +Megatron-LM checkpoint (`/tmp/Llama-3.2-1B-Eagle3`) and a Hugging Face-Like exported checkpoiint +for deploymenet (`/tmp/Llama-3.2-1B-Eagle3-Export`). + +See [ADVANVED.md](ADVANCED.md) for a multi-gpu multi-node training example for `moonshotai/Kimi-K2-Instruct`. + +
+ +**⭐ Offline BF16 EAGLE3 Training:** + +Coming soon ... + +## Learn More About Configuration + +For simplicity, we use `shell` scripts and varibles as arguments. Each script has at least 1 positional +argument `[pretrained_model_card]`. Some scripts may require more such as `[qformat]` is needed for +quantization. + +```sh +\ + HF_MODEL_CKPT=[pretrained_checkpoint] \ + bash megatron-lm/examples/post_training/modelopt/quantize.sh [pretrained_model_card] [qformat] +``` + +> **❗ IMPORTANT:** `pretrained_model_card` **CANNOT** be a path to a local pretrained checkpoint. +> It is used to get the corresponding Megatron-LM `${MODEL_ARGS}`. For example, +> `meta-llama/Llama-3.1-8B-Instruct` or `deepseek-ai/DeepSeek-R1` are both supported. +> \ +> Provide the pretrained checkpoint through varible `${HF_MODEL_CKPT}` in commandline or +> in `env_setup_template.sh`. More variables (e.g. `${TP}`, `${EP}`, ...) can be provided though +> commandline but we recommend passing all variable in a another `shell` script. + +When `${HF_MODEL_CKPT}` is not set through the commandline, `./env_setup_template.sh` can be used +to pass all variables instead. If you have your own script, use `${SANDBOX_ENV_SETUP}`. + +``` +\ + SANDBOX_ENV_SETUP= \ + bash megatron-lm/examples/post_training/modelopt/quantize.sh [pretrained_model_card] [qformat] +``` + +If you use our `slurm` script, then you **MUST USE** `${SANDBOX_ENV_SETUP}` (default: `./env_setup_template.sh`). +Other variables are not passing through `sbatch` and `srun` automatically. + +See [ADVANVED.md](ADVANCED.md) to learn all the configurable variables. diff --git a/examples/megatron-lm/config/moonshotai/kimi_k2_instruct.sh b/examples/megatron-lm/config/moonshotai/kimi_k2_instruct.sh new file mode 100644 index 000000000..a3902bc81 --- /dev/null +++ b/examples/megatron-lm/config/moonshotai/kimi_k2_instruct.sh @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +HF_MODEL_CKPT=/workspace/scratch/moonshotai/Kimi-K2-Instruct +TP=8 +ETP=1 +EP=64 diff --git a/examples/megatron-lm/config/moonshotai/kimi_k2_instruct_export.sh b/examples/megatron-lm/config/moonshotai/kimi_k2_instruct_export.sh new file mode 100644 index 000000000..838b8d643 --- /dev/null +++ b/examples/megatron-lm/config/moonshotai/kimi_k2_instruct_export.sh @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +HF_MODEL_CKPT=/workspace/scratch/moonshotai/Kimi-K2-Instruct + +MLM_EXTRA_ARGS=" \ + --decoder-first-pipeline-num-layers 3 \ + --decoder-last-pipeline-num-layers 2 \ + --init-model-with-meta-device \ + --use-cpu-initialization \ + +" + +# Layer distribution over PP: 3, [4] * 14, 2. +PP=16 diff --git a/examples/megatron-lm/env_setup_template.sh b/examples/megatron-lm/env_setup_template.sh new file mode 100644 index 000000000..8fa4c8f2e --- /dev/null +++ b/examples/megatron-lm/env_setup_template.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +HF_MODEL_CKPT=/workspace/scratch/meta-llama/Llama-3.2-1B-Instruct +TP=1 +ETP=1 +EP=1 +PP=1 diff --git a/examples/megatron-lm/interactive.sh b/examples/megatron-lm/interactive.sh new file mode 100644 index 000000000..8a7b0813a --- /dev/null +++ b/examples/megatron-lm/interactive.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ -n "${USER_FSW}" ]; then + echo "USER_FSW is set to ${USER_FSW}" +else + USER_FSW=/tmp +fi + +docker run --gpus all --init -it --rm --network host --ipc=host \ + --user $(id -u):$(id -g) \ + -v $PWD:/workspace/nmm-sandbox \ + -v ${USER_FSW}:/workspace/scratch \ + -v /home/chenhany/projects/nmm-sandbox/modelopt:/workspace/TensorRT-Model-Optimizer \ + nvidia-modelopt-megatron:latest bash diff --git a/examples/megatron-lm/slurm/sbatch.sh b/examples/megatron-lm/slurm/sbatch.sh new file mode 100644 index 000000000..705748d14 --- /dev/null +++ b/examples/megatron-lm/slurm/sbatch.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# CHANGE THE FOLLOWING TO YOUR ACCOUNT AND CHANGE THE JOB NAME TO COMPLY WITH THE +# USAGE. SWITCH TO `-p luna -t 04:00:00` IF YOU HAVE BEEN GRANTED CAPACITY FROM +# THE BIWEEKLY CAPACITY MEETING. IF YOU DON'T KNOW WHO IS THE PIC OF YOUR CSRG PPP +# MANAGEMET, GO WITH `-p backfill -t 00:25:00`. + +#SBATCH -A coreai_dlalgo_llm +#SBATCH -p batch +#SBATCH --job-name=coreai_dlalgo_modelopt-modelopt.mlm.examples +#SBATCH --nodes=1 --ntasks-per-node=8 --gpus-per-node=8 +#SBATCH -t 04:00:00 +#SBATCH --exclusive --mem=0 --overcommit + +# Bash coloring +RED='\033[0;31m' +YELLOW='\033[0;33m' +GREEN='\033[0;32m' +BLUE='\033[0;34m' +PURPLE='\033[0;35m' +WHITE='\033[0;37m' + +# Predefined logging +MLM_ERROR="${RED}ERROR: ${WHITE}" +MLM_WARNING="${YELLOW}WARNING:${WHITE}" + +# CHANGE THE FOLLOWING TO YOUR DATA, MEGATRON, and CHECKPOINT DIR +if [[ -z ${USER_FSW} ]]; then + printf "${MLM_ERROR} Varible USER_FSW (read/write scratch space) must be set!\n" + exit 1 +fi + +if [ -z ${SANDBOX_DIR} ]; then + SANDBOX_DIR="$(pwd)" + printf "${MLM_WARNING} Variable SANDBOX_DIR not set! (default: ${SANDBOX_DIR})\n" +fi + +if [ -z ${SANDBOX_ENV_SETUP} ]; then + SANDBOX_ENV_SETUP=./env_setup_template.sh + printf "${MLM_WARNING} Variable SANDBOX_ENV_SETUP not set! (default: ${SANDBOX_ENV_SETUP})\n" +fi + +if [ -z ${CONTAINER_IMAGE} ]; then + CONTAINER_IMAGE="nvidia-modelopt-megatron:latest" + printf "${MLM_WARNING} Variable CONTAINER_IMAGE not set! (default: ${CONTAINER_IMAGE})\n" +fi + +if [ -z ${LAUNCH_SCRIPT} ]; then + LAUNCH_SCRIPT="python" + printf "${MLM_WARNING} Variable LAUNCH_SCRIPT not set! (default: ${LAUNCH_SCRIPT})\n" +fi + +# DO NOT MODIFY THE VALUES BELOW UNLESS YOU KNOW WHAT YOU ARE DOING!!! +DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + +CONTAINER_MOUNT="${SANDBOX_DIR}:/workspace/nmm-sandbox,${USER_FSW}:/workspace/scratch" + +srun -l \ + --mpi=pmix \ + --output=%x_%j_$DATETIME.log \ + --container-image ${CONTAINER_IMAGE} \ + --container-workdir "/workspace/nmm-sandbox" \ + --container-mounts ${CONTAINER_MOUNT} \ + --export "HF_MODEL_CKPT=${HF_MODEL_CKPT},SANDBOX_ENV_SETUP=${SANDBOX_ENV_SETUP},LAUNCH_SCRIPT=${LAUNCH_SCRIPT}" \ + bash ${1} + +set +x diff --git a/examples/onnx_ptq/download_example_onnx.py b/examples/onnx_ptq/download_example_onnx.py index 62c9b7dd0..4c70ab7cb 100644 --- a/examples/onnx_ptq/download_example_onnx.py +++ b/examples/onnx_ptq/download_example_onnx.py @@ -23,9 +23,7 @@ from modelopt.torch._deploy.utils import get_onnx_bytes -def export_to_onnx( - model, input_shape, onnx_save_path, device, weights_dtype="float32", use_autocast=False -): +def export_to_onnx(model, input_shape, onnx_save_path, device, weights_dtype="fp32"): """Export the torch model to ONNX format.""" # Create input tensor with same precision as model's first parameter input_dtype = model.parameters().__next__().dtype @@ -35,7 +33,6 @@ def export_to_onnx( model=model, dummy_input=(input_tensor,), weights_dtype=weights_dtype, - use_autocast=use_autocast, ) # Write ONNX model to disk @@ -81,7 +78,7 @@ def export_to_onnx( input_shape = (args.batch_size,) + data_config["input_size"] vit_save_path = args.onnx_save_path or "vit_base_patch16_224.onnx" - weights_dtype = "float16" if args.fp16 else "float32" + weights_dtype = "fp16" if args.fp16 else "fp32" export_to_onnx( model, input_shape, diff --git a/examples/onnx_ptq/llm_export.py b/examples/onnx_ptq/llm_export.py index 437895b20..1b949559c 100644 --- a/examples/onnx_ptq/llm_export.py +++ b/examples/onnx_ptq/llm_export.py @@ -16,34 +16,46 @@ """This script is used to export a LLM model to ONNX and perform quantization.""" import argparse +import json import os import shutil +import tempfile import time +from contextlib import contextmanager import onnx import onnx_graphsurgeon as gs import torch from packaging.version import Version -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer -from modelopt.onnx.llm_export.utils.export_utils import ( +import modelopt +from modelopt.onnx.llm_export_utils.export_utils import ( ModelLoader, WrapperModelForCausalLM, llm_to_onnx, ) +from modelopt.onnx.llm_export_utils.quantization_utils import quantize +from modelopt.onnx.llm_export_utils.surgeon_utils import fold_fp8_qdq_to_dq +from modelopt.onnx.quantization.qdq_utils import fp4qdq_to_2dq, quantize_weights_to_int4 +from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.quantization.utils import is_quantized_linear def llm_arguments(): """Parse the arguments for the llm export script.""" parser = argparse.ArgumentParser() parser.add_argument( - "--torch_dir", type=str, help="The folder of HF PyTorch model ckpt", required=False + "--torch_dir", + type=str, + help="The folder of HF PyTorch model ckpt or HuggingFace model name/path (e.g., 'Qwen/Qwen2.5-0.5B-Instruct')", + required=False, ) parser.add_argument( "--dtype", type=str, default="fp16", - choices=["fp16", "fp8", "nvfp4"], + choices=["fp16", "fp8", "int4_awq", "nvfp4"], help="The precision of onnx export", ) @@ -82,24 +94,50 @@ def llm_arguments(): help="The path of config.json, in case it is not with the PyTorch or ONNX file", default=None, ) + parser.add_argument( + "--calib_size", type=int, help="The size of calibration dataset", default=512 + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + default=False, + help="Trust remote code when loading model from HuggingFace Hub", + ) return parser def get_config_path(args): - """Look for config.json. It is recommended to keep a copy per ONNX path. - - Args: - args: argparse.Namespace - - Returns: - str: The path of config.json + """ + Get config.json file path from the arguments. + The default priority is: config_path > torch_dir/config.json > onnx_path/../config.json """ if args.config_path and os.path.exists(args.config_path): return args.config_path if args.torch_dir: - torch_config = os.path.join(args.torch_dir, "config.json") - if os.path.exists(torch_config): - return torch_config + # Check if torch_dir is a local directory + if os.path.isdir(args.torch_dir): + torch_config = os.path.join(args.torch_dir, "config.json") + if os.path.exists(torch_config): + return torch_config + else: + # For HuggingFace model names, download config temporarily + try: + # Download config from HuggingFace + config = AutoConfig.from_pretrained( + args.torch_dir, trust_remote_code=args.trust_remote_code + ) + + # Save to temporary file + temp_config_path = os.path.join( + tempfile.gettempdir(), f"config_{args.torch_dir.replace('/', '_')}.json" + ) + with open(temp_config_path, "w") as f: + json.dump(config.to_dict(), f, indent=2) + + return temp_config_path + except Exception as e: + print(f"Warning: Could not download config for {args.torch_dir}: {e}") + if args.onnx_path: onnx_config = os.path.join(os.path.dirname(args.onnx_path), "config.json") if os.path.exists(onnx_config): @@ -119,6 +157,7 @@ def export_raw_llm( wrapper_cls=WrapperModelForCausalLM, extra_inputs={}, extra_dyn_axes={}, + calib_size=512, ): """Export raw llm model to ONNX and perform quantization. @@ -132,6 +171,7 @@ def export_raw_llm( wrapper_cls: class, Used for wrapping the model extra_inputs: dict, Used for extra inputs extra_dyn_axes: dict, Used for extra dynamic axes + calib_size: int, Used for quantization calibration size """ os.makedirs(output_dir, exist_ok=True) @@ -143,17 +183,23 @@ def export_raw_llm( ) shutil.copy(config_path, os.path.join(output_dir, "config.json")) - # Need to quantize model to fp8, int4 or nvfp4 - if dtype in ["fp8", "nvfp4"]: - # Avoid import modelopt when no quantization is needed - from modelopt.onnx.llm_export.utils.quantization_utils import quantize - from modelopt.torch.export import export_hf_checkpoint - from modelopt.torch.quantization.utils import is_quantized_linear + # Need to quantize model to fp8, int4_awq or nvfp4 + if dtype in ["fp8", "int4_awq", "nvfp4"]: + tokenizer = AutoTokenizer.from_pretrained( + torch_dir, trust_remote_code=args.trust_remote_code + ) + # Only check for local modelopt_state if torch_dir is a local directory + if os.path.isdir(torch_dir): + modelopt_state = os.path.join(torch_dir, "modelopt_state.pth") + model_needs_quantization = not os.path.exists(modelopt_state) + else: + # For HuggingFace model names, always quantize as we can't have local state files + model_needs_quantization = True - tokenizer = AutoTokenizer.from_pretrained(torch_dir, trust_remote_code=True) - modelopt_state = os.path.join(torch_dir, "modelopt_state.pth") - if not os.path.exists(modelopt_state): - model = quantize(model, tokenizer, dtype, lm_head_precision, dataset_dir) + if model_needs_quantization: + model = quantize( + model, tokenizer, dtype, lm_head_precision, dataset_dir, calib_size=calib_size + ) if dtype == "nvfp4": # This is required for nvfp4 ONNX export @@ -164,7 +210,7 @@ def export_raw_llm( module.input_quantizer._onnx_quantizer_type = "dynamic" module.weight_quantizer._onnx_quantizer_type = "static" - if dtype in {"fp8", "nvfp4"}: + if dtype in {"fp8", "int4_awq", "nvfp4"}: print(f"Exporting {dtype} ONNX model from quantized PyTorch model...") llm_to_onnx( wrapper_cls( @@ -205,14 +251,13 @@ def surgeon_llm( """ t0 = time.time() + onnx.shape_inference.infer_shapes_path(raw_onnx_path) graph = gs.import_onnx(onnx.load(raw_onnx_path)) t1 = time.time() print(f"Importing ONNX graph takes {t1 - t0}s.") graph.fold_constants().cleanup().toposort() if dtype == "fp8" or lm_head_precision == "fp8": - from modelopt.onnx.llm_export.utils.surgeon_utils import fold_fp8_qdq_to_dq - graph = fold_fp8_qdq_to_dq(graph) os.makedirs(output_dir, exist_ok=True) @@ -220,13 +265,20 @@ def surgeon_llm( onnx_model = gs.export_onnx(graph) + @contextmanager + def time_operation(operation_name): + start_time = time.time() + yield + end_time = time.time() + print(f"{operation_name} takes {end_time - start_time}s.") + if dtype == "nvfp4": - t4 = time.time() - from modelopt.onnx.quantization.qdq_utils import fp4qdq_to_2dq + with time_operation("quantizing weights to nvfp4"): + onnx_model = fp4qdq_to_2dq(onnx_model, verbose=True) - onnx_model = fp4qdq_to_2dq(onnx_model, verbose=True) - t5 = time.time() - print(f"nvfp4 qdq to 2 dqs inserted in {t5 - t4}.") + elif dtype == "int4_awq": + with time_operation("quantizing weights to int4"): + onnx_model = quantize_weights_to_int4(onnx_model) output_onnx_name = f"{output_dir}/model.onnx" print( @@ -253,12 +305,16 @@ def surgeon_llm( ) if os.path.exists(config_path): - if config_path.endswith("config.json"): + if os.path.isfile(config_path) and config_path.endswith("config.json"): + # config_path is already a config.json file shutil.copy(config_path, os.path.join(output_dir, "config.json")) - else: + elif os.path.isdir(config_path): + # config_path is a directory containing config.json shutil.copy( os.path.join(config_path, "config.json"), os.path.join(output_dir, "config.json") ) + else: + print(f"Warning: Unexpected config_path format: {config_path}") t3 = time.time() print(f"Surgeon LLM completed in {t3 - t2}s.") @@ -273,8 +329,6 @@ def check_dtype_support(args): def get_modelopt_version(): try: - import modelopt - return Version(modelopt.__version__) except Exception as e: print(f"Modelopt version cannot be parsed. Reason: {e!s}") @@ -324,6 +378,7 @@ def main(args): wrapper_cls=WrapperModelForCausalLM, extra_inputs=extra_inputs, extra_dyn_axes=extra_dyn_axes, + calib_size=args.calib_size, ) # Providing the config path to config.json results in a hf validation error for internvl_chat. diff --git a/examples/onnx_ptq/torch_quant_to_onnx.py b/examples/onnx_ptq/torch_quant_to_onnx.py index 5bcb51b1c..418f1d7e5 100644 --- a/examples/onnx_ptq/torch_quant_to_onnx.py +++ b/examples/onnx_ptq/torch_quant_to_onnx.py @@ -144,7 +144,6 @@ def main(): # Quantize model quantized_model = quantize_model(model, config, data_loader) - use_autocast = args.quantize_mode not in ["mxfp8", "int4_awq"] # Export to ONNX export_to_onnx( @@ -152,8 +151,7 @@ def main(): input_shape, args.onnx_save_path, device, - weights_dtype="float16", - use_autocast=use_autocast, + weights_dtype="fp16", ) print(f"Quantized ONNX model is saved to {args.onnx_save_path}") diff --git a/examples/speculative_decoding/ar_validate.py b/examples/speculative_decoding/ar_validate.py index be59f4ee4..bfbcb2239 100644 --- a/examples/speculative_decoding/ar_validate.py +++ b/examples/speculative_decoding/ar_validate.py @@ -17,6 +17,7 @@ from accelerate import Accelerator from datasets import load_dataset +from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer import modelopt.torch.opt as mto @@ -25,6 +26,31 @@ mto.enable_huggingface_checkpointing() +def validate_ar(model, tokenizer, ds, steps=3, osl=20, num_samples=20, device=None): + validator = HFARValidation(model, tokenizer) + num_samples = min(num_samples, len(ds)) + ars = [] + for i in tqdm(range(num_samples), desc="Validating AR"): + prompt = ds[i]["prompt"][0] + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + # Apply chat template to the prompt, continuing with assistant response + if hasattr(tokenizer, "apply_chat_template"): + chat_messages = [ + {"role": "user", "content": prompt}, + ] + prompt = tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + if device: + input_ids = input_ids.to(device) + + # validate AR + _, ar = validator.validate(osl, input_ids=input_ids, steps=steps) + ars.append(ar) + return ars + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, required=True, help="Path to model directory") @@ -35,6 +61,12 @@ def main(): parser.add_argument( "--num_samples", type=int, default=20, help="Number of MT-Bench samples to use" ) + parser.add_argument( + "--ar_lower_bound", + type=float, + default=None, + help="AR lower bound for validation. If provided, will throw error if AR is below threshold.", + ) args = parser.parse_args() accelerator = Accelerator() @@ -43,32 +75,20 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(args.model_path) model.eval() model = accelerator.prepare(model) - validator = HFARValidation(model, tokenizer) # Load MT-Bench prompts from HuggingFace ds = load_dataset("HuggingFaceH4/mt_bench_prompts")["train"] - num_samples = min(args.num_samples, len(ds)) - ars = [] - - for i in range(num_samples): - prompt = ds[i]["prompt"][0] - input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(accelerator.device) - # Apply chat template to the prompt, continuing with assistant response - if hasattr(tokenizer, "apply_chat_template"): - chat_messages = [ - {"role": "user", "content": prompt}, - ] - prompt = tokenizer.apply_chat_template( - chat_messages, tokenize=False, add_generation_prompt=True + ars = validate_ar( + model, tokenizer, ds, args.steps, args.osl, args.num_samples, accelerator.device + ) + # Optionally, throw error if AR is below lower bound + if args.ar_lower_bound: + mean_ar = sum(ars) / len(ars) + if mean_ar < args.ar_lower_bound: + raise ValueError( + f"AR is below lower bound {args.ar_lower_bound}. Mean AR: {mean_ar:.4f}" ) - input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(accelerator.device) - - # validate AR - _, ar = validator.validate(args.osl, input_ids=input_ids, steps=args.steps) - ars.append(ar) - if accelerator.is_main_process: - print(f"[{i + 1}/{num_samples}] Prompt: {prompt[:60]}... | AR: {ar:.4f}") - + # Print results if ars and accelerator.is_main_process: avg_ar = sum(ars) / len(ars) print("\n==== AR Validation Results on MT-Bench ====") diff --git a/examples/speculative_decoding/calibrate_draft_vocab.py b/examples/speculative_decoding/calibrate_draft_vocab.py index 90ebe2e3d..1211d4eb6 100644 --- a/examples/speculative_decoding/calibrate_draft_vocab.py +++ b/examples/speculative_decoding/calibrate_draft_vocab.py @@ -27,7 +27,13 @@ def main(): parser = argparse.ArgumentParser(description="Calibrate draft vocab and save to .pt file") parser.add_argument("--model", type=str, required=True, help="Model name or path for tokenizer") parser.add_argument("--data", type=str, required=True, help="Path to training data (jsonl)") - parser.add_argument("--draft_vocab_size", type=int, required=True, help="Draft vocab size") + parser.add_argument( + "--eagle_config", + type=str, + required=True, + default="eagle_config.json", + help="Path to eagle_config.json", + ) parser.add_argument( "--calibrate_size", type=int, @@ -39,6 +45,12 @@ def main(): ) args = parser.parse_args() + with open(args.eagle_config) as f: + eagle_config = json.load(f) + if "draft_vocab_size" not in eagle_config: + print("No draft vocab size specified in eagle_config.json, no need to calibrate for d2t.") + return + print("Calibrating vocab...") tokenizer = AutoTokenizer.from_pretrained(args.model) with open(args.data) as f: @@ -47,7 +59,7 @@ def main(): conversations = conversations[: args.calibrate_size] conversations = [item for sublist in conversations for item in sublist] - d2t = calibrate_frequent_vocab(tokenizer, conversations, args.draft_vocab_size) + d2t = calibrate_frequent_vocab(tokenizer, conversations, eagle_config["draft_vocab_size"]) model_name = os.path.basename(os.path.normpath(args.model)) vocab_path = os.path.join(args.save_dir, model_name, "d2t.pt") os.makedirs(os.path.dirname(vocab_path), exist_ok=True) diff --git a/examples/speculative_decoding/eagle_config.json b/examples/speculative_decoding/eagle_config.json new file mode 100644 index 000000000..55ff948ed --- /dev/null +++ b/examples/speculative_decoding/eagle_config.json @@ -0,0 +1,3 @@ +{ + "draft_vocab_size": 32000 +} diff --git a/examples/speculative_decoding/launch.sh b/examples/speculative_decoding/launch.sh index b4ffd8bba..fa9e899c8 100755 --- a/examples/speculative_decoding/launch.sh +++ b/examples/speculative_decoding/launch.sh @@ -62,13 +62,9 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi MEDUSA_NUM_LAYERS="${1#*=}" ;; - --eagle_num_layers*) + --eagle_config*) if [[ "$1" != *=* ]]; then shift; fi - EAGLE_NUM_LAYERS="${1#*=}" - ;; - --draft_vocab_size*) - if [[ "$1" != *=* ]]; then shift; fi - DRAFT_VOCAB_SIZE="${1#*=}" + EAGLE_CONFIG="${1#*=}" ;; --fsdp_transformer_layer_cls_to_wrap*) if [[ "$1" != *=* ]]; then shift; fi @@ -106,8 +102,6 @@ LR=${LR:-"1e-4"} TRAIN_BS=${TRAIN_BS:-4} MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1} MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1} -EAGLE_NUM_LAYERS=${EAGLE_NUM_LAYERS:-1} -DRAFT_VOCAB_SIZE=${DRAFT_VOCAB_SIZE:-0} REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1} REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1} FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"} @@ -117,18 +111,20 @@ TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} if [[ "$MODE" == "medusa" ]]; then SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" -elif [[ "$MODE" == "eagle" ]]; then - SPECULATIVE_ARGS="--eagle_num_layers $EAGLE_NUM_LAYERS --draft_vocab_size $DRAFT_VOCAB_SIZE" +elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then + if [[ -n "$EAGLE_CONFIG" ]]; then + SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" + else + SPECULATIVE_ARGS="" + fi else - echo "Only medusa and eagle supported for now!" + echo "Only medusa, eagle1, eagle3 supported for now!" exit 1 fi if [[ "$NUM_GPU" == 1 ]]; then - FSDP_ARGS="" MULTI_GPU="" else - FSDP_ARGS="--fsdp 'full_shard auto_wrap' --fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP" MULTI_GPU="--multi_gpu" fi @@ -151,10 +147,9 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ --weight_decay 0.0 \ --warmup_steps 100 \ --lr_scheduler_type linear \ - --logging_steps 1 \ + --logging_steps 100 \ --tf32 True \ --data_path $DATA \ - $FSDP_ARGS \ $SPECULATIVE_ARGS " diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 52d26fa6b..7ce46c55d 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -36,14 +36,15 @@ import torch import transformers +from ar_validate import validate_ar +from datasets import load_dataset from eagle_utils import make_eagle_supervised_data_module from medusa_utils import make_medusa_supervised_data_module -from transformers import Trainer +from transformers import Trainer, TrainerCallback from transformers.trainer_utils import get_last_checkpoint import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.utils import calibrate_frequent_vocab from modelopt.torch.utils import print_rank_0 torch.manual_seed(0) @@ -66,10 +67,6 @@ class DataArguments: default="draft_vocab_cache", metadata={"help": "Path to the d2t cache directory."}, ) - calibrate_size: int = field( - default=None, - metadata={"help": "Size of the calibration data. If None, use entire training set."}, - ) @dataclass @@ -85,7 +82,8 @@ class TrainingArguments(transformers.TrainingArguments): ) dataloader_drop_last: bool = field(default=True) bf16: bool = field(default=True) - mode: Literal["eagle", "medusa"] = "medusa" + mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3" + ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."}) @dataclass @@ -96,11 +94,7 @@ class MedusaArguments: @dataclass class EagleArguments: - eagle_num_layers: int | None = field(default=1) - use_input_layernorm_in_first_layer: bool | None = field(default=True) - use_last_layernorm: bool | None = field(default=True) - use_aux_hidden_state: bool | None = field(default=True) - draft_vocab_size: int | None = field(default=32000) + eagle_config: str = field(default=None, metadata={"help": "Path to eagle_config.json"}) def train(): @@ -156,57 +150,78 @@ def train(): "medusa_num_layers": medusa_args.medusa_num_layers, } mtsp.convert(model, [("medusa", config)]) - elif training_args.mode == "eagle": + elif training_args.mode in ["eagle1", "eagle3"]: + from modelopt.torch.speculative.config import EAGLE1_DEFAULT_CFG, EAGLE3_DEFAULT_CFG + + # Load default config config = { - "eagle_num_layers": eagle_args.eagle_num_layers, - "use_input_layernorm_in_first_layer": eagle_args.use_input_layernorm_in_first_layer, - "use_last_layernorm": eagle_args.use_last_layernorm, - "use_aux_hidden_state": eagle_args.use_aux_hidden_state, - "draft_vocab_size": eagle_args.draft_vocab_size, - } + "eagle1": EAGLE1_DEFAULT_CFG, + "eagle3": EAGLE3_DEFAULT_CFG, + }[training_args.mode]["config"] + + # overwrite config with custom config + if eagle_args.eagle_config: + with open(eagle_args.eagle_config) as f: + custom_config = json.load(f) + config["eagle_architecture_config"].update(custom_config) + + # Hidden size and vocab size must match base model + config["eagle_architecture_config"].update( + { + "hidden_size": model.config.hidden_size, + "vocab_size": model.config.vocab_size, + "draft_vocab_size": custom_config["draft_vocab_size"] + if eagle_args.eagle_config and "draft_vocab_size" in custom_config + else model.config.vocab_size, + } + ) mtsp.convert(model, [("eagle", config)]) - if eagle_args.draft_vocab_size > 0 and ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 - ): - model_name = os.path.basename(os.path.normpath(model_args.model_name_or_path)) - - vocab_cache_path = os.path.join( - data_args.draft_vocab_cache_dir, model_name, "d2t.pt" - ) - if os.path.exists(vocab_cache_path): - vocab_cache = torch.load(vocab_cache_path) - if len(vocab_cache) == eagle_args.draft_vocab_size: - model.eagle_module.d2t = vocab_cache - print_rank_0(f"Loaded draft vocab cache from {vocab_cache_path}.") - else: - print_rank_0( - "No matching draft vocab cache found, calibrating vocab using training set..." - ) - with open(data_args.data_path) as f: - calibrate_conversations = [json.loads(line)["conversations"] for line in f] - if data_args.calibrate_size: - calibrate_conversations = calibrate_conversations[ - : data_args.calibrate_size - ] - calibrate_conversations = [ - item for sublist in calibrate_conversations for item in sublist - ] - - model.eagle_module.d2t = calibrate_frequent_vocab( - tokenizer, calibrate_conversations, eagle_args.draft_vocab_size + # read draft vocab cache + if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size: + try: + model_name = os.path.basename(os.path.normpath(model_args.model_name_or_path)) + vocab_cache_path = os.path.join( + data_args.draft_vocab_cache_dir, model_name, "d2t.pt" ) + vocab_cache = torch.load(vocab_cache_path) + model.eagle_module.d2t = vocab_cache + print_rank_0(f"Loaded draft vocab cache from {vocab_cache_path}.") + except Exception as e: + raise e else: raise Exception(f"{training_args.mode} is not supported!") print_rank_0("Loading dataset...") if training_args.mode == "medusa": data_module = make_medusa_supervised_data_module(tokenizer, data_args) - elif training_args.mode == "eagle": + elif training_args.mode in ["eagle1", "eagle3"]: data_module = make_eagle_supervised_data_module(tokenizer, data_args) - trainer = Trainer(model=model, processing_class=tokenizer, args=training_args, **data_module) + class ARValidationCallback(TrainerCallback): + def __init__(self, ar_validate_steps: int = 500): + self.ar_validate_steps = ar_validate_steps + + def on_step_end(self, args, state, control, **kwargs): + if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: + print_rank_0("Running AR validation...") + ars = validate_ar( + model=kwargs["model"], + tokenizer=kwargs["processing_class"], + ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], + device=kwargs["model"].device, + ) + print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") + return control + + trainer = Trainer( + model=model, + processing_class=tokenizer, + args=training_args, + callbacks=[ARValidationCallback(training_args.ar_validate_steps)], + **data_module, + ) trainer._move_model_to_device(model, trainer.args.device) # Manually enable this to return loss in eval diff --git a/modelopt/deploy/llm/generate.py b/modelopt/deploy/llm/generate.py index f7f768455..6b97cb1a8 100644 --- a/modelopt/deploy/llm/generate.py +++ b/modelopt/deploy/llm/generate.py @@ -16,6 +16,7 @@ """A wrapper over the TensorRT-LLM high level API runner.""" import json +import warnings from collections.abc import Iterable from pathlib import Path from typing import Any @@ -178,7 +179,26 @@ def __init__( self._build_torch_llm_from_config( checkpoint_dir, tokenizer, tp, trust_remote_code, max_batch_size ) - self._max_seq_len = config["max_position_embeddings"] + + def _find_max_position_embeddings(cfg: dict) -> int | None: + if "max_position_embeddings" in cfg: + return cfg["max_position_embeddings"] + for v in cfg.values(): + if isinstance(v, dict): + res = _find_max_position_embeddings(v) + if res is not None: + return res + return None + + # Some VLMs may have a sub-config for max_position_embeddings, so we need to find it. + self._max_seq_len = _find_max_position_embeddings(config) + if self._max_seq_len is None: + warnings.warn( + "max_position_embeddings not found in config.json, using default value 8192" + ) + self._max_seq_len = 8192 + else: + print(f"max_position_embeddings: {self._max_seq_len}") self._max_beam_width = 1 self._gather_context_logits = False diff --git a/modelopt/torch/speculative/mtp/__init__.py b/modelopt/onnx/llm_export_utils/__init__.py similarity index 75% rename from modelopt/torch/speculative/mtp/__init__.py rename to modelopt/onnx/llm_export_utils/__init__.py index 44de7f211..2ab754f31 100644 --- a/modelopt/torch/speculative/mtp/__init__.py +++ b/modelopt/onnx/llm_export_utils/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Eagle Optimization Method.""" - -from .conversion import * -from .mtp_model import * +"""Utilities for exporting LLM models to ONNX.""" diff --git a/modelopt/onnx/llm_export/utils/export_utils.py b/modelopt/onnx/llm_export_utils/export_utils.py similarity index 100% rename from modelopt/onnx/llm_export/utils/export_utils.py rename to modelopt/onnx/llm_export_utils/export_utils.py diff --git a/modelopt/onnx/llm_export_utils/quantization_utils.py b/modelopt/onnx/llm_export_utils/quantization_utils.py new file mode 100644 index 000000000..98e95e6a3 --- /dev/null +++ b/modelopt/onnx/llm_export_utils/quantization_utils.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantization utilities for LLM models.""" + +import time + +import modelopt.torch.quantization as mtq +from modelopt.torch.utils.dataset_utils import get_dataset_dataloader + + +def _quantize_model(model, quant_config, calib_dataloader=None): + """The calibration loop for the model can be setup using the modelopt API. + + Example usage: + from modelopt.torch.utils.dataset_utils import create_forward_loop + model = ... # Initilaize the model + tokenizer = ... # Initilaize the tokenizer + quant_cfg = ... # Setup quantization configuration + forward_loop = create_forward_loop(model=model, dataset_name="cnn_dailymail", tokenizer=tokenizer) + mtq.quantize(model, quant_cfg, forward_loop=forward_loop) + """ + + def calibrate_loop(model): + """Adjusts weights and scaling factors based on selected algorithms.""" + for idx, data in enumerate(calib_dataloader): + if idx % 10 == 0: + print(f"Calibrating batch {idx}...") + if isinstance(data, dict): + data = {k: v.to(model.device) for k, v in data.items()} + model(**data) + else: + data = data.to(model.device) + model(data) + + print("Starting quantization...") + start_time = time.time() + mtq.quantize(model, quant_config, forward_loop=calibrate_loop) + end_time = time.time() + print(f"Quantization finishes in {end_time - start_time}s.") + + return model + + +def get_quant_config(precision, lm_head_precision="fp16"): + """Get the quantization configuration.""" + if precision == "fp8": + quant_cfg = mtq.FP8_DEFAULT_CFG + + elif precision == "nvfp4": + quant_cfg = mtq.NVFP4_DEFAULT_CFG + + elif precision == "int4_awq": + quant_cfg = mtq.INT4_AWQ_CFG + + else: + raise ValueError(f"Unsupported precision: {precision}") + + config_dict = quant_cfg["quant_cfg"] # type: dict + + if lm_head_precision == "fp8": + config_dict["*lm_head.input_quantizer"] = {"num_bits": (4, 3), "axis": None} + config_dict["*lm_head.weight_quantizer"] = {"num_bits": (4, 3), "axis": None} + elif lm_head_precision == "nvfp4": + config_dict["*lm_head.input_quantizer"] = { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + } + config_dict["*lm_head.weight_quantizer"] = { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + } + return quant_cfg + + +def quantize( + model, tokenizer, precision, lm_head_precision="fp16", dataset_dir=None, calib_size=512 +): + """Quantize the PyTorch model to fp8 or int4_awq.""" + assert precision in [ + "fp8", + "int4_awq", + "nvfp4", + ], ( + f"Only fp8(W8A8), int4_awq(W4A16), nvfp4(W4A4) is supported. You passed an unsupported precision: {precision}." + ) + + assert lm_head_precision in ["fp16"], ( + f"Only fp16(unquantized) is supported for lm_head. You passed an unsupported precision: {lm_head_precision}." + ) + + if tokenizer.pad_token != "": # nosec B105 + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if not dataset_dir: + dataset_dir = "cnn_dailymail" + + batch_size = 1 + data_loader = get_dataset_dataloader( + dataset_name=dataset_dir, tokenizer=tokenizer, batch_size=batch_size, num_samples=calib_size + ) + quant_config = get_quant_config(precision, lm_head_precision) + quantized_model = _quantize_model(model, quant_config, data_loader) + mtq.print_quant_summary(quantized_model) + return quantized_model diff --git a/modelopt/onnx/llm_export/utils/surgeon_utils.py b/modelopt/onnx/llm_export_utils/surgeon_utils.py similarity index 100% rename from modelopt/onnx/llm_export/utils/surgeon_utils.py rename to modelopt/onnx/llm_export_utils/surgeon_utils.py diff --git a/modelopt/onnx/logging_config.py b/modelopt/onnx/logging_config.py index a15533343..fd0c306a6 100644 --- a/modelopt/onnx/logging_config.py +++ b/modelopt/onnx/logging_config.py @@ -58,14 +58,11 @@ def configure_logging(level=logging.INFO, log_file=None): file=sys.stderr, ) print("[modelopt][onnx] - INFO - Falling back to console logging.", file=sys.stderr) - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - else: - # If no log_file is specified, log to stdout - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) + + # Setup handler to print log in stdout + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) # Prevent log messages from propagating to the root logger logger.propagate = False diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index ccef8e99f..84abc3e0f 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -912,6 +912,36 @@ def find_nodes_from_mha_to_exclude( return [*set(nodes_to_exclude)] # type: ignore[arg-type] +def validate_op_types_spelling(onnx_path, op_types_to_quantize, op_types_to_exclude) -> None: + """Validate spelling in op types.""" + + def find_item_ignore_case(target, arr): + target_lower = target.lower() + for item in arr: + if item.lower() == target_lower: + return item + return None + + model = onnx.load(onnx_path, load_external_data=True) + op_types = {node.op_type for node in model.graph.node} + for op_type in op_types: + if op_types_to_quantize: + op_to_quant = find_item_ignore_case(op_type, op_types_to_quantize) + if op_type not in op_types_to_quantize and op_to_quant is not None: + logger.warning( + f"Model contains '{op_type}' ops, but you're requesting '{op_to_quant}' " + f"to be quantized, which is not a match. Please ensure that the lower/uppercasing is correct." + ) + if op_types_to_exclude: + op_to_exclude = find_item_ignore_case(op_type, op_types_to_exclude) + if op_type not in op_types_to_exclude and op_to_exclude is not None: + logger.warning( + f"Model contains '{op_type}' ops, but you're requesting '{op_to_exclude}' " + f"to be excluded from quantization, which is not a match. " + f"Please ensure that the lower/uppercasing is correct." + ) + + def cast_custom_ops(onnx_model: onnx.ModelProto, ops_to_cast: dict) -> onnx.ModelProto: """Adds cast_to_fp16 nodes to the inputs and cast_to_fp32 to the outputs of a layer in the requested indices.""" logger.info("Casting custom ops in the requested inputs and outputs") diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index 1369813b2..33a13d313 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -63,6 +63,16 @@ cupy_warning_msg = f"Using slower INT4 ONNX quantization using numpy: {e}" +# By Default, Gather nodes will not be quantized to INT4. If caller supplies some specific axis +# value (0 / 1 / -1) for kwqrgs "gather_quantize_axis" then Gather nodes will be inspected for +# INT4 quantization. +DEFAULT_GATHER_QUANTIZE_AXIS = None + +# Using block-size of 128 is seen to degrade accuracy with some models like DeepSeek-7B. +# So, setting the default to 32. User can change it using kwargs "gather_block_size" as needed +# for further tuning or experiments. +DEFAULT_GATHER_BLOCK_SIZE = 32 + NUM_BITS = 4 INT4_SCALE = 7.0 INT4_MIN = -(2 ** (NUM_BITS - 1)) # -8 @@ -79,36 +89,56 @@ def _next_block_size_multiple(x: float, block_size: int) -> float: return math.ceil(x / block_size) * block_size -def _pad(w: np.ndarray, block_size: int) -> np.ndarray: - """Pads `w` to next largest multiple of block_size, on axis 0.""" - if w.shape[0] % block_size == 0: +def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray: + """Pads `w` to next largest multiple of block_size, on quantize_axis.""" + assert quantize_axis <= len(w.shape), ( + f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}" + ) + + if w.shape[quantize_axis] % block_size == 0: return w - pad_width = _next_block_size_multiple(w.shape[0], block_size) - w.shape[0] + pad_width = ( + _next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis] + ) pads = [(0, 0) for _ in range(len(w.shape))] - pads[0] = (0, pad_width) + pads[quantize_axis] = (0, pad_width) return np.pad(w, pads, mode="constant", constant_values=0) -def _depad(w: np.ndarray, orig_shape: tuple) -> np.ndarray: - """Depad axis 0 to original shape.""" +def _depad(w: np.ndarray, orig_shape: tuple, quantize_axis: int = 0) -> np.ndarray: + """Depad quantize_axis to original shape.""" if w.shape == orig_shape: return w - return w[0 : orig_shape[0], ...] + ans = None + if quantize_axis == 0: + ans = w[0 : orig_shape[0], ...] + elif quantize_axis == 1: + ans = w[..., 0 : orig_shape[1]] + else: + raise ValueError("Incorrect Quantize-axis: it must be 0 or 1 for a 2D array") + return ans -def find_scales(w: np.ndarray, block_size: int, alpha: float = 1.0, use_zero_point: bool = False): +def find_scales( + w: np.ndarray, + block_size: int, + quantize_axis: int = 0, + alpha: float = 1.0, + use_zero_point: bool = False, +): """Find scale factors for `w` via `s = max(w.block(block_size)) / 7`.""" - w = _pad(w, block_size) - w = w.T + w = _pad(w, block_size, quantize_axis) + if quantize_axis == 0: + w = w.T s_last_dim = w.shape[-1] // block_size s_shape = list(w.shape) s_shape[-1] = s_last_dim + z = None if not use_zero_point: w_amax = np.abs(w.reshape(-1, block_size)).max(axis=-1) s = (w_amax * alpha) / INT4_SCALE - s = s.reshape(s_shape).T - z = None + s = s.reshape(s_shape) else: max_val = w.reshape(-1, block_size).max(axis=-1) min_val = w.reshape(-1, block_size).min(axis=-1) @@ -122,42 +152,125 @@ def find_scales(w: np.ndarray, block_size: int, alpha: float = 1.0, use_zero_poi temp = temp.clip(min=min_int, max=max_int) z = temp assert s.shape == z.shape, "s and z shape mismatch" - s = s.reshape(s_shape).T - z = z.reshape(s_shape).T + s = s.reshape(s_shape) + z = z.reshape(s_shape) + assert z is None or use_zero_point is True, "zero-point value and use-zero-point not in sync" + if quantize_axis == 0: + s = s.T + if z is not None: + z = z.T return s, z -def rtn(w: np.ndarray, s: np.ndarray, block_size: int, zp: np.ndarray = None) -> np.ndarray: +def rtn( + w: np.ndarray, s: np.ndarray, block_size: int, quantize_axis: int = 0, zp: np.ndarray = None +) -> np.ndarray: """Quantizes `w` with scale factors `s` via Round-to-Nearest. Ties are broken by rounding to the nearest even number. """ - w_padded = _pad(w, block_size) - num_blocks = w_padded.shape[0] // s.shape[0] + w_padded = _pad(w, block_size, quantize_axis) + num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] if zp is None: w_padded = ( - np.rint(w_padded / s.repeat(num_blocks, axis=0)) + np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) .clip(INT4_MIN, INT4_MAX) .astype(np.int8) ) else: w_padded = ( - (np.rint(w_padded / s.repeat(num_blocks, axis=0)) + zp.repeat(num_blocks, axis=0)) + ( + np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) + + zp.repeat(num_blocks, axis=quantize_axis) + ) .clip(UINT4_MIN, UINT4_MAX) .astype(np.int8) ) - return _depad(w_padded, w.shape) + return _depad(w_padded, w.shape, quantize_axis) -def dq_tensor(w: np.ndarray, s: np.ndarray, block_size: int, zp: np.ndarray = None) -> np.ndarray: +def dq_tensor( + w: np.ndarray, s: np.ndarray, block_size: int, quantize_axis: int = 0, zp: np.ndarray = None +) -> np.ndarray: """Dequantizes `w` with scale factors `s`.""" - w_padded = _pad(w, block_size) - num_blocks = w_padded.shape[0] // s.shape[0] + w_padded = _pad(w, block_size, quantize_axis) + num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] if zp is None: - w_padded = w_padded * s.repeat(num_blocks, axis=0) + w_padded = w_padded * s.repeat(num_blocks, axis=quantize_axis) else: - w_padded = (w_padded - zp.repeat(num_blocks, axis=0)) * s.repeat(num_blocks, axis=0) - return _depad(w_padded, w.shape) + w_padded = (w_padded - zp.repeat(num_blocks, axis=quantize_axis)) * s.repeat( + num_blocks, axis=quantize_axis + ) + return _depad(w_padded, w.shape, quantize_axis) + + +def _quantize_gather_nodes( + graph: onnx.GraphProto, + nodes_to_exclude: list[str], + gather_quantize_axis: int, + block_size: int, + use_zero_point: bool, + dq_only: bool, +): + """Return scale, zero-point, and weights for quantizable gather nodes using INT4 RTN.""" + t = time.time() + weights_map = {} + scales_map = {} + zero_point_map = {} + for node in graph.nodes: + if node.op == "Gather" and node.name not in nodes_to_exclude: + for in_tensor in node.inputs: + if not isinstance(in_tensor, gs.Constant): + continue + if len(in_tensor.values.shape) == 1: + # 1D blocked quantization not supported. + continue + name = in_tensor.name + w = in_tensor.values + s, zp = find_scales( + np.asarray(w), + block_size, + quantize_axis=gather_quantize_axis, + use_zero_point=use_zero_point, + ) + s = s.astype(w.dtype) + scales_map[name] = s + weight_dtype = numpy.int8 + if zp is not None: + assert use_zero_point is True, ( + "Found zero-point tensor but zero-point is disabled" + ) + weight_dtype = numpy.uint8 + zp = zp.astype(weight_dtype) + zero_point_map[name] = zp + if dq_only: + qw = rtn( + np.asarray(w), + s, + block_size, + quantize_axis=gather_quantize_axis, + zp=zp if zp is None else zp.astype(w.dtype), + ) + weights_map[name] = qw.astype(weight_dtype) + else: + weights_map[name] = in_tensor + if has_cupy: + for name in scales_map: + scales_map[name] = np.asnumpy(scales_map[name]) + for name in zero_point_map: + zero_point_map[name] = np.asnumpy(zero_point_map[name]) + if dq_only: + for name in weights_map: + weights_map[name] = np.asnumpy(weights_map[name]) + + num_gather_nodes_quantized = len(weights_map) + if num_gather_nodes_quantized > 0: + logger.info( + f"Quantizing {num_gather_nodes_quantized} Gather nodes took {time.time() - t} seconds" + ) + else: + logger.info("Found 0 Gather nodes to quantize") + return weights_map, scales_map, zero_point_map def quantize_rtn( @@ -165,6 +278,7 @@ def quantize_rtn( block_size: int, dq_only: bool = False, nodes_to_exclude: list[str] = [], + **kwargs: Any, ) -> onnx.ModelProto: """Quantizes `onnx_model` using the RTN (Round-to-Nearest) algorithm. @@ -218,6 +332,20 @@ def quantize_rtn( # Import the update graph graph = gs.import_onnx(onnx_model) + gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE) + gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS) + gather_w_map = None + gather_s_map = None + if gather_quantize_axis is not None: + gather_w_map, gather_s_map, _ = _quantize_gather_nodes( + graph, + nodes_to_exclude, + gather_quantize_axis, + gather_block_size, + use_zero_point=False, + dq_only=dq_only, + ) + if dq_only: # Calculate actual quantized weights. logger.info("Computing quantized weights for DQ-only mode") @@ -231,20 +359,32 @@ def quantize_rtn( gemm_weights_quantized[name] = numpy.asarray(qw) qdq.insert_dq_nodes(graph, scales, quantized_weights=gemm_weights_quantized) + if gather_w_map is not None: + assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" + qdq.insert_dq_nodes(graph, gather_s_map, quantized_weights=gather_w_map) else: if has_cupy: for name in scales: scales[name] = np.asnumpy(scales[name]) qdq.insert_qdq_nodes(graph, scales, weight_map=gemm_tensors) + if gather_w_map is not None: + assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" + qdq.insert_qdq_nodes(graph, gather_s_map, weight_map=gather_w_map) logger.info(f"RTN quantization completed in {time.time() - t_start:.2f} seconds") return gs.export_onnx(graph) -def quant_tensor(w: np.ndarray, block_size: int, alpha: float = 1.0, use_zero_point: bool = False): +def quant_tensor( + w: np.ndarray, + block_size: int, + quantize_axis: int = 0, + alpha: float = 1.0, + use_zero_point: bool = False, +): """Quantize a tensor using alpha etc. and return the quantized tensor.""" - scale, zp = find_scales(w, block_size, alpha, use_zero_point) - wq = rtn(w, scale, block_size, zp) + scale, zp = find_scales(w, block_size, quantize_axis, alpha, use_zero_point) + wq = rtn(w, scale, block_size, quantize_axis, zp) return wq, scale, zp @@ -324,7 +464,7 @@ def _clip_search( # Compute loss for each alpha value for alpha in awq_clip.loss: # Perform QDQ on the whole original weight tensor - qw, scales, _ = quant_tensor(w_copy, block_size, alpha) + qw, scales, _ = quant_tensor(w_copy, block_size, alpha=alpha) cur_w = dq_tensor(qw, scales, block_size) # Reshape before getting the batch of size co_bsz to multiply with input @@ -509,7 +649,7 @@ def _quantize_awq_clip( w = np.asarray(w) alpha = alphas.get(weight_tensor.name, 1) - qw, scale, _ = quant_tensor(w, block_size, alpha) + qw, scale, _ = quant_tensor(w, block_size, alpha=alpha) if has_cupy: qw = np.asnumpy(qw) scale = np.asnumpy(scale) @@ -527,12 +667,33 @@ def _quantize_awq_clip( logger.info(f"Quantizing actual weights took {time.time() - t} seconds") - t = time.time() graph_gs = gs.import_onnx(onnx_model) + + gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE) + gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS) + gather_w_map = None + gather_s_map = None + if gather_quantize_axis is not None: + gather_w_map, gather_s_map, _ = _quantize_gather_nodes( + graph, + nodes_to_exclude, + gather_quantize_axis, + gather_block_size, + use_zero_point=False, + dq_only=True, + ) + + t = time.time() dq_node_attributes = {"axis": 0, "block_size": block_size} qdq.insert_dq_nodes( graph_gs, scales, quantized_weights=gemm_weights_quantized, attributes=dq_node_attributes ) + if gather_w_map is not None: + assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" + gather_dq_node_attributes = {"axis": gather_quantize_axis, "block_size": gather_block_size} + qdq.insert_dq_nodes( + graph_gs, scales, quantized_weights=gather_w_map, attributes=gather_dq_node_attributes + ) logger.info(f"Inserting DQ nodes took {time.time() - t} seconds") logger.info("Exporting the quantized graph") @@ -685,7 +846,7 @@ def run_awq_scale_search_per_node( w_scaled = w * awq_scale[:, np.newaxis] qw, scale, zp = quant_tensor(w_scaled, block_size, use_zero_point=use_zero_point) - dqw = dq_tensor(qw, scale, block_size, zp) + dqw = dq_tensor(qw, scale, block_size, zp=zp) out_curr = x_scaled.__matmul__(dqw) loss = np.mean(np.power((out_actual - out_curr), 2)) del out_curr @@ -850,7 +1011,7 @@ def run_awq_scale_search_per_subgraph( x_scaled = x * 1.0 / awq_scale w_scaled = w * awq_scale[:, np.newaxis] qw, scale, zp = quant_tensor(w_scaled, block_size, use_zero_point=use_zero_point) - dqw = dq_tensor(qw, scale, block_size, zp) + dqw = dq_tensor(qw, scale, block_size, zp=zp) out_curr = x_scaled.__matmul__(dqw) loss += np.mean(np.power((out_act - out_curr), 2)) del out_curr, out_act @@ -1076,7 +1237,9 @@ def _quantize_awq_lite( assert enable_weight_clipping or (alpha == 1), ( "clip range enabled without enabling weight-clipping param" ) - qw, scale, zp = quant_tensor(w_scaled, block_size, alpha, use_zero_point=use_zero_point) + qw, scale, zp = quant_tensor( + w_scaled, block_size, alpha=alpha, use_zero_point=use_zero_point + ) assert use_zero_point is True or zp is None, "zp is not according to use-zero-point setting" if do_transpose: qw = qw.T @@ -1179,8 +1342,25 @@ def _quantize_awq_lite( logger.info( "Inserting DQ nodes and input_pre_quant_scale node using quantized weights and scales" ) - t = time.time() + graph_gs = gs.import_onnx(onnx_model) + + gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE) + gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS) + gather_w_map = None + gather_s_map = None + gather_zp_map = None + if gather_quantize_axis is not None: + gather_w_map, gather_s_map, gather_zp_map = _quantize_gather_nodes( + graph_gs, + nodes_to_exclude, + gather_quantize_axis, + gather_block_size, + use_zero_point=use_zero_point, + dq_only=True, + ) + + t = time.time() dq_node_attributes = {"axis": 0, "block_size": block_size} qdq.insert_dq_nodes( graph_gs, @@ -1189,6 +1369,19 @@ def _quantize_awq_lite( attributes=dq_node_attributes, zero_points=zero_points if use_zero_point else None, ) + if gather_w_map is not None: + assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" + assert not use_zero_point or gather_zp_map, ( + "zero-point setting and zero-point map not in sync for quantizable gather nodes" + ) + gather_dq_node_attributes = {"axis": gather_quantize_axis, "block_size": gather_block_size} + qdq.insert_dq_nodes( + graph_gs, + gather_s_map, + quantized_weights=gather_w_map, + attributes=gather_dq_node_attributes, + zero_points=gather_zp_map if use_zero_point else None, + ) if pre_quant_scale: qdq.insert_pre_quant_scale_nodes(graph_gs, input_tensors, pre_quant_scale) @@ -1269,6 +1462,10 @@ def quantize( Default: 0.5. - **awqclip_bsz_col** (int): Batch size for processing the column dimension in awq-clip. Default: 1024. + - **gather_quantize_axis** (int): Quantization axis for Gather nodes. + Default: None (Gather nodes not quantized). + - **gather_block_size** (int): Block-size for Gather nodes quantization. + Default: 32. **Returns**: A quantized ONNX model in ONNX ModelProto format. """ configure_logging(level=log_level.upper()) @@ -1309,6 +1506,7 @@ def quantize( block_size, dq_only="dq" in calibration_method, nodes_to_exclude=nodes_to_exclude, + **kwargs, ) elif calibration_method in ["awq_lite", "awq_full"]: do_weight_clipping = False diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index daf6c612d..e14179c73 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -55,6 +55,7 @@ find_nodes_from_mha_to_exclude, print_stat, remove_redundant_cast_nodes, + validate_op_types_spelling, ) from modelopt.onnx.quantization.int4 import quantize as quantize_int4 from modelopt.onnx.quantization.int8 import quantize as quantize_int8 @@ -421,6 +422,9 @@ def quantize( nodes_to_quantize = nodes_to_quantize or [] nodes_to_exclude = nodes_to_exclude or [] + # Check op types spelling in 'op_types_to_exclude' and '_to_quantize' + validate_op_types_spelling(onnx_path, op_types_to_quantize, op_types_to_exclude) + # (1) If disable_mha_qdq is set, don't add Q/DQ layers to MatMuls in MHA pattern. # (2) else when quantize_mode == "int8", if seq_len > 512, don't add Q/DQ layers to # MatMuls in MHA pattern. diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 0f4babf96..a922bb19f 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -334,8 +334,7 @@ def get_onnx_bytes_and_metadata( dynamo_export: bool = False, onnx_opset: int = DEFAULT_ONNX_OPSET, dq_only: bool = False, - weights_dtype: str = "float32", - use_autocast: bool = False, + weights_dtype: str = "fp32", ) -> tuple[bytes, ModelMetadata]: """Get onnx model in bytes from input pytorch model together with the input/output of model. @@ -353,7 +352,6 @@ def get_onnx_bytes_and_metadata( onnx_opset: The onnx opset version to use for exporting the model. dq_only: If True, the exported onnx model is converted to a dq_only model. weights_dtype: The dtype of the weights in the onnx model. - use_autocast: If True, the model is exported using torch.autocast(). Returns: bytes: Onnx model in bytes. @@ -365,8 +363,8 @@ def get_onnx_bytes_and_metadata( if not isinstance(model, nn.Module): raise ValueError("Only PyTorch model compilation is supported.") - assert weights_dtype in ["float32", "float16"], ( - "weights_dtype must be one of float32, or float16" + assert weights_dtype in ["fp32", "fp16", "bf16"], ( + "weights_dtype must be one of fp32, fp16, or bf16" ) # unwrap DDP and DP models @@ -394,9 +392,10 @@ def get_onnx_bytes_and_metadata( # during inference. input_none_names = list(set(tree_spec_input.names) - set(input_names)) - # Get output once (we export in inference mode - so also using inference mode here!) - autocast = torch.autocast("cuda") if use_autocast else nullcontext() + use_torch_autocast = not (is_fp4_quantized(model) or is_mxfp8_quantized(model)) + autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext() + # Get output once (we export in inference mode - so also using inference mode here!) with torch.inference_mode(), autocast: output = model(*named_args.values()) @@ -408,7 +407,7 @@ def get_onnx_bytes_and_metadata( if onnx_load_path != "": onnx_model = OnnxBytes(onnx_load_path) - onnx_model_graph = onnx.load(os.path.join(onnx_load_path)) + onnx_model_graph = onnx.load(onnx_load_path) model_metadata = create_model_metadata( tree_spec_input, tree_spec_output, input_none_names, onnx_model_graph, model ) @@ -479,8 +478,14 @@ def get_onnx_bytes_and_metadata( if dq_only: onnx_opt_graph = qdq_to_dq(onnx_opt_graph) - if weights_dtype == "float16": - if not use_autocast: + try: + # TODO: Single-precision torch model assumed + param_dtype = next(model.parameters()).dtype + except StopIteration: + param_dtype = torch.float32 + if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32: + if is_mxfp8_quantized(model): + assert weights_dtype == "fp16", "BF16 + MXFP8 mixed precision is not supported yet" onnx_opt_graph = convert_float_to_float16( onnx_opt_graph, keep_io_types=False, @@ -488,7 +493,9 @@ def get_onnx_bytes_and_metadata( check_fp16_ready=False, ) else: - onnx_opt_graph = convert_to_f16(onnx_opt_graph, keep_io_types=False) + onnx_opt_graph = convert_to_f16( + onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False + ) # If the onnx model contains external data store the external tensors in one file and save the onnx model if has_external_data(onnx_save_path): diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index 55508f230..b0c5d8adc 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -18,7 +18,6 @@ Some of the logics in this file are empirical and needs constant update if exceptions occur. """ -from contextlib import contextmanager from warnings import warn import torch @@ -995,67 +994,165 @@ def module_match_name_list(module, name_list): return ["linear_fc1", "linear_fc2"] elif module_match_name_list(module, ["DBRXMoeSparseMoeBlock"]): return ["w1_linear", "w2_linear", "v1_linear"] + elif module_match_name_list(module, ["GptOssMoE"]): + # GPT-OSS MoE modules use gate_up_proj and down_proj + return ["gate_up_proj", "down_proj"] else: # assuing w1, w2, w3 by default return ["w1", "w2", "w3"] -def set_amax_for_uncalibrated_experts(experts: nn.Module, set_amax_value: float | None = None): - """Set amax of uncalibrated experts to a given value or the max of existing amax value from other experts. +def set_expert_quantizer_amax( + modules: nn.Module | list[nn.Module], + quantizer_attrs: str | list[str] | None = None, + fallback_value: float = 0.5, + device: torch.device | None = None, +) -> list[nn.Module]: + """Set amax values for expert quantizers using smart fallback logic. + + Uses smart fallback logic: + + 1. Use max from existing quantizers in current batch (best - direct from calibration) + 2. If no existing values found, then: + - For weight quantizers: calculate from weight statistics + - For input quantizers: use max from other experts, fallback if none found + 3. Use fallback value as last resort + + This ensures we always have semantically appropriate amax values for export. Args: - experts: a list of experts - set_amax_value: set amax value to the given value. - If None, set amax value to the max of existing amax value from other experts. + modules: Single module or list of modules containing quantizers + quantizer_attrs: Specific quantizer attributes to handle. + If None, defaults to ["input_quantizer"] for backward compatibility. + fallback_value: Final fallback value when other methods fail (default: 0.5) + device: Target device for tensors (auto-detected if None) Returns: - uncalibrated_experts: a list of uncalibrated experts + uncalibrated_modules: a list of uncalibrated experts """ - uncalibrated_experts = [] - # get the max amax value from all experts - if set_amax_value is None: - amax_values = [ - module.input_quantizer.amax - for module in experts - if ( - hasattr(module, "input_quantizer") - and module.input_quantizer is not None - and module.input_quantizer.is_enabled - ) - and module.input_quantizer.amax is not None - ] - if len(amax_values) == 0: - return uncalibrated_experts - set_amax_value = torch.max(torch.stack(amax_values)) + import warnings - for module in experts: - if ( - hasattr(module, "input_quantizer") - and module.input_quantizer is not None - and module.input_quantizer.is_enabled - ) and module.input_quantizer.amax is None: - warn( - f"Missing amax value for {module} input_quantizer. Setting it to {set_amax_value} for export. " - f"This typically occurs in MoE models when certain experts are not activated during calibration. " - f"Consider increasing your calibration dataset size to ensure all experts are exercised." - ) - # Use float32 dtype explicitly to ensure we create a floating point tensor - module.input_quantizer.amax = torch.tensor( - set_amax_value, dtype=torch.float32, device=module.weight.device - ) - uncalibrated_experts.append(module) + # Normalize inputs + if not isinstance(modules, list): + modules = [modules] + if quantizer_attrs is None: + quantizer_attrs = ["input_quantizer"] + elif isinstance(quantizer_attrs, str): + quantizer_attrs = [quantizer_attrs] -@contextmanager -def set_amax_for_uncalibrated_experts_context( - experts: nn.Module, set_amax_value: float | None = None -): - """Set amax for uncalibrated experts in a context manager.""" - uncalibrated_experts = set_amax_for_uncalibrated_experts(experts, set_amax_value) - yield - if uncalibrated_experts: - for module in uncalibrated_experts: - delattr(module.input_quantizer, "_amax") + uncalibrated_modules = [] + + # Determine target device if not provided + if device is None: + first_module = next(iter(modules)) + if hasattr(first_module, "weight"): + target_device = first_module.weight.device + else: + target_device = torch.device("cpu") + else: + target_device = device + + # Collect all valid quantizers + all_quantizers = [] + + for module in modules: + for attr_name in quantizer_attrs: + if hasattr(module, attr_name): + quantizer = getattr(module, attr_name) + if ( + quantizer is not None + and hasattr(quantizer, "is_enabled") + and quantizer.is_enabled + ): + all_quantizers.append((module, attr_name, quantizer)) + + target_amax = None + + # Collect ANY existing amax values from current batch (most direct source) + valid_amax_values = [] + for _, attr_name, quantizer in all_quantizers: + existing_amax = getattr(quantizer, "amax", None) + if existing_amax is not None: + # Convert to tensor and add to collection + if isinstance(existing_amax, torch.Tensor): + valid_amax_values.append(existing_amax.to(target_device)) + else: + valid_amax_values.append( + torch.tensor(existing_amax, dtype=torch.float32, device=target_device) + ) + + # Use existing values from current batch if any found + if len(valid_amax_values) > 0: + target_amax = torch.max(torch.stack(valid_amax_values)) + + # If no existing values in current batch, apply type-specific fallback logic + elif target_amax is None: + has_input_quantizers = any("input_quantizer" in attr for _, attr, _ in all_quantizers) + has_weight_quantizers = any("weight_quantizer" in attr for _, attr, _ in all_quantizers) + + if has_weight_quantizers and not has_input_quantizers: + # For weight quantizers: calculate from weight statistics + weight_amax_values = [] + for module, _, _ in all_quantizers: + # Try to find a weight tensor in the module + weight_tensor = None + for weight_attr in ["weight", "gate_up_proj", "down_proj"]: + if hasattr(module, weight_attr): + weight_tensor = getattr(module, weight_attr) + break + + if weight_tensor is not None: + weight_amax_values.append(torch.max(torch.abs(weight_tensor))) + + if weight_amax_values: + target_amax = torch.max(torch.stack(weight_amax_values)).item() + elif has_input_quantizers: + # For input quantizers: ideally search other experts for existing input amax values + # TODO: Implement broader expert search - currently function only has access to current batch + # For now, this will fall through to fallback value + pass + + # Final fallback + if target_amax is None: + target_amax = fallback_value + has_input_quantizers = any("input_quantizer" in attr for _, attr, _ in all_quantizers) + + # Apply target amax to quantizers that need it + for module, attr_name, quantizer in all_quantizers: + # Check if quantizer needs amax (use property for consistency) + needs_amax = getattr(quantizer, "amax", None) is None + + # Skip dynamic quantizers for input quantizers + if "input_quantizer" in attr_name and getattr(quantizer, "_dynamic", False): + needs_amax = False + + if needs_amax: + # Create tensor with appropriate value (using function-wide target_device) + if isinstance(target_amax, torch.Tensor): + amax_tensor = target_amax.clone().to(dtype=torch.float32, device=target_device) + else: + amax_tensor = torch.tensor(target_amax, dtype=torch.float32, device=target_device) + + # Set amax value using property for proper validation and tensor handling + quantizer.amax = amax_tensor + + uncalibrated_modules.append(module) + amax_val = amax_tensor.item() if isinstance(amax_tensor, torch.Tensor) else amax_tensor + + if len(valid_amax_values) > 0: + warnings.warn( + f"Missing amax value for {attr_name} in {type(module).__name__}. " + f"Setting it to {amax_val:.6f} (max from existing quantizers in current batch). " + f"This typically occurs when certain experts are not activated during calibration." + ) + elif amax_val != fallback_value and "input_quantizer" not in attr_name: + warnings.warn( + f"Missing amax value for {attr_name} in {type(module).__name__}. " + f"Setting it to {amax_val:.6f} (computed from weights)." + ) + + return uncalibrated_modules def build_stacked_experts( @@ -1084,15 +1181,19 @@ def build_stacked_experts( resmooth_only=True, ) - # Set amax to 0 for uncalibrated experts - with set_amax_for_uncalibrated_experts_context( - [ - expert_getter(experts, i, module_name) - for module_name in linear_names - for i in range(num_experts) - ], - 0, # set amax to 0 for uncalibrated experts as we will calculate max across all experts later - ): + # Set amax to 0 for uncalibrated experts (calculate max across all experts later) + expert_modules = [ + expert_getter(experts, i, module_name) + for module_name in linear_names + for i in range(num_experts) + ] + uncalibrated_experts = set_expert_quantizer_amax( + modules=expert_modules, + quantizer_attrs=["input_quantizer"], + fallback_value=0, + ) + + try: # Pre-fuse W1 and W3 if len(linear_names) == 3: for i in range(num_experts): @@ -1150,10 +1251,17 @@ def build_stacked_experts( experts_weight_3.weights_scaling_factor_2, ) - # Explicitly move weight to CPU to reduce GPU memory requirement. - experts_weight_1.weight = experts_weight_1.weight.cpu() - experts_weight_2.weight = experts_weight_2.weight.cpu() - return experts_weight_1, experts_weight_2 + # Explicitly move weight to CPU to reduce GPU memory requirement. + experts_weight_1.weight = experts_weight_1.weight.cpu() + experts_weight_2.weight = experts_weight_2.weight.cpu() + return experts_weight_1, experts_weight_2 + + finally: + # Cleanup: restore original amax values (same logic as old context manager) + if uncalibrated_experts: + for module in uncalibrated_experts: + if hasattr(module.input_quantizer, "_amax"): + delattr(module.input_quantizer, "_amax") def build_moe_config(module: nn.Module, decoder_type) -> MOEConfig: diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py old mode 100644 new mode 100755 index 6081c3c1e..af62fd526 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -49,6 +49,7 @@ "Nemotron": "gpt", "Deepseek": "deepseek", "Whisper": "whisper", + "gptoss": "gptoss", } __doc__ = f"""Utility functions for model type detection and classification. diff --git a/modelopt/torch/export/plugins/mcore_common.py b/modelopt/torch/export/plugins/mcore_common.py index f4eb0e736..74cf006a6 100644 --- a/modelopt/torch/export/plugins/mcore_common.py +++ b/modelopt/torch/export/plugins/mcore_common.py @@ -18,6 +18,7 @@ from typing import Any from .mcore_deepseek import deepseek_causal_lm_export, deepseek_causal_lm_import +from .mcore_gptoss import gptoss_causal_lm_export, gptoss_causal_lm_import from .mcore_llama import ( eagle3_llama_causal_lm_export, eagle_llama_causal_lm_export, @@ -50,6 +51,7 @@ "Qwen3ForCausalLM": qwen3_causal_lm_export, "Qwen3MoeForCausalLM": qwen3_causal_lm_export, "Qwen2ForCausalLM": qwen25_causal_lm_export, + "GptOssForCausalLM": gptoss_causal_lm_export, } all_mcore_hf_import_mapping: dict[str, Any] = { @@ -61,4 +63,5 @@ "Qwen3ForCausalLM": qwen3_causal_lm_import, "Qwen3MoeForCausalLM": qwen3_causal_lm_import, "Qwen2ForCausalLM": qwen25_causal_lm_import, + "GptOssForCausalLM": gptoss_causal_lm_import, } diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 369a86961..666dfa3f3 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -17,6 +17,7 @@ """Custom Megatron mapping and safetensors utility.""" import json +import math import os from dataclasses import dataclass from pathlib import Path @@ -162,6 +163,30 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] ) +class PackNameRemappingGPT(CustomModuleMapping): + """A custom module mapping that packs module after name remapping.""" + + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that packs and renames module.""" + super().__init__( + func_name="pack_name_remapping_gpt_oss", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) + + +class UnpackNameRemappingGPT(CustomModuleMapping): + """A custom module mapping that unpacks module after name remapping.""" + + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that unpacks module after name remapping.""" + super().__init__( + func_name="unpack_name_remapping_gpt_oss", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) + + def save_safetensors(state_dict, save_directory: str | os.PathLike): """Save safetensors with pipeline model parallel support.""" pp_rank = get_pipeline_model_parallel_rank() @@ -279,9 +304,10 @@ def _get_safetensor_slices( tensor_slice = f.get_slice(key) assert tensor_slice is not None shape = tensor_slice.get_shape() - assert len(shape) in (1, 2, 3), f"Shape {shape} is not supported!" + assert len(shape) in (1, 2, 3, 4), f"Shape {shape} is not supported!" # 1 for bias case - # 3 for packed MoE case + # 3 for packed MoE case (Llama4) + # 4 for packed MoE case with mxfp4 dtype (gpt-oss) # MCore tensor parallel model sharding sharding_dim = parallel_config.sharding_dim @@ -292,9 +318,11 @@ def _get_safetensor_slices( tp_size = get_expert_tensor_parallel_world_size() if tp_size > 1: raise ValueError("Packed MoE import only supports ETP=1.") - if len(shape) != 3: + if len(shape) not in (2, 3, 4): raise ValueError( - "Packed MoE import only supports 3D tensor in shape [num_experts, in_dim, out_dim]." + "For LLama4, Packed MoE import only supports 3D tensor in shape [num_experts, in_dim, out_dim]." + "For gpt-oss, Packed MoE import only supports 2D tensor of MOE bias, " + "3D tensor of MOE scales, 4D tensor of MOE blocks" ) if sharding_dim != 0: raise ValueError("Packed MoE import only supports sharding_dim=0.") @@ -308,7 +336,12 @@ def _get_safetensor_slices( key, shape, sharding_dim, ep_size ) ) - tensor = tensor_slice[rank_offset : rank_offset + per_rank_size, :, :] + if len(shape) == 2: + tensor = tensor_slice[rank_offset : rank_offset + per_rank_size, :] + elif len(shape) == 3: + tensor = tensor_slice[rank_offset : rank_offset + per_rank_size, :, :] + elif len(shape) == 4: + tensor = tensor_slice[rank_offset : rank_offset + per_rank_size, :, :, :] else: if parallel_group == "TP": tp_rank = get_tensor_model_parallel_rank() @@ -334,6 +367,7 @@ def _get_safetensor_slices( tensor = tensor_slice[rank_offset : rank_offset + per_rank_size, :] elif len(shape) == 3: # For packed ETP case, Llama4 uses 3D tensor for local experts + # For gpt-oss case, gpt-oss uses 3D tensor for scales of local experts if sharding_dim == 1: tensor = tensor_slice[:, rank_offset : rank_offset + per_rank_size, :] elif sharding_dim == 2: @@ -345,6 +379,12 @@ def _get_safetensor_slices( elif len(shape) == 1: # For bias case tensor = tensor_slice[rank_offset : rank_offset + per_rank_size] + elif len(shape) == 4: + # For packed ETP case, gpt-oss uses 4D tensor for local experts + if sharding_dim == 1: + tensor = tensor_slice[:, rank_offset : rank_offset + per_rank_size, :, :] + elif sharding_dim == 2: + tensor = tensor_slice[:, :, rank_offset : rank_offset + per_rank_size, :] else: raise ValueError(f"Unsupported shape: {shape}") return tensor @@ -417,3 +457,74 @@ def get_safetensor( tensor = padded_tensor return tensor.contiguous() + + +def dequantize_mxfp4_to_bf16( + blocks: torch.Tensor, + scales: torch.Tensor, + *, + dtype: torch.dtype = torch.bfloat16, + rows_per_chunk: int = 32768 * 1024, +) -> torch.Tensor: + """Dequantize MXFP4 blocks and scales to BF16 format. + + Args: + blocks: The quantized blocks tensor (U8 dtype) + scales: The scales tensor (U8 dtype) + dtype: Target dtype for dequantization (default: torch.bfloat16) + rows_per_chunk: Number of rows to process per chunk + + Returns: + Dequantized tensor in the specified dtype + """ + assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}" + + # Convert scales from U8 to int32 and subtract 127 (bias) + scales = scales.to(torch.int32) - 127 + + fp4_values = [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + lut = torch.tensor(fp4_values, dtype=dtype, device=blocks.device) + + *prefix_shape, g, b = blocks.shape + rows_total = math.prod(prefix_shape) * g + + blocks = blocks.reshape(rows_total, b) + scales = scales.reshape(rows_total, 1) + + out = torch.empty(rows_total, b * 2, dtype=dtype, device=blocks.device) + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + del idx_lo, idx_hi, blk, exp + + return out.reshape(*prefix_shape, g, b * 2).view(*prefix_shape, g * b * 2) diff --git a/modelopt/torch/export/plugins/mcore_gptoss.py b/modelopt/torch/export/plugins/mcore_gptoss.py new file mode 100644 index 000000000..c16347fbf --- /dev/null +++ b/modelopt/torch/export/plugins/mcore_gptoss.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom mapping from GPT-OSS Hugging Face models to Megatron Core models.""" + +from .mcore_custom import ( + COL_TP, + PACK_EP, + REPLICATE, + ROW_TP, + CustomModuleMapping, + NameRemapping, + PackNameRemappingGPT, + QKVMerging, + QKVSlicing, + UnpackNameRemappingGPT, +) + +gptoss_causal_lm_export: dict[str, CustomModuleMapping | bool] = { + "word_embeddings": NameRemapping("model.embed_tokens."), + "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), + "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), + "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), + "softmax_offset": NameRemapping("model.layers.{}.self_attn.sinks"), + "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), + "use_packed_local_experts": True, + "local_experts.linear_fc1": PackNameRemappingGPT( + "model.layers.{}.mlp.experts.gate_up_proj", + {"layer_type": "linear_fc1"}, + ), + "local_experts.linear_fc2": PackNameRemappingGPT( + "model.layers.{}.mlp.experts.down_proj", + {"layer_type": "linear_fc2"}, + ), + "router": NameRemapping("model.layers.{}.mlp.router."), + "final_layernorm": NameRemapping("model.norm."), + "output_layer": NameRemapping("lm_head."), +} + +gptoss_causal_lm_import: dict[str, CustomModuleMapping | bool] = { + "word_embeddings": NameRemapping("model.embed_tokens.", COL_TP), + "input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE), + "linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP), + "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP), + "softmax_offset": NameRemapping("model.layers.{}.self_attn.sinks", COL_TP), + "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE), + "router": NameRemapping("model.layers.{}.mlp.router.", REPLICATE), + "use_packed_local_experts": True, + "local_experts.linear_fc1_ep": UnpackNameRemappingGPT( + "model.layers.{}.mlp.experts.gate_up_proj", + PACK_EP | {"layer_type": "linear_fc1"}, + ), + "local_experts.linear_fc2_ep": UnpackNameRemappingGPT( + "model.layers.{}.mlp.experts.down_proj", + PACK_EP | {"layer_type": "linear_fc2"}, + ), + "final_layernorm": NameRemapping("model.norm.", REPLICATE), + "output_layer": NameRemapping("lm_head.", COL_TP), +} diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index 80c94bba1..696af6323 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -26,7 +26,12 @@ from modelopt.torch.utils import import_plugin from .mcore_common import all_mcore_hf_import_mapping -from .mcore_custom import CustomModuleMapping, ParallelConfig, get_safetensor +from .mcore_custom import ( + CustomModuleMapping, + ParallelConfig, + dequantize_mxfp4_to_bf16, + get_safetensor, +) with import_plugin("transformers", verbose=False): import transformers @@ -34,7 +39,7 @@ has_mcore = False with import_plugin("megatron"): from megatron.core.parallel_state import ( - get_expert_model_parallel_world_size, + get_expert_tensor_parallel_world_size, get_tensor_model_parallel_world_size, ) from megatron.core.ssm.mamba_layer import MambaLayer @@ -108,6 +113,7 @@ def _custom_mapping_to_lambda(mapping): "qkv_merging": self._qkv_merging, "gated_mlp_merging": self._gated_mlp_merging, "unpack_name_remapping": self._unpack_name_remapping, + "unpack_name_remapping_gpt_oss": self._unpack_name_remapping_gpt_oss, } func = method_map[mapping.func_name] prefix = mapping.target_name_or_prefix @@ -136,7 +142,8 @@ def _name_remapping( parallel_config: ParallelConfig | None = None, ): if isinstance(module, torch.Tensor): - module.data.copy_(self._get_safetensor(prefix)) + tensor = self._get_safetensor(prefix, parallel_config=parallel_config) + module.data.copy_(tensor) return weight = module.state_dict().get("weight", None) @@ -174,7 +181,18 @@ def _name_remapping( state_dict[key] = val else: source_key = mapping.get(key, key) - tensor = self._get_safetensor(prefix + source_key, parallel_config=parallel_config) + # For bias tensors in ROW_TP layers, don't use parallel config to avoid sharding + # since bias should always be replicated, not sharded + if ( + key == "bias" + and parallel_config is not None + and parallel_config.sharding_dim == 1 + ): + tensor = self._get_safetensor(prefix + source_key, parallel_config=None) + else: + tensor = self._get_safetensor( + prefix + source_key, parallel_config=parallel_config + ) state_dict[key] = tensor.to(dtype=self.dtype).to(device=val.device) module.load_state_dict(state_dict) @@ -373,6 +391,72 @@ def _unpack_name_remapping( linear_module.load_state_dict(state_dict) + def _unpack_name_remapping_gpt_oss( + self, + module, + prefix, + layer_type: str, + parallel_config: ParallelConfig | None = None, + ): + tensor_blocks = self._get_safetensor(prefix + "_blocks", parallel_config=parallel_config) + tensor_bias = self._get_safetensor(prefix + "_bias", parallel_config=parallel_config) + tensor_scales = self._get_safetensor(prefix + "_scales", parallel_config=parallel_config) + tensor = dequantize_mxfp4_to_bf16(tensor_blocks, tensor_scales, dtype=self.dtype) + + for idx, sub_module in enumerate(module.children()): + state_dict = {} + linear_module = getattr(sub_module, layer_type) + weight = linear_module.state_dict().get("weight", None) + sub_tensor = tensor[idx] + if weight is None: + raise ValueError(f"{linear_module!s} does not contain weight!") + # TODO (yueshen): Handle weight_scale case + else: + if layer_type == "linear_fc1": + # HF checkpoint has interleaved weights, need to de-interleave + # Pattern: [0,2,4,...,5758] -> [0,1,2,...,2879] and [1,3,5,...,5759] -> [2880,2881,...,5759] + height, width = sub_tensor.shape + half_height = height // 2 + + # Create de-interleaved tensor + deinterleaved_tensor = torch.zeros_like(sub_tensor) + deinterleaved_tensor[:half_height] = sub_tensor[ + ::2 + ] # Even indices -> first half + deinterleaved_tensor[half_height:] = sub_tensor[ + 1::2 + ] # Odd indices -> second half + sub_tensor = deinterleaved_tensor + + state_dict["weight"] = sub_tensor.to(dtype=self.dtype).to(device=weight.device) + + for key, val in linear_module.state_dict().items(): + if key in {"weight", "weight_quantizer._scale"}: + continue + elif "extra_state" in key: + state_dict[key] = val + elif "bias" in key: + sub_tensor_bias = tensor_bias[idx] + + if layer_type == "linear_fc1": + # HF checkpoint has interleaved bias, need to de-interleave + bias_len = sub_tensor_bias.shape[0] + half_bias_len = bias_len // 2 + + # Create de-interleaved bias tensor + deinterleaved_bias = torch.zeros_like(sub_tensor_bias) + deinterleaved_bias[:half_bias_len] = sub_tensor_bias[ + ::2 + ] # Even indices -> first half + deinterleaved_bias[half_bias_len:] = sub_tensor_bias[ + 1::2 + ] # Odd indices -> second half + sub_tensor_bias = deinterleaved_bias + + state_dict["bias"] = sub_tensor_bias.to(dtype=self.dtype).to(device=val.device) + + linear_module.load_state_dict(state_dict) + def _import_state_dict(self): model = self.model @@ -428,6 +512,10 @@ def _import_state_dict(self): self.rules["k_layernorm"](attention.k_layernorm, layer_id) self.rules["linear_qkv"](attention.linear_qkv, layer_id) self.rules["linear_proj"](attention.linear_proj, layer_id) + if hasattr(attention.core_attention, "softmax_offset"): + self.rules["softmax_offset"]( + attention.core_attention.softmax_offset, layer_id + ) if not isinstance(layer.pre_mlp_layernorm, IdentityOp): self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) @@ -458,20 +546,23 @@ def _import_state_dict(self): self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id) self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id) # We only support either EP or ETP for now - elif get_expert_model_parallel_world_size() > 1: - # EP supports for packed MoE - self.rules["local_experts.linear_fc1_ep"]( + elif get_expert_tensor_parallel_world_size() > 1: + # ETP supports for packed MoE + # ETP is not supported for gpt-oss model + if self.arch in ["GptOssForCausalLM"]: + raise ValueError("ETP is not supported for gpt-oss model") + self.rules["local_experts.linear_fc1_etp"]( layer.mlp.experts.local_experts, layer_id ) - self.rules["local_experts.linear_fc2_ep"]( + self.rules["local_experts.linear_fc2_etp"]( layer.mlp.experts.local_experts, layer_id ) else: - # ETP supports for packed MoE - self.rules["local_experts.linear_fc1_etp"]( + # EP supports for packed MoE + self.rules["local_experts.linear_fc1_ep"]( layer.mlp.experts.local_experts, layer_id ) - self.rules["local_experts.linear_fc2_etp"]( + self.rules["local_experts.linear_fc2_ep"]( layer.mlp.experts.local_experts, layer_id ) else: diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py old mode 100644 new mode 100755 index 33ff83b1f..5a2334673 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -88,6 +88,38 @@ def get_scaling_factor_from_weight(weight, group_size) -> torch.tensor: return weights_scaling_factor +def maybe_transpose_expert_weight_dimensions( + weight: torch.Tensor, + weight_scale: torch.Tensor | None = None, + is_bmm_expert_weight: bool = True, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Transpose the last two dimensions of expert weights. + + This function transposes expert weights between the two layouts: + - (num_experts, input_dim, output_dim) ↔ (num_experts, output_dim, input_dim) + + Since transpose(-2, -1) is self-inverse, this function can be used for both + forward and backward transformations. This is needed for quantization functions + that expect the last dimension to be the input dimension for block quantization. + Specifically used for bmm-style expert weights in models like llama4 and gpt-oss. + + Args: + weight: The weight tensor to transpose. Expected shape for experts: (num_experts, dim1, dim2) + weight_scale: Optional weight scaling factor tensor to transpose alongside weight + is_bmm_expert_weight: Whether this is an expert weight (3D tensor) that needs transposition + + Returns: + Tuple of (transposed_weight, transposed_weight_scale) + """ + if not is_bmm_expert_weight or weight.dim() != 3: + return weight, weight_scale + + transposed_weight = weight.transpose(-2, -1) + transposed_weight_scale = weight_scale.transpose(-2, -1) if weight_scale is not None else None + + return transposed_weight, transposed_weight_scale + + def resmooth_and_get_scale( merged_weights: torch.Tensor, pre_quant_scales: list[torch.Tensor], @@ -571,6 +603,11 @@ def process_layer_quant_config(layer_config_dict): # Get layer name for constructing quantized_layers dictionary under per_layer_config prefix = ".".join(k.rsplit(".", 1)[:-1]) + awq_key = prefix + ".awq_block_size" + + # Get the corresponding AWQ block size + block_size_value = layer_config_dict.get(awq_key, 0) + if v == "fp8": layer_config = {"quant_algo": "FP8"} elif v == "fp8_pc_pt": @@ -578,14 +615,14 @@ def process_layer_quant_config(layer_config_dict): elif v == "int4_awq": layer_config = { "quant_algo": "W4A16_AWQ", - "group_size": layer_config_dict[prefix + ".awq_block_size"], + "group_size": block_size_value, "has_zero_point": False, "pre_quant_scale": True, } elif v == "w4a8_awq": layer_config = { "quant_algo": "W4A8_AWQ", - "group_size": layer_config_dict[prefix + ".awq_block_size"], + "group_size": block_size_value, "has_zero_point": False, "pre_quant_scale": True, } @@ -594,12 +631,12 @@ def process_layer_quant_config(layer_config_dict): elif v == "nvfp4": layer_config = { "quant_algo": "NVFP4", - "group_size": layer_config_dict[prefix + ".awq_block_size"], + "group_size": block_size_value, } elif v == "nvfp4_awq": layer_config = { "quant_algo": "NVFP4_AWQ", - "group_size": layer_config_dict[prefix + ".awq_block_size"], + "group_size": block_size_value, "has_zero_point": False, "pre_quant_scale": True, } @@ -613,7 +650,7 @@ def process_layer_quant_config(layer_config_dict): elif v == "w4a8_mxfp4_fp8": layer_config = { "quant_algo": "W4A8_MXFP4_FP8", - "group_size": layer_config_dict[prefix + ".awq_block_size"], + "group_size": block_size_value, } else: layer_config = {"quant_algo": v} @@ -996,9 +1033,33 @@ def get_quant_config(named_modules: nn.Module | dict[str, nn.Module]) -> dict[st kv_cache_format = QUANTIZATION_NONE for name, module in dict(named_modules).items(): - if hasattr(module, "input_quantizer") or hasattr(module, "weight_quantizer"): + # Check for standard quantizers or any quantizers from weight attributes + has_quantizers = ( + hasattr(module, "input_quantizer") + or hasattr(module, "weight_quantizer") + or any( + hasattr(module, quantizer_attr_names(weight_name).weight_quantizer) + or hasattr(module, quantizer_attr_names(weight_name).input_quantizer) + for weight_name in weight_attr_names(module) + ) + ) + if has_quantizers: quantization_format = get_quantization_format(module) - block_size = get_weight_block_size(module) + + # For MoE expert modules, we need to extract block size from the correct weight quantizer + # Try to get block size from each weight attribute (e.g., gate_up_proj, down_proj) + block_size = 0 + weight_names = list(weight_attr_names(module)) + + for weight_name in weight_names: + weight_block_size = get_weight_block_size(module, weight_name) + if weight_block_size > 0: + block_size = weight_block_size + break + + # Fallback to default weight quantizer if no specific weight quantizer found + if block_size == 0: + block_size = get_weight_block_size(module) # Construct per layer config dictionary layer_config_dict[name + ".quantization"] = quantization_format diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py old mode 100644 new mode 100755 index fbe7b16ca..b18ae2619 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -29,6 +29,7 @@ from modelopt.torch.quantization import set_quantizer_by_cfg_context from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer +from modelopt.torch.quantization.qtensor import NVFP4QTensor from modelopt.torch.quantization.utils import quantizer_attr_names from .convert_hf_config import convert_hf_quant_config_format @@ -38,7 +39,7 @@ is_layernorm, is_moe, is_quantlinear, - set_amax_for_uncalibrated_experts, + set_expert_quantizer_amax, ) from .model_config import ( KV_CACHE_FP8, @@ -60,6 +61,7 @@ get_weight_block_size, get_weight_scaling_factor, get_weight_scaling_factor_2, + maybe_transpose_expert_weight_dimensions, postprocess_state_dict, preprocess_linear_fusion, to_quantized_weight, @@ -290,64 +292,44 @@ def _export_quantized_weight( weight_scale: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale, None) weight_scale_2: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale_2, None) - quantized_weight = to_quantized_weight( - weight.to(dtype), - weight_scale, - quantization_format, - weight_scale_2, - block_size, - ) - setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False)) + # Transpose weight for bmm-style expert quantization (llama4, gpt-oss) + if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]: + # Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim) + # for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization + is_bmm_expert_weight = weight.dim() == 3 and any( + expert_type in type(sub_module).__name__ + for expert_type in ["Llama4TextExperts", "GptOssExperts"] + ) + weight, _ = maybe_transpose_expert_weight_dimensions( + weight, is_bmm_expert_weight=is_bmm_expert_weight + ) + weight_scale = NVFP4QTensor.get_weights_scaling_factor( + weight, + block_size=block_size, + weights_scaling_factor_2=weight_scale_2, + )[0] + + quantized_weight = to_quantized_weight( + weight.to(dtype), + weight_scale, + quantization_format, + weight_scale_2, + block_size, + ) + quantized_weight, weight_scale = maybe_transpose_expert_weight_dimensions( + quantized_weight, weight_scale, is_bmm_expert_weight=is_bmm_expert_weight + ) + else: + quantized_weight = to_quantized_weight( + weight.to(dtype), + weight_scale, + quantization_format, + weight_scale_2, + block_size, + ) -def _handle_llama4_experts_amax(module: nn.Module): - """Handle the amax values for the experts in the Llama4 model.""" - # Handle uncalibrated input quantizers that have None amax values - input_quantizers = [ - module.gate_up_proj_input_quantizer, - module.down_proj_input_quantizer, - ] - - # Only handle enabled input quantizers - enabled_input_quantizers = [q for q in input_quantizers if q.is_enabled] - - # Only handle amax for non-dynamic quantizers - non_dynamic_quantizers = [ - q for q in enabled_input_quantizers if not getattr(q, "_dynamic", False) - ] - - if non_dynamic_quantizers: - # Find the maximum amax value from non-None quantizers - valid_amax_values = [ - quantizer.amax for quantizer in non_dynamic_quantizers if quantizer.amax is not None - ] - - device = module.gate_up_proj.device - - # If all quantizers have None amax, set a default value - if not valid_amax_values: - default_amax = torch.tensor(1.0, dtype=torch.float32, device=device) - warnings.warn( - "All input quantizers have None amax values. Setting default amax to 1.0. " - "This typically occurs when experts are not activated during calibration. " - "Consider increasing your calibration dataset size to ensure all experts are exercised." - ) - for quantizer in non_dynamic_quantizers: - if quantizer.amax is None: - quantizer.amax = default_amax.clone() - else: - # Set None amax values to the maximum of existing values - max_amax = torch.max(torch.stack(valid_amax_values)) - if max_amax.device != device: - max_amax = max_amax.to(device) - for quantizer in non_dynamic_quantizers: - if quantizer.amax is None: - warnings.warn( - f"Missing amax value for input quantizer. Setting it to {max_amax.item()} for export. " - "This typically occurs when certain experts are not activated during calibration. " - "Consider increasing your calibration dataset size to ensure all experts are exercised." - ) - quantizer.amax = max_amax.clone() + setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False)) def _export_hf_checkpoint( @@ -392,12 +374,28 @@ def _export_hf_checkpoint( if hasattr(experts_mlp, linear_name): linear_modulelist = getattr(experts_mlp, linear_name) if hasattr(linear_modulelist, "__iter__"): - set_amax_for_uncalibrated_experts(list(linear_modulelist)) + set_expert_quantizer_amax( + modules=list(linear_modulelist), + quantizer_attrs=["input_quantizer"], + ) + elif "QuantGptOssExperts" in type(sub_module.experts).__name__: + # Handle GPT-OSS experts specifically + # GPT-OSS experts use gate_up_proj and down_proj + gpt_oss_linear_names = ["gate_up_proj", "down_proj"] + for linear_name in gpt_oss_linear_names: + if hasattr(sub_module.experts, linear_name): + linear_module = getattr(sub_module.experts, linear_name) + if hasattr(linear_module, "input_quantizer"): + set_expert_quantizer_amax( + modules=[linear_module], + quantizer_attrs=["input_quantizer"], + ) elif isinstance(sub_module.experts, collections.abc.Iterable): # For other MoE models (like Mixtral) with iterable experts try: - set_amax_for_uncalibrated_experts( - [getattr(expert, linear_name) for expert in sub_module.experts] + set_expert_quantizer_amax( + modules=[getattr(expert, linear_name) for expert in sub_module.experts], + quantizer_attrs=["input_quantizer"], ) except AttributeError as e: # Provide more helpful debugging information @@ -460,8 +458,22 @@ def _export_hf_checkpoint( has_quantized_layers = True if is_quantlinear(sub_module): _export_quantized_weight(sub_module, dtype) - elif "Llama4TextExperts" in type(sub_module).__name__: - _handle_llama4_experts_amax(sub_module) + elif ( + "Llama4TextExperts" in type(sub_module).__name__ + or "GptOssExperts" in type(sub_module).__name__ + ): + # TODO: consolidate uncalibrated experts handling logic + # Handle weight quantizers amax values using smart fallback logic + set_expert_quantizer_amax( + modules=sub_module, + quantizer_attrs=["gate_up_proj_weight_quantizer", "down_proj_weight_quantizer"], + ) + # Handle input quantizers amax values using smart fallback logic + set_expert_quantizer_amax( + modules=sub_module, + quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"], + ) + # Export the quantized weights for weight_name in ["gate_up_proj", "down_proj"]: _export_quantized_weight(sub_module, dtype, weight_name) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 93ed40e3a..a90201a4f 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -174,8 +174,8 @@ class GPTModelExporter: pretrained model hosted inside a model repo on huggingface.co; or a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - export_extra_modules: If True, export extra modules like medusa_heads, - eagle_module, or mtp. Otherwise, only export the base model. + export_extra_modules: If True, export extra modules like medusa_heads or + eagle_module. Otherwise, only export the base model. dtype: The weights data type to export the unquantized layers. """ @@ -216,6 +216,10 @@ def __init__( self.dtype = dtype self.trust_remote_code = trust_remote_code self.arch = self._hf_config.architectures[0] + # TODO: May modify this later according to what quantization exported ckpt is, currently only support BF16. + if self.arch == "GptOssForCausalLM": + if hasattr(self._hf_config, "quantization_config"): + del self._hf_config.quantization_config self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] @@ -232,7 +236,7 @@ def __init__( self.rules = self.all_rules["MedusaLlamaForCausalLM"] if mode == "eagle" and export_extra_modules: - is_eagle3 = mode_cfg["config"]["use_aux_hidden_state"] + is_eagle3 = mode_cfg["config"]["eagle_architecture_config"]["use_aux_hidden_state"] architectures = "LlamaForCausalLMEagle3" if is_eagle3 else "LlamaForCausalLMEagle" @@ -246,12 +250,18 @@ def __init__( eagle_config = { "use_input_layernorm_in_first_layer": mode_cfg["config"][ - "use_input_layernorm_in_first_layer" + "eagle_architecture_config" + ]["use_input_layernorm_in_first_layer"], + "use_last_layernorm": mode_cfg["config"]["eagle_architecture_config"][ + "use_last_layernorm" ], - "use_last_layernorm": mode_cfg["config"]["use_last_layernorm"], - "use_mtp_layernorm": mode_cfg["config"]["use_mtp_layernorm"], - "use_aux_hidden_state": mode_cfg["config"]["use_aux_hidden_state"], - "eagle_aux_hidden_state_layer_ids": model.eagle_aux_hidden_state_layer_ids, + "use_mtp_layernorm": mode_cfg["config"]["eagle_architecture_config"][ + "use_mtp_layernorm" + ], + "use_aux_hidden_state": mode_cfg["config"]["eagle_architecture_config"][ + "use_aux_hidden_state" + ], + "eagle_aux_hidden_state_layer_ids": model.eagle_config.eagle_aux_hidden_state_layer_ids, } eagle_config_update = { @@ -263,7 +273,9 @@ def __init__( "max_position_embeddings": self._hf_text_config.max_position_embeddings, "num_attention_heads": model.eagle_module.config.num_attention_heads, "num_key_value_heads": model.eagle_module.config.num_query_groups, - "num_hidden_layers": mode_cfg["config"]["eagle_num_layers"], + "num_hidden_layers": mode_cfg["config"]["eagle_architecture_config"][ + "num_hidden_layers" + ], "vocab_size": self._hf_text_config.vocab_size, # Unset any special token ids given that the tokenizer can change here. "bos_token_id": None, @@ -272,35 +284,13 @@ def __init__( "sep_token_id": None, # The following attributes are EAGLE specific "eagle_config": eagle_config, - } - - # [TODO] (yeyu): there is also target_hidden_size - if mode_cfg["config"]["draft_vocab_size"] > 0: - eagle_config_update["draft_vocab_size"] = mode_cfg["config"][ + "draft_vocab_size": mode_cfg["config"]["eagle_architecture_config"][ "draft_vocab_size" - ] - else: - eagle_config_update["draft_vocab_size"] = None + ], + } self._hf_extra_config.update(eagle_config_update) - if mode == "mtp" and export_extra_modules: - mtp_config = { - "hidden_size": self._hf_config.hidden_size, - "head_dim": self._hf_config.head_dim, - "intermediate_size": self._hf_config.intermediate_size, - "max_position_embeddings": self._hf_config.max_position_embeddings, - "num_attention_heads": self._hf_config.num_attention_heads, - "num_hidden_layers": mode_cfg["config"]["mtp_num_layers"], - "num_mtp_module": mode_cfg["config"]["mtp_num_module"], - "num_key_value_heads": self._hf_config.num_key_value_heads, - "rms_norm_eps": self._hf_config.rms_norm_eps, - "rope_theta": self._hf_config.rope_theta, - "use_input_layernorm_in_first_layer": True, - "use_last_layernorm": False, - } - self._hf_config.mtp = mtp_config - def save_pretrained( self, save_directory: str | os.PathLike, @@ -314,6 +304,13 @@ def save_pretrained( pp_rank = get_pipeline_model_parallel_rank() pp_size = get_pipeline_model_parallel_world_size() + # We use the 1st PP rank to handle VLM because vision_models + # and vision_proj only exist in the first stage. + is_first_stage_main_rank = pp_rank == 0 + # We use the last PP rank to write the config because + # medusa_heads and eagle_module only exist in the last stage. + is_last_stage_main_rank = pp_rank == pp_size - 1 + # Main export process state_dict = self.extra_state_dict if self.export_extra_modules else self.state_dict quantization_format = get_quantization_format(self.model) @@ -330,12 +327,10 @@ def save_pretrained( elif quantization_format == QUANTIZATION_NVFP4: quantization = "NVFP4" - # TODO (chenhany): need to handle Medusa and EAGLE meatadata - if torch.distributed.get_rank() == torch.distributed.get_world_size() - 1: + # We use the last PP rank and the 1st EP rank to write the config because + # medusa_heads and eagle_module only exist in the last stage. + if is_last_stage_main_rank: if self.export_extra_modules and self._hf_extra_config is not None: - # os.makedirs(save_directory, exist_ok=True) - # with open(save_directory + "/config.json", 'w') as file: - # json.dump(self._hf_extra_config, file, indent=4) self._hf_extra_config.save_pretrained(save_directory) else: self._hf_config.save_pretrained(save_directory) @@ -365,7 +360,7 @@ def save_pretrained( except (OSError, ValueError, ImportError): pass - if torch.distributed.get_rank() == torch.distributed.get_world_size() - 1: + if is_last_stage_main_rank: hf_quant_config = { "producer": { "name": "modelopt", @@ -381,7 +376,7 @@ def save_pretrained( json.dump(hf_quant_config, f, indent=4) if ( - torch.distributed.get_rank() == 0 + is_first_stage_main_rank and self.is_multimodal and pretrained_model_name_or_path is not None ): @@ -453,10 +448,11 @@ def save_pretrained( torch.distributed.barrier() if self.export_extra_modules: - if pp_rank == pp_size - 1: + if is_last_stage_main_rank: save_file( state_dict, save_directory + "/model.safetensors", metadata={"format": "pt"} ) + torch.distributed.barrier() return save_safetensors(state_dict, save_directory) @@ -473,7 +469,6 @@ def extra_state_dict(self): if len(self._state_dict) == 0: self._get_medusa_heads_state_dict() self._get_eagle_module_state_dict() - self._get_mtp_state_dict() return self._state_dict def _populate_rule_book(self): @@ -485,6 +480,7 @@ def _custom_mapping_to_lambda(mapping): "qkv_slicing": self._qkv_slicing, "gated_mlp_slicing": self._gated_mlp_slicing, "pack_name_remapping": self._pack_name_remapping, + "pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss, } func = method_map[mapping.func_name] prefix = mapping.target_name_or_prefix @@ -799,6 +795,120 @@ def _pack_name_remapping(self, module, prefix, layer_type=None): if merged_input_scale is not None: self._state_dict[prefix + "_input_scale"] = merged_input_scale + def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): + """Pack name remapping into one tensor.""" + weight_list = [] + weight_scale_list = [] + weight_scale_2_list = [] + input_scale_list = [] + bias_list = [] + + for expert in module: + assert layer_type is not None, "layer_type is required for pack_name_remapping" + name_to_value, qformat, block_size = get_quantized_state( + getattr(expert, layer_type), self.dtype + ) + weight = name_to_value.pop("weight") + bias = name_to_value.pop("bias", None) + weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) + input_scale = ( + name_to_value.pop("input_scale") if "input_scale" in name_to_value else None + ) + + weight_list.append(weight) + weight_scale_list.append(weight_scale) + weight_scale_2_list.append(weight_scale_2) + input_scale_list.append(input_scale) + bias_list.append(bias) + + merged_weight = torch.stack(weight_list, dim=0) + + # Transpose the last two dimensions to match HuggingFace format (except for GptOssForCausalLM) + # NeMo format: [num_experts, out_features, in_features] + # HF format: [num_experts, in_features, out_features] + + # TODO: Need to decide if we want to transpose the weight or not. + merged_weight = merged_weight.transpose(-2, -1).contiguous() + + # Apply interleaving for GptOssForCausalLM linear_fc1 to match HF format + if layer_type == "linear_fc1": + # Megatron has de-interleaved format, need to interleave for HF + # Pattern: first half -> even indices, second half -> odd indices + num_experts, in_features, out_features = merged_weight.shape + half_out = out_features // 2 + + # Create interleaved tensor + interleaved_weight = torch.zeros_like(merged_weight) + interleaved_weight[:, :, ::2] = merged_weight[ + :, :, :half_out + ] # First half -> even indices + interleaved_weight[:, :, 1::2] = merged_weight[ + :, :, half_out: + ] # Second half -> odd indices + merged_weight = interleaved_weight + + # Handle bias tensors + merged_bias = None + if bias_list[0] is not None: + merged_bias = torch.stack(bias_list, dim=0) + + # Apply interleaving for GptOssForCausalLM linear_fc1 bias to match HF format + if layer_type == "linear_fc1": + num_experts, bias_len = merged_bias.shape + half_bias_len = bias_len // 2 + + # Create interleaved bias tensor + interleaved_bias = torch.zeros_like(merged_bias) + interleaved_bias[:, ::2] = merged_bias[ + :, :half_bias_len + ] # First half -> even indices + interleaved_bias[:, 1::2] = merged_bias[ + :, half_bias_len: + ] # Second half -> odd indices + merged_bias = interleaved_bias + + if weight_scale_2_list[0] is None: + merged_weight_scale_2 = None + if weight_scale_list[0] is not None: + merged_weight_scale = torch.max(torch.stack(weight_scale_list, dim=0), dim=0)[0] + else: + merged_weight_scale = None + else: + # NVFP4 + merged_weight_scale_2 = torch.max(torch.stack(weight_scale_2_list, dim=0), dim=0)[0] + merged_weight_scale = torch.stack(weight_scale_list, dim=0) + # Transpose the scaling factors to match the transposed weights + # TODO: Need to decide if we want to transpose the weight or not. + merged_weight_scale = merged_weight_scale.transpose(-2, -1).contiguous() + + if input_scale_list[0] is not None: + merged_input_scale = torch.max(torch.stack(input_scale_list, dim=0), dim=0)[0] + else: + merged_input_scale = None + + # Save the merged weights + if merged_weight_scale is None: + # TODO: May need to modify the key name later. + self._state_dict[prefix] = merged_weight + else: + self._state_dict[prefix] = to_quantized_weight( + merged_weight, + merged_weight_scale, + qformat, + merged_weight_scale_2, + block_size, + ) + self._state_dict[prefix + "_weight_scale"] = merged_weight_scale + if merged_weight_scale_2 is not None: + self._state_dict[prefix + "_weight_scale_2"] = merged_weight_scale_2 + if merged_input_scale is not None: + self._state_dict[prefix + "_input_scale"] = merged_input_scale + + # Save bias tensors if they exist + if merged_bias is not None: + # TODO: May need to modify the key name later. + self._state_dict[prefix + "_bias"] = merged_bias + def _get_medusa_heads_state_dict(self): medusa_heads = getattr(self.model, "medusa_heads", None) if medusa_heads is None: @@ -819,23 +929,24 @@ def _get_eagle_module_state_dict(self): # self.rules["word_embeddings"](self.model.embedding.word_embeddings) self.rules["fc"](eagle_module.fc) - if self.model.use_aux_hidden_state: + if self.model.eagle_config.use_aux_hidden_state: self.rules["enorm"](eagle_module.enorm) - elif self.model.use_mtp_layernorm: + elif self.model.eagle_config.use_mtp_layernorm: self.rules["enorm"](eagle_module.enorm) self.rules["hnorm"](eagle_module.hnorm) - if self.model.use_last_layernorm: + if self.model.eagle_config.use_last_layernorm: self.rules["final_layernorm"](eagle_module.decoder.final_layernorm) - if self.model.draft_vocab_size > 0: + if hasattr(self.model.eagle_module, "eagle_output_layer"): self.rules["output_layer"](eagle_module.eagle_output_layer) + if hasattr(self.model.eagle_module, "dt2"): self.rules["d2t"](eagle_module.d2t) for layer in eagle_module.decoder.layers: layer_id = layer.layer_number - 1 - if layer_id > 0 or self.model.use_input_layernorm_in_first_layer: + if layer_id > 0 or self.model.eagle_config.use_input_layernorm_in_first_layer: self.rules["input_layernorm"](layer.input_layernorm, layer_id) if "MLASelfAttention" in str(type(layer.self_attention)): @@ -889,72 +1000,6 @@ def _get_eagle_module_state_dict(self): self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) - def _get_mtp_state_dict(self): - mtp = getattr(self.model, "mtp", None) - if mtp is None: - return - - for module_id, module in enumerate(mtp): - self.rules["mtp.fc"](module.fc, module_id) - self.rules["mtp.enorm"](module.enorm, module_id) - self.rules["mtp.hnorm"](module.hnorm, module_id) - for layer in module.decoder.layers: - layer_id = layer.layer_number - 1 - self.rules["mtp.input_layernorm"](layer.input_layernorm, module_id, layer_id) - - if "MLASelfAttention" in str(type(layer.self_attention)): - if hasattr(layer.self_attention, "linear_q_proj"): - self.rules["mtp.linear_q_proj"]( - layer.self_attention.linear_q_proj, module_id, layer_id - ) - else: - self.rules["mtp.linear_q_down_proj"]( - layer.self_attention.linear_q_down_proj, module_id, layer_id - ) - self.rules["mtp.linear_q_layernorm"]( - layer.self_attention.q_layernorm, module_id, layer_id - ) - self.rules["mtp.linear_q_up_proj"]( - layer.self_attention.linear_q_up_proj, module_id, layer_id - ) - - self.rules["mtp.linear_kv_down_proj"]( - layer.self_attention.linear_kv_down_proj, module_id, layer_id - ) - self.rules["mtp.linear_kv_layernorm"]( - layer.self_attention.kv_layernorm, module_id, layer_id - ) - self.rules["mtp.linear_kv_up_proj"]( - layer.self_attention.linear_kv_up_proj, module_id, layer_id - ) - else: - self.rules["mtp.linear_qkv"]( - layer.self_attention.linear_qkv, module_id, layer_id - ) - - self.rules["mtp.linear_proj"](layer.self_attention.linear_proj, module_id, layer_id) - self.rules["mtp.pre_mlp_layernorm"](layer.pre_mlp_layernorm, module_id, layer_id) - - if "MoE" in str(type(layer.mlp)): - self.rules["mtp.router"](layer.mlp.router, module_id, layer_id) - if hasattr(layer.mlp, "shared_experts"): - self.rules["mtp.shared_experts.linear_fc1"]( - layer.mlp.shared_experts.linear_fc1, module_id, layer_id - ) - self.rules["mtp.shared_experts.linear_fc2"]( - layer.mlp.shared_experts.linear_fc2, module_id, layer_id - ) - for expert_id, expert in enumerate(layer.mlp.experts.local_experts): - self.rules["mtp.local_experts.linear_fc1"]( - expert.linear_fc1, module_id, layer_id, expert_id - ) - self.rules["mtp.local_experts.linear_fc2"]( - expert.linear_fc2, module_id, layer_id, expert_id - ) - else: - self.rules["mtp.linear_fc1"](layer.mlp.linear_fc1, module_id, layer_id) - self.rules["mtp.linear_fc2"](layer.mlp.linear_fc2, module_id, layer_id) - def _get_state_dict(self): model = self.model @@ -1029,6 +1074,10 @@ def _get_state_dict(self): self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) + if hasattr(layer.self_attention.core_attention, "softmax_offset"): + self.rules["softmax_offset"]( + layer.self_attention.core_attention.softmax_offset, layer_id + ) if not isinstance(layer.pre_mlp_layernorm, IdentityOp): self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) @@ -1084,8 +1133,8 @@ def export_mcore_gpt_to_hf( pretrained model hosted inside a model repo on huggingface.co; or a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - export_extra_modules: If True, export extra modules like medusa_heads, - eagle_module, or mtp. Otherwise, only export the base model. + export_extra_modules: If True, export extra modules like medusa_heads or + eagle_module. Otherwise, only export the base model. dtype: The weights data type to export the unquantized layers. export_dir: The target export path. """ diff --git a/modelopt/torch/opt/plugins/megatron_model_config.py b/modelopt/torch/opt/plugins/megatron_model_config.py deleted file mode 100644 index c54202915..000000000 --- a/modelopt/torch/opt/plugins/megatron_model_config.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Megatron-Core model config (TransformerConfig+).""" - -from collections.abc import Callable -from dataclasses import dataclass - -import torch.nn.functional as F -from megatron.core.transformer.transformer_config import TransformerConfig - - -@dataclass -class Llama31Config8B(TransformerConfig): - """Configuration class for GPT models. - - Extends TransformerConfig with additional parameters specific to GPT models - and provides utility methods for model configuration. - """ - - # From megatron.core.models.gpt.gpt_model.GPTModel - transformer_layer_spec = None - vocab_size: int = None - max_sequence_length: int = 8192 - position_embedding_type = "rope" - rotary_percent: float = 1.0 - rotary_base: int = 500000 - rope_scaling: bool = True - rope_scaling_factor: float = 8.0 - - # Specific TransformerConfig - seq_length: int = 8192 - num_layers: int = 32 - hidden_size: int = 4096 - ffn_hidden_size: int = 14336 - kv_channels: int = 128 - num_attention_heads: int = 32 - num_query_groups: int = 8 - init_method_std: float = 0.01 - normalization: str = "RMSNorm" - layernorm_epsilon: float = 1.0e-05 - activation_func: Callable = F.silu - gated_linear_unit: bool = True - add_bias_linear: bool = False - attention_dropout: float = 0.0 - hidden_dropout: float = 0.0 - - # Different from the default values in TransformerConfig - attention_softmax_in_fp32: bool = False - gradient_accumulation_fusion: bool = False diff --git a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py index 717e3c5d0..dbff2b67d 100644 --- a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py +++ b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py @@ -24,6 +24,7 @@ from modelopt.torch.quantization.config import FP8_DEFAULT_CFG from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear from modelopt.torch.quantization.qtensor import FP8QTensor, QTensorWrapper +from modelopt.torch.quantization.utils import reduce_amax from .utils import fp8_compatible @@ -31,44 +32,67 @@ FP8_MAX = torch.finfo(torch.float8_e4m3fn).max -def _to_fp8(x, scale): - return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) - - def fp8_per_tensor_gemm(quant_module, input, bias=None): """GEMM function for fp8 per tensor quantization.""" - weight_amax = quant_module.weight_quantizer.amax - if weight_amax is None: - weight_amax = quant_module.weight.abs().amax() + + @torch.compile(dynamic=True) + def _to_fp8(x, scale): + return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) + + @torch.compile(dynamic=True) + def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None): + input_shape = input.shape + input_fp8 = _to_fp8(input, scale_a).reshape(-1, input_shape[-1]) + weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t() + output = torch._scaled_mm( + input_fp8, + weight_fp8_t, + scale_a=scale_a, + scale_b=scale_b, + bias=bias, + out_dtype=input.dtype, + use_fast_accum=True, + ) + return output.reshape(*input_shape[:-1], output.shape[-1]) + + cached_scale_a = hasattr(quant_module, "_scale_a") input_amax = quant_module.input_quantizer.amax if input_amax is None: - input_amax = input.abs().amax() + cached_scale_a = False + input_amax = reduce_amax(input) + assert input_amax != 0 + + if not cached_scale_a: + quant_module._scale_a = (input_amax.float() / 448.0).to(device=input.device) + + cached_scale_b = ( + hasattr(quant_module, "_scale_b") and quant_module.weight.dtype == torch.float8_e4m3fn + ) + weight_amax = quant_module.weight_quantizer.amax + if weight_amax is None: + cached_scale_b = False + weight_amax = reduce_amax(quant_module.weight) + assert weight_amax != 0 + + if not cached_scale_b: + quant_module._scale_b = (weight_amax.float() / 448.0).to(device=quant_module.weight.device) + if quant_module.weight.dtype != torch.float8_e4m3fn: - weight_fp8 = _to_fp8(quant_module.weight, weight_amax / 448.0) + weight_fp8 = _to_fp8(quant_module.weight, quant_module._scale_b) else: - weight_fp8 = quant_module.weight - # If input_amax is 0, it means the input is all zeros. So we set it to a small value to avoid division by zero. - if input_amax == 0: - input_amax = torch.tensor(1e-5, device=input_amax.device, dtype=input_amax.dtype) - weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t() - input_fp8 = _to_fp8(input, input_amax / 448.0) - - scale_a = (input_amax / 448.0).to(device=input_fp8.device, dtype=torch.float32) - scale_b = (weight_amax / 448.0).to(device=input_fp8.device, dtype=torch.float32) - - ouptut = torch._scaled_mm( - input_fp8.reshape(-1, input_fp8.shape[-1]), - weight_fp8_t, - scale_a=scale_a, - scale_b=scale_b, + weight_fp8 = quant_module.weight.data + + output = _fp8_gemm_impl( + input, + weight_fp8, + scale_a=quant_module._scale_a, + scale_b=quant_module._scale_b, bias=bias if input.dtype != torch.float32 else None, - out_dtype=input.dtype, - use_fast_accum=False, ) # _scaled_mm does not support bias for float32 input, so we add it manually if input.dtype == torch.float32 and bias is not None: - ouptut += bias - return ouptut.reshape(*input.shape[:-1], ouptut.shape[-1]) + output += bias + return output.reshape(*input.shape[:-1], output.shape[-1]) def _fp8_availability_check(module, input, args, kwargs): @@ -114,7 +138,13 @@ class Fp8PerTensorLinear(Function): @staticmethod def forward( - ctx, quant_module, input_tensor, weight, bias=None, allreduce_dgrad=False, tp_group=None + ctx, + quant_module, + input_tensor, + weight, + bias=None, + allreduce_dgrad=False, + tp_group=None, ): """Forward method.""" ctx.save_for_backward( diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index 1381e7676..ae5fab291 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -34,7 +34,7 @@ def _to_nvfp4(inputs, amax=None): vec_size = 16 if amax is None: amax = inputs.abs().amax() - global_scale = 448.0 * 6.0 / (amax.float()) + global_scale = 448.0 * 6.0 / amax.float() fp4, scale = torch.ops.trtllm.fp4_quantize(inputs, global_scale, vec_size, False) return fp4, scale, global_scale @@ -65,13 +65,14 @@ def nvfp4_gemm(quant_module, input_tensor, bias=None): if input_amax is None: input_amax = input_tensor.abs().amax() input_global_scale = 448.0 * 6.0 / input_amax.float() + alpha = 1.0 / (weight_global_scale * input_global_scale) output = torch.ops.auto_deploy.torch_quant_fp4_linear( input_tensor, weight_fp4, bias=bias, input_scale=input_global_scale, weight_scale=weight_scale, - alpha=1.0 / weight_global_scale / input_global_scale, + alpha=alpha, ) return output diff --git a/modelopt/torch/quantization/nn/modules/quant_linear.py b/modelopt/torch/quantization/nn/modules/quant_linear.py index 16259450d..23795ad0c 100644 --- a/modelopt/torch/quantization/nn/modules/quant_linear.py +++ b/modelopt/torch/quantization/nn/modules/quant_linear.py @@ -173,9 +173,10 @@ def forward(self, input, *args, **kwargs): # Note: We cache the real-quant GEMM function to avoid matching overhead. # This assumes that the function will not change after the first call. if self._real_quant_gemm_impl: - output = self._real_quant_gemm_impl( - self, input, self.weight, self.bias, *args, **kwargs - ) + with torch.cuda.nvtx.range("RealQuantLinear gemm"): + output = self._real_quant_gemm_impl( + self, input, self.weight, self.bias, *args, **kwargs + ) return ( self.output_quantizer(output) if hasattr(self, "output_quantizer") else output ) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index ffbdb067b..560e7e4bc 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -630,7 +630,7 @@ def _real_quantize(self, inputs): outputs, _weights_scaling_factor, _weights_scaling_factor_2 = NVFP4QTensor.quantize( inputs, self._block_sizes[-1], - weights_scaling_factor_2=self.amax.float() / 448.0 / 6.0 + weights_scaling_factor_2=self.amax.float() / (448.0 * 6.0) if self.amax is not None else None, try_tensorrt=True, @@ -698,6 +698,8 @@ def _fake_quantize(self, inputs): self._narrow_range, self._trt_high_precision_dtype, self._pass_through_bwd, + self.block_sizes.get(-1) if self.block_sizes else None, + self.axis[0] if isinstance(self.axis, tuple) else self.axis, ) return outputs diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 4ea9d6d31..1df4c3415 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -324,6 +324,27 @@ def convert(cls, module: nn.Module) -> "_QuantConv1D": return dyn_cls.convert(module) +class _TransposedQuantization(torch.autograd.Function): + """Applies transposed quantization. + + This is useful for weight quantization of some MoEs such as gpt-oss or Llama4 which has expert weights + of shape (num_experts, in_dim, out_dim). Per-channel/Per-block quantization from ModelOpt + assumes that `in_dim` is -1 dim. Hence for quantizing such MoE weights, lets use transposed quantization. + """ + + # Note: TransposedQuantization uses STE with no clipping + @staticmethod + def forward(ctx, inputs, quantizer): + return quantizer(inputs.transpose(-1, -2).contiguous()).transpose(-1, -2) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +_transposed_quantize = _TransposedQuantization.apply + + class _QuantMoeSparseMoe(QuantModule): def _setup(self): pass @@ -351,12 +372,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) gate_up = torch.bmm( self.gate_up_proj_input_quantizer(hidden_states), - self.gate_up_proj_weight_quantizer(self.gate_up_proj), + _transposed_quantize(self.gate_up_proj, self.gate_up_proj_weight_quantizer), ) gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors next_states = torch.bmm( self.down_proj_input_quantizer(up * self.act_fn(gate)), - self.down_proj_weight_quantizer(self.down_proj), + _transposed_quantize(self.down_proj, self.down_proj_weight_quantizer), ) next_states = next_states.view(-1, self.hidden_size) return next_states @@ -528,17 +549,15 @@ def top_k(self, value): except ImportError: pass +try: + from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock -# TODO: Use autograd wrapped matmul instead of this for memory efficiency -class _TransposedQuantization(torch.autograd.Function): - # Note: TransposedQuantization uses STE with no clipping - @staticmethod - def forward(ctx, inputs, quantizer): - return quantizer(inputs.transpose(-1, -2).contiguous()).transpose(-1, -2) - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None + if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry: + QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})( + _QuantMoeSparseMoe + ) +except ImportError: + pass class _QuantGptOssExperts(_QuantFunctionalMixin): @@ -557,7 +576,7 @@ def _get_quantized_weight(quantizer, module, weight): if module._enable_weight_quantization: if hasattr(quantizer, "_cached_quant_val"): return getattr(quantizer, "_cached_quant_val") - quantizer._cached_quant_val = _TransposedQuantization.apply(weight, quantizer) + quantizer._cached_quant_val = _transposed_quantize(weight, quantizer) return quantizer._cached_quant_val return weight diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 7f503d6a7..589f30b91 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -22,8 +22,10 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import torch +from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint +from megatron.core.utils import get_tensor_model_parallel_group_if_none from modelopt.torch.opt.plugins.megatron import ( _MegatronMLP, @@ -421,6 +423,17 @@ class forward(). This is not desired since _forward_impl introduces much more ar ): allreduce_dgrad = kwargs.get("allreduce_dgrad", False) tp_group = kwargs.get("tp_group") + sequence_parallel = kwargs.get("sequence_parallel", False) + + tp_group = get_tensor_model_parallel_group_if_none(tp_group) + + if sequence_parallel: + input = gather_from_sequence_parallel_region( + input, tensor_parallel_output_grad=True, group=tp_group + ) + else: + input = input + return RealQuantLinear.forward( self, input, diff --git a/modelopt/torch/quantization/plugins/peft.py b/modelopt/torch/quantization/plugins/peft.py index 15aae8d8f..a5a696d24 100644 --- a/modelopt/torch/quantization/plugins/peft.py +++ b/modelopt/torch/quantization/plugins/peft.py @@ -25,7 +25,7 @@ from modelopt.torch.quantization.qtensor.base_qtensor import QTensorWrapper from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer -from .huggingface import _TransposedQuantization +from .huggingface import _transposed_quantize __all__ = [] @@ -137,7 +137,7 @@ def _activate_lora(self, active_adapters: list[str]): quantizer = getattr(base_layer, self.parameter_name + "_weight_quantizer") with base_layer.reset_dynamic_attributes(): base_param = getattr(base_layer, self.parameter_name) - quantized_val = _TransposedQuantization.apply( + quantized_val = _transposed_quantize( base_param if delta_weight is None else base_param + delta_weight, quantizer ) delattr(base_layer, self.parameter_name) diff --git a/modelopt/torch/quantization/qtensor/mxfp4_tensor.py b/modelopt/torch/quantization/qtensor/mxfp4_tensor.py index e1298e3c6..022825e40 100644 --- a/modelopt/torch/quantization/qtensor/mxfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/mxfp4_tensor.py @@ -46,10 +46,9 @@ def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple: def cast_fp4(x): sign = torch.sign(x) sign_bit = (2 - sign) // 2 - # TODO: Optimize this, currently its on cpu and is slower ord_ = torch.sum( - (x.abs().unsqueeze(-1).cpu() - MXFP4QTensor.E2M1_bounds) > 0, dim=-1 - ).to(x.device) + (x.abs().unsqueeze(-1) - MXFP4QTensor.E2M1_bounds.to(x.device)) > 0, dim=-1 + ) fp4_val = (sign_bit * 0b1000 + ord_).to(torch.uint8) return fp4_val diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index b67f90950..29131ff7a 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -15,12 +15,11 @@ """Implements NVFP4 quantization for efficient tensor storage and computation.""" -import numpy as np import torch from ..backends.utils import fp4_compatible from ..qtensor.base_qtensor import BaseQuantizedTensor -from ..utils import reduce_block_padding +from ..utils import reduce_amax, reduce_block_amax, reduce_block_padding # Define conversion tables e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) @@ -37,6 +36,7 @@ class NVFP4QTensor(BaseQuantizedTensor): """ e2m1_values_on_device = {} + e2m1_bounds_on_device = {} @classmethod def get_e2m1_values(cls, device): @@ -45,12 +45,19 @@ def get_e2m1_values(cls, device): cls.e2m1_values_on_device[device] = e2m1_values.to(device) return cls.e2m1_values_on_device[device] + @classmethod + def get_e2m1_bounds(cls, device): + """Returns the e2m1 values on the device.""" + if device not in cls.e2m1_bounds_on_device: + cls.e2m1_bounds_on_device[device] = e2m1_bounds.to(device) + return cls.e2m1_bounds_on_device[device] + @classmethod def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): """Returns per tensor weight scaling factor from the weight_quantizer amax.""" # Assert that weight_quantizer has attribute amax assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute amax" - return weight_quantizer._amax.float() / 6.0 / 448.0 + return weight_quantizer._amax.float() / (6.0 * 448.0) @classmethod def get_weights_scaling_factor( @@ -65,31 +72,27 @@ def get_weights_scaling_factor( weights_scaling_factor_2 = cls.get_weights_scaling_factor_2(input) # Get per_block amax - [n, k] = input.shape[-2:] assert block_size != 0, "Block size is zero. Cannot return per_block amax for given input." - assert k % block_size == 0, ( + assert input.shape[-1] % block_size == 0, ( "Weight shape is not divisible for block size for block quantiation." ) - input = input.reshape((*tuple(input.shape[:-2]), n, k // block_size, block_size)) # Get per block amax - per_block_amax = input.abs().amax(dim=-1).float() + per_block_amax = reduce_block_amax(input, block_sizes={-1: block_size}).float() # Get per-block-scale - per_block_scale = per_block_amax / 6.0 - # Quantize per_block_scale to FP8 - q_per_block_scale = per_block_scale / weights_scaling_factor_2 + per_block_scale = per_block_amax / (6.0 * weights_scaling_factor_2) # Set all zero values in scale to 1.0 - q_per_block_scale[per_block_scale == 0] = 1.0 + per_block_scale[per_block_scale == 0] = 1.0 # Convert to torch.float8_e4m3fn if not keep_high_precision: - q_per_block_scale = q_per_block_scale.to(torch.float8_e4m3fn) - return q_per_block_scale, weights_scaling_factor_2 + per_block_scale = per_block_scale.to(torch.float8_e4m3fn) + return per_block_scale, weights_scaling_factor_2 @classmethod def get_weights_scaling_factor_2(cls, input: torch.Tensor): """Returns per tensor weight scaling factor.""" - return input.abs().amax().float() / 6.0 / 448.0 + return reduce_amax(input).float() / (6.0 * 448.0) @classmethod def get_activation_scaling_factor(cls, quantizer): @@ -103,8 +106,7 @@ def get_activation_scaling_factor(cls, quantizer): if amax is None: return None - activation_scaling_factor = amax.float() / (quantizer.maxbound) - activation_scaling_factor = activation_scaling_factor / 448.0 + activation_scaling_factor = amax.float() / (quantizer.maxbound * 448.0) assert torch.all(activation_scaling_factor > 0), ( f" activation scaling factor {activation_scaling_factor} not positive." @@ -112,26 +114,28 @@ def get_activation_scaling_factor(cls, quantizer): return activation_scaling_factor - @staticmethod - def _cast_fp4(weight: torch.Tensor): + @classmethod + def _cast_fp4(cls, weight: torch.Tensor): """Converts tensor to uint4.""" - # Get device device = weight.device - # Define mask to perform rounding - mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device) - mask_shape = list(weight.shape) - mask = mask.expand([*mask_shape, 7]) - + # Extract sign and compute absolute values in one pass sign_bit = (weight < 0).to(torch.uint8) - weight_abs = weight.abs_() - # Calculate the ordinal value based on the bounds - ord = torch.searchsorted(e2m1_bounds.to(device), weight_abs, out_int32=True).to(torch.uint8) - # All values equal to e2m1_bounds at odd indices are rounded up and even indices are rounded down - round = torch.any((weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) * mask, dim=-1) - fp4_val = (sign_bit * 0b1000 + ord + round).to(torch.uint8) - return fp4_val + + # Get bounds and compute ordinal values + e2m1_bounds = cls.get_e2m1_bounds(device) + ord = torch.searchsorted(e2m1_bounds, weight_abs, out_int32=True).to(torch.uint8) + + # Efficiently check for rounding at odd-indexed bounds [0.75, 1.75, 2.5] + # Only need to check bounds at indices 1, 3, 5 + odd_bounds = e2m1_bounds[[1, 3, 5]] # [0.75, 1.75, 2.5] + equals_odd_bounds = torch.any(weight_abs.unsqueeze(-1) == odd_bounds, dim=-1).to( + torch.uint8 + ) + + # Combine sign, ordinal, and rounding adjustment + return (sign_bit << 3) + ord + equals_odd_bounds @classmethod def quantize( @@ -171,6 +175,8 @@ def quantize( and weights_scaling_factor is None and try_tensorrt and block_size == 16 + and input.is_cuda + and input.dtype in [torch.half, torch.bfloat16] ): try: import tensorrt_llm # noqa: F401 @@ -200,6 +206,7 @@ def quantize( ) # Reshape the weight and scale factors + original_shape = input.shape input = input.view((*tuple(input.shape[:-1]), -1, block_size)) # Scale weights @@ -208,7 +215,7 @@ def quantize( ) # Reshape weights to original - scaled_weight = scaled_weight.view((*tuple(scaled_weight.shape[:-2]), -1)) + scaled_weight = scaled_weight.view(original_shape) if keep_high_precision: return scaled_weight @@ -222,17 +229,16 @@ def quantize( weights_scaling_factor_2, ) - def dequantize(self, dtype: torch.dtype = None, **kwarg): + def dequantize(self, dtype: torch.dtype = None, fast=False, **kwarg): """Dequantze NVFP4 packed tensor to a target dtype.""" if dtype is None: dtype = self.metadata["dtype"] def _unpack_tensor(input: torch.Tensor): # Initalize storage for unpacked tensor - unpacked = torch.empty( - [input.shape[0], input.shape[1] * 2], dtype=dtype, device=input.device - ) - unpacked_shape = unpacked.shape + unpacked_shape = list(input.shape) + unpacked_shape[-1] = unpacked_shape[-1] * 2 + unpacked = torch.empty(unpacked_shape, dtype=dtype, device=input.device) unpacked[..., 1::2] = input >> 4 unpacked[..., 0::2] = input & 0x0F @@ -257,26 +263,33 @@ def _unpack_tensor(input: torch.Tensor): raise ImportError( "This tensor is quantized by trtllm, but tensorrt_llm cannot be imported." ) from e - q_per_block_scale = ( - kwarg["scale"].to(torch.float32) - if kwarg["scale"].dtype == torch.float8_e4m3fn - else kwarg["scale"] - ) - block_sizes = kwarg["block_sizes"][-1] - per_block_quant_scale = kwarg["double_scale"] - # Dequantize scales - per_block_scale = q_per_block_scale * per_block_quant_scale + if fast: + from ..triton.fp4_kernel import fp4_dequantize + + return fp4_dequantize( + self._quantized_data, + kwarg["scale"], + kwarg["double_scale"], + block_size=kwarg["block_sizes"][-1], + dtype=dtype, + ).reshape(self.metadata["shape"]) + else: + q_per_block_scale = ( + kwarg["scale"].to(torch.float32) + if kwarg["scale"].dtype == torch.float8_e4m3fn + else kwarg["scale"] + ) + block_size = kwarg["block_sizes"][-1] + per_block_quant_scale = kwarg["double_scale"] - # Unpack and unscale weights - deq_data = _unpack_tensor(self._quantized_data) + # Dequantize scales + per_block_scale = q_per_block_scale * per_block_quant_scale - deq_data = deq_data.view( - deq_data.shape[0], deq_data.shape[1] // block_sizes, -1 - ) * per_block_scale.unsqueeze(-1) + # Unpack and unscale weights + deq_data = _unpack_tensor(self._quantized_data) - return ( - deq_data.view(-1)[: np.prod(self.metadata["shape"])] - .reshape(self.metadata["shape"]) - .to(dtype) - ) + deq_data = deq_data.view( + (*tuple(deq_data.shape[:-1]), -1, block_size) + ) * per_block_scale.unsqueeze(-1) + return deq_data.reshape(self.metadata["shape"]).to(dtype) diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 36b6cc0f3..bd4cf644c 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -317,14 +317,19 @@ def _fake_quant_backward_function(ctx, grad_outputs, num_args=1): def _save_for_backward_if_needed(ctx, pass_through_bwd, inputs, amax): if not pass_through_bwd and amax is not None: - ctx.save_for_backward(inputs, torch.tensor(amax).to(inputs.device, inputs.dtype)) + amax = ( + amax + if isinstance(amax, torch.Tensor) + else torch.tensor(amax, device=inputs.device, dtype=inputs.dtype) + ) + ctx.save_for_backward(inputs, amax) class FakeTensorQuantFunction(Function): """Fake version of TensorQuantFunction use CUDA extension.""" @staticmethod - @symbolic_helper.parse_args("v", "t", "t", "i", "b", "b", "s", "b") + @symbolic_helper.parse_args("v", "t", "t", "i", "b", "b", "s", "b", "i", "i") def symbolic( g, inputs, @@ -335,9 +340,16 @@ def symbolic( narrow_range=True, trt_high_precision_dtype=None, pass_through_bwd=False, + block_size=None, + axis=None, ): """ONNX symbolic function.""" - from .export_onnx import export_int8 + from .export_onnx import export_int4, export_int8 + + if num_bits == 4: + return export_int4( + g, inputs, amax, num_bits, trt_high_precision_dtype, block_size, axis + ) return export_int8( g, inputs, amax, num_bits, unsigned, narrow_range, trt_high_precision_dtype @@ -354,6 +366,8 @@ def forward( narrow_range=True, trt_high_precision_dtype=None, pass_through_bwd=False, + block_size=None, + axis=None, ): """Forward method.""" if bias is not None: @@ -391,7 +405,7 @@ def legacy_quant_func(): @staticmethod def backward(ctx, grad_outputs): """Implements straight through estimation with clipping.""" - return _fake_quant_backward_function(ctx, grad_outputs, num_args=8) + return _fake_quant_backward_function(ctx, grad_outputs, num_args=10) class ScaledE4M3Function(Function): diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 7ac296fe0..33b38b7a9 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -148,7 +148,7 @@ def fp4_fake_quant_block( triton.cdiv(M, meta["TILE_SIZE"]), triton.cdiv(N, meta["TILE_SIZE"]), ) - global_scale = (global_amax / 6.0) / 448.0 + global_scale = global_amax.float() / (6.0 * 448.0) num_fp4_blocks = tile_size // block_size fp4_fake_quant_kernel[grid]( x, @@ -162,3 +162,142 @@ def fp4_fake_quant_block( ) y = y.reshape(x_shape).contiguous().to(dtype=x_dtype) return y + + +@triton.jit +def fp4_dequantize_kernel( + packed_ptr, + scale_ptr, + global_scale_ptr, + output_ptr, + N, + BLOCK_SIZE: tl.constexpr, + TILE_SIZE: tl.constexpr, +): + """Dequantizes FP4 packed data using per-block scaling factors. + + Args: + packed_ptr (tl.pointer): Pointer to packed uint8 tensor (M x N//2) + scale_ptr (tl.pointer): Pointer to per-block scale tensor (M x N//BLOCK_SIZE) + output_ptr (tl.pointer): Pointer to output tensor (M x N) + global_scale_ptr (tl.pointer): Pointer to global scale tensor + N (int): Number of columns in unpacked tensor + BLOCK_SIZE (tl.constexpr): Size of each FP4 quantization block + TILE_SIZE (tl.constexpr): Size of the processing tile (in packed elements) + """ + # Get program ID for processing packed elements + pid = tl.program_id(0) + + # Calculate packed element offsets (each packed element contains 2 FP4 values) + packed_start = pid * TILE_SIZE + packed_offs = packed_start + tl.arange(0, TILE_SIZE) + + # Calculate 2D coordinates for packed data + packed_row_idx = packed_offs // (N // 2) + packed_col_idx = packed_offs % (N // 2) + + # Create mask for packed data bounds checking + packed_mask = packed_col_idx < (N // 2) + + # Load global scale + global_scale = tl.load(global_scale_ptr) + + # Load packed data + packed_data = tl.load(packed_ptr + packed_offs, mask=packed_mask, other=0) + + # Unpack packed FP4 values (uint8) to float16x2 + x_f16x2_packed = tl.inline_asm_elementwise( + asm=""" + { + .reg .b8 byte0, byte1, byte2, byte3; + mov.b32 {byte0, byte1, byte2, byte3}, $4; + cvt.rn.f16x2.e2m1x2 $0, byte0; + cvt.rn.f16x2.e2m1x2 $1, byte1; + cvt.rn.f16x2.e2m1x2 $2, byte2; + cvt.rn.f16x2.e2m1x2 $3, byte3; + } + """, + constraints="=r,=r,=r,=r,r", + args=[packed_data], + dtype=tl.uint32, + is_pure=True, + pack=4, + ) + val_low = ( + (x_f16x2_packed & 0xFFFF).cast(tl.uint16).cast(tl.float16, bitcast=True).cast(tl.float32) + ) + val_high = ( + (x_f16x2_packed >> 16).cast(tl.uint16).cast(tl.float16, bitcast=True).cast(tl.float32) + ) + + # Calculate output positions for both values + out_col_low = packed_col_idx * 2 + out_col_high = packed_col_idx * 2 + 1 + out_offs_low = packed_row_idx * N + out_col_low + out_offs_high = packed_row_idx * N + out_col_high + + # Calculate block indices for scaling + block_col_low = out_col_low // BLOCK_SIZE + block_col_high = out_col_high // BLOCK_SIZE + scale_offs_low = packed_row_idx * (N // BLOCK_SIZE) + block_col_low + scale_offs_high = packed_row_idx * (N // BLOCK_SIZE) + block_col_high + + # Load scaling factors + scale_low = tl.load(scale_ptr + scale_offs_low, mask=packed_mask & (out_col_low < N), other=1.0) + scale_high = tl.load( + scale_ptr + scale_offs_high, mask=packed_mask & (out_col_high < N), other=1.0 + ) + + # Apply scaling + result_low = val_low * scale_low.to(tl.float32) * global_scale + result_high = val_high * scale_high.to(tl.float32) * global_scale + + # Store results + out_mask_low = packed_mask & (out_col_low < N) + out_mask_high = packed_mask & (out_col_high < N) + + tl.store(output_ptr + out_offs_low, result_low, mask=out_mask_low) + tl.store(output_ptr + out_offs_high, result_high, mask=out_mask_high) + + +def fp4_dequantize( + packed_tensor: torch.Tensor, + scale_tensor: torch.Tensor, + global_scale: torch.Tensor, + block_size: int = 16, + tile_size: int = 128, + dtype: torch.dtype = torch.get_default_dtype(), +) -> torch.Tensor: + """Dequantizes FP4 packed tensor using per-block scaling factors. + + Args: + packed_tensor (torch.Tensor): Packed uint8 tensor of shape (M, N//2) + scale_tensor (torch.Tensor): Per-block scale tensor of shape (M, N//block_size) + global_scale (torch.Tensor): Global scaling factor tensor + block_size (int): Size of FP4 quantization blocks + tile_size (int): Size of processing tiles + + Returns: + torch.Tensor: Dequantized tensor of shape (M, N) + """ + packed_N = packed_tensor.shape[-1] + N = packed_N * 2 + # Create output tensor with proper shape handling + output_shape = list(packed_tensor.shape) + output_shape[-1] = N + output = torch.empty(output_shape, dtype=dtype, device=packed_tensor.device) + + # Calculate total number of elements and grid size + grid = lambda meta: (triton.cdiv(packed_tensor.numel(), meta["TILE_SIZE"]),) + + fp4_dequantize_kernel[grid]( + packed_tensor, + scale_tensor, + global_scale, + output, + N, + BLOCK_SIZE=block_size, + TILE_SIZE=tile_size, + ) + + return output diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index dc1277359..ac922af35 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -15,57 +15,37 @@ """Configurations for speculative decoding modes.""" +from copy import deepcopy + from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField +from .eagle.default_config import default_eagle_config + +eagle3_default_config = deepcopy(default_eagle_config) +eagle_mtp_default_config = deepcopy(default_eagle_config) + +eagle3_default_config.update({"use_aux_hidden_state": True, "use_last_layernorm": True}) +eagle_mtp_default_config.update({"use_last_layernorm": True, "use_mtp_layernorm": True}) + EAGLE1_DEFAULT_CFG = { "algorithm": "eagle", "config": { - "eagle_num_layers": 1, - "eagle_hidden_state_distillation": False, - "eagle_disable_moe": True, - "use_input_layernorm_in_first_layer": False, - "use_last_layernorm": False, - "use_mtp_layernorm": False, + "eagle_architecture_config": deepcopy(default_eagle_config), }, } EAGLE3_DEFAULT_CFG = { "algorithm": "eagle", "config": { - "eagle_num_layers": 1, - "eagle_hidden_state_distillation": False, - "eagle_disable_moe": True, - "use_aux_hidden_state": True, - "eagle_aux_hidden_state_layer_ids": [], - "use_input_layernorm_in_first_layer": True, - "use_last_layernorm": True, - "use_mtp_layernorm": False, + "eagle_architecture_config": eagle3_default_config, }, } EAGLE_MTP_DEFAULT_CFG = { "algorithm": "eagle", "config": { - "eagle_num_layers": 1, - "eagle_hidden_state_distillation": False, - "eagle_disable_moe": False, - "use_aux_hidden_state": False, - "eagle_aux_hidden_state_layer_ids": [], - "use_input_layernorm_in_first_layer": True, - "use_last_layernorm": True, - "use_mtp_layernorm": True, - }, -} - -MTP_DEFAULT_CFG = { - "algorithm": "eagle", - "config": { - "eagle_num_layers": 1, - "eagle_hidden_state_distillation": False, - "eagle_disable_moe": False, - "use_input_layernorm_in_first_layer": True, - "use_last_layernorm": False, - "use_mtp_layernorm": True, + "eagle_reuse_base_decoder": True, + "eagle_architecture_config": eagle_mtp_default_config, }, } @@ -87,81 +67,34 @@ class MedusaConfig(ModeloptBaseConfig): class EagleConfig(ModeloptBaseConfig): """Eagle config.""" - eagle_num_layers: int = ModeloptField( - default=1, - description=("The number of decoder used in the eagle model."), - ) - - use_input_layernorm_in_first_layer: bool = ModeloptField( - default=True, description=("Whether to use input_layernorm in the first decoder layer.") - ) - - use_last_layernorm: bool = ModeloptField( - default=False, description=("Whether to use a final layernorm before lm_head.") + eagle_offline: bool = ModeloptField( + default=False, description=("Whether to use detached Eagle.") ) eagle_hidden_state_distillation: bool = ModeloptField( default=False, description=("Whether to use feature hidden states distillation.") ) - use_aux_hidden_state: bool = ModeloptField( - default=False, description=("Whether to use aux hidden state (EAGLE-3).") + eagle_self_logit_distillation: bool = ModeloptField( + default=True, description=("Whether to use logit distillation.") ) - eagle_aux_hidden_state_layer_ids: list = ModeloptField( - default=[], - description=("The list of aux hidden state layers used in EAGLE-3."), + eagle_freeze_base_model: bool = ModeloptField( + default=True, description=("Whether to freeze base model during eagle module training.") ) - eagle_disable_moe: bool = ModeloptField( - default=False, description=("Whether to disable MoE in eagle module.") + eagle_report_acc: bool = ModeloptField( + default=True, description=("Whether to report eval accuracy.") ) - draft_vocab_size: int = ModeloptField( - default=0, - description=("The vocab size of the eagle module. 0 means the same as base model."), - ) - - use_mtp_layernorm: bool = ModeloptField( - default=False, - description=( - "Whether to use norms before input_hidden_states and embedding in eagle module." - ), - ) - - ffn_hidden_size: int = ModeloptField( - default=0, - description=( - "ffn_hidden_size of the eagle module. Using base model's ffn_hidden_size is set to 0." - ), - ) - - parallel_draft_step: int = ModeloptField( - default=1, - description=( - "The number of tokens generated in parallel draft. If set to 1, draft is not in parallel mode." - ), - ) - - -class MTPConfig(ModeloptBaseConfig): - """MTP config.""" - - mtp_num_layers: int = ModeloptField( - default=1, - description=("The number of decoder used in the mtp model."), - ) - - mtp_num_module: int = ModeloptField( - default=1, - description=("The number of mtp used in the model."), + eagle_reuse_base_decoder: bool = ModeloptField( + default=False, description=("Whether to reuse base model decoder in eagle module.") ) - mtp_freeze_list: list = ModeloptField( - default=[], - description=("The list of mtp module to freeze."), + eagle_loss_decay_factor: float = ModeloptField( + default=0.9, description=("The decay factor for multiple eagle_loss.") ) - use_last_layernorm: bool = ModeloptField( - default=False, description=("Whether to use a final layernorm before lm_head.") + eagle_architecture_config: dict = ModeloptField( + default={}, description=("The config for eagle module architecture.") ) diff --git a/modelopt/torch/speculative/eagle/__init__.py b/modelopt/torch/speculative/eagle/__init__.py index 91df77580..183246ee3 100644 --- a/modelopt/torch/speculative/eagle/__init__.py +++ b/modelopt/torch/speculative/eagle/__init__.py @@ -16,4 +16,5 @@ """Eagle Optimization Method.""" from .conversion import * +from .default_config import * from .eagle_model import * diff --git a/modelopt/torch/speculative/eagle/conversion.py b/modelopt/torch/speculative/eagle/conversion.py index 32efb9886..f048feaa1 100644 --- a/modelopt/torch/speculative/eagle/conversion.py +++ b/modelopt/torch/speculative/eagle/conversion.py @@ -24,6 +24,7 @@ from ..config import EagleConfig EagleDMRegistry = _DMRegistryCls(prefix="Eagle") # global instance for the registry +OfflineEagleDMRegistry = _DMRegistryCls(prefix="DetachedEagle") # global instance for the registry def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertReturnType: @@ -31,26 +32,25 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu # initialize the true module if necessary model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + registry = OfflineEagleDMRegistry if config.eagle_offline else EagleDMRegistry + original_cls = type(model) - if original_cls not in EagleDMRegistry: - for cls in EagleDMRegistry._registry: + if original_cls not in registry: + for cls in registry._registry: if issubclass(original_cls, cls): - EagleDMRegistry.register({original_cls: "base_model_class"})(EagleDMRegistry[cls]) + registry.register({original_cls: "base_model_class"})(registry[cls]) break - eagle_model = EagleDMRegistry.convert(model) + eagle_model = registry.convert(model) eagle_model.modify( - eagle_num_layers=config.eagle_num_layers, - use_input_layernorm_in_first_layer=config.use_input_layernorm_in_first_layer, - use_last_layernorm=config.use_last_layernorm, + eagle_offline=config.eagle_offline, eagle_hidden_state_distillation=config.eagle_hidden_state_distillation, - use_aux_hidden_state=config.use_aux_hidden_state, - eagle_aux_hidden_state_layer_ids=config.eagle_aux_hidden_state_layer_ids, - eagle_disable_moe=config.eagle_disable_moe, - draft_vocab_size=config.draft_vocab_size, - use_mtp_layernorm=config.use_mtp_layernorm, - ffn_hidden_size=config.ffn_hidden_size, - parallel_draft_step=config.parallel_draft_step, + eagle_self_logit_distillation=config.eagle_self_logit_distillation, + eagle_freeze_base_model=config.eagle_freeze_base_model, + eagle_report_acc=config.eagle_report_acc, + eagle_reuse_base_decoder=config.eagle_reuse_base_decoder, + eagle_loss_decay_factor=config.eagle_loss_decay_factor, + eagle_architecture_config=config.eagle_architecture_config, ) # no metadata, all specifed via config. diff --git a/modelopt/torch/speculative/eagle/default_config.py b/modelopt/torch/speculative/eagle/default_config.py new file mode 100644 index 000000000..f8c69b2ff --- /dev/null +++ b/modelopt/torch/speculative/eagle/default_config.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default EAGLE architecture config.""" + +default_eagle_config = { + "hidden_act": "silu", + "torch_dtype": "bfloat16", + "vocab_size": 128256, + "draft_vocab_size": 128256, + "max_position_embeddings": 8192, + "position_embedding_type": "rope", + "rope_scaling": { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + "rope_theta": 500000.0, + "num_hidden_layers": 1, + "hidden_size": 4096, + "intermediate_size": 14336, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "initializer_range": 0.01, + "rms_norm_eps": 1e-05, + "mlp_bias": False, + "attention_bias": False, + "attention_dropout": 0.0, + "use_input_layernorm_in_first_layer": True, + "use_last_layernorm": False, + "use_aux_hidden_state": False, + "eagle_aux_hidden_state_layer_ids": [], + "use_mtp_layernorm": False, + "parallel_draft_step": 1, + "has_lm_head": False, +} diff --git a/modelopt/torch/speculative/eagle/eagle_model.py b/modelopt/torch/speculative/eagle/eagle_model.py index a13f0ed02..69051f6e8 100644 --- a/modelopt/torch/speculative/eagle/eagle_model.py +++ b/modelopt/torch/speculative/eagle/eagle_model.py @@ -24,44 +24,28 @@ class EagleModel(DynamicModule): """Base Eagle Model.""" def _setup(self): - self._register_temp_attribute("eagle_num_layers", 0) self._register_temp_attribute("eagle_module", None) def modify( self, - eagle_num_layers, - use_input_layernorm_in_first_layer, - use_last_layernorm, + eagle_offline, eagle_hidden_state_distillation, - use_aux_hidden_state, - eagle_aux_hidden_state_layer_ids, - eagle_disable_moe, - draft_vocab_size, - use_mtp_layernorm, - parallel_draft_step, + eagle_self_logit_distillation, + eagle_freeze_base_model, + eagle_report_acc, + eagle_reuse_base_decoder, + eagle_loss_decay_factor, + eagle_architecture_config, ): """Base Eagle Model modify function. Child class should implement the details.""" - self.eagle_num_layers = eagle_num_layers - self.use_input_layernorm_in_first_layer = use_input_layernorm_in_first_layer - self.use_last_layernorm = use_last_layernorm + self.eagle_offline = eagle_offline self.eagle_hidden_state_distillation = eagle_hidden_state_distillation - self.use_aux_hidden_state = use_aux_hidden_state - self.eagle_aux_hidden_state_layer_ids = eagle_aux_hidden_state_layer_ids - self.eagle_disable_moe = eagle_disable_moe - self.draft_vocab_size = draft_vocab_size - self.use_mtp_layernorm = use_mtp_layernorm - self.parallel_draft_step = parallel_draft_step - - # Use default aux_hidden_state layers if use_aux_hidden_state is True - # but no layer id is given - if self.use_aux_hidden_state and len(self.eagle_aux_hidden_state_layer_ids) == 0: - self._set_default_aux_hidden_state_layers() - - if len(self.eagle_aux_hidden_state_layer_ids) > 0: - assert not self.eagle_hidden_state_distillation, ( - "EAGLE-3 does not support hidden state distillation!" - ) - - if self.parallel_draft_step > 1: - for i in range(self.parallel_draft_step - 1): + self.eagle_self_logit_distillation = eagle_self_logit_distillation + self.eagle_freeze_base_model = eagle_freeze_base_model + self.eagle_report_acc = eagle_report_acc + self.eagle_reuse_base_decoder = eagle_reuse_base_decoder + self.eagle_loss_decay_factor = eagle_loss_decay_factor + + if eagle_architecture_config.get("parallel_draft_step", 1) > 1: + for i in range(eagle_architecture_config.get("parallel_draft_step") - 1): self.register_buffer(f"mask_token_{i}", torch.tensor(-1)) diff --git a/modelopt/torch/speculative/mode.py b/modelopt/torch/speculative/mode.py index d126dc1a1..866449e15 100644 --- a/modelopt/torch/speculative/mode.py +++ b/modelopt/torch/speculative/mode.py @@ -23,10 +23,9 @@ _ModeRegistryCls, ) -from .config import EagleConfig, MedusaConfig, MTPConfig +from .config import EagleConfig, MedusaConfig from .eagle.conversion import convert_to_eagle_model, restore_eagle_model from .medusa.conversion import convert_to_medusa_model, restore_medusa_model -from .mtp.conversion import convert_to_mtp_model, restore_mtp_model SpeculativeDecodingModeRegistry = _ModeRegistryCls("speculative") @@ -85,31 +84,3 @@ def convert(self) -> ConvertEntrypoint: def restore(self) -> RestoreEntrypoint: """The mode's entrypoint for restoring a model.""" return restore_eagle_model - - -@SpeculativeDecodingModeRegistry.register_mode -class MTPModeDescriptor(ModeDescriptor): - """Class to describe the ``"mtp"`` mode. - - The properties of this mode can be inspected via the source code. - """ - - @property - def name(self) -> str: - """Returns the value (str representation) of the mode.""" - return "mtp" - - @property - def config_class(self) -> type[ModeloptBaseConfig]: - """Specifies the config class for the mode.""" - return MTPConfig - - @property - def convert(self) -> ConvertEntrypoint: - """The mode's entrypoint for converting a model.""" - return convert_to_mtp_model - - @property - def restore(self) -> RestoreEntrypoint: - """The mode's entrypoint for restoring a model.""" - return restore_mtp_model diff --git a/modelopt/torch/speculative/mtp/conversion.py b/modelopt/torch/speculative/mtp/conversion.py deleted file mode 100644 index 81b16f8b1..000000000 --- a/modelopt/torch/speculative/mtp/conversion.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""MTP conversion/restore utilities.""" - -from torch import nn - -from modelopt.torch.opt.conversion import ModelLikeModule -from modelopt.torch.opt.dynamic import _DMRegistryCls -from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict - -from ..config import MTPConfig - -MTPDMRegistry = _DMRegistryCls(prefix="MTP") # global instance for the registry - - -def convert_to_mtp_model(model: nn.Module, config: MTPConfig) -> ConvertReturnType: - """Convert the model to a mtp model as per `config`.""" - # initialize the true module if necessary - model = model.init_modellike() if isinstance(model, ModelLikeModule) else model - - original_cls = type(model) - if original_cls not in MTPDMRegistry: - for cls in MTPDMRegistry._registry: - if issubclass(original_cls, cls): - MTPDMRegistry.register({original_cls: "base_model_class"})(MTPDMRegistry[cls]) - break - - mtp_model = MTPDMRegistry.convert(model) - mtp_model.modify( - mtp_num_layers=config.mtp_num_layers, - mtp_num_module=config.mtp_num_module, - mtp_freeze_list=config.mtp_freeze_list, - use_last_layernorm=config.use_last_layernorm, - ) - - # no metadata, all specifed via config. - metadata = {} - - return mtp_model, metadata - - -def restore_mtp_model(model: nn.Module, config: MTPConfig, metadata: MetadataDict) -> nn.Module: - """Function for restoring a previously convert model to a mtp model.""" - # the metadata should be empty - assert not metadata, "No metadata expected!" - - return convert_to_mtp_model(model, config)[0] diff --git a/modelopt/torch/speculative/mtp/mtp_model.py b/modelopt/torch/speculative/mtp/mtp_model.py deleted file mode 100644 index 784d41982..000000000 --- a/modelopt/torch/speculative/mtp/mtp_model.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""MTP model to support mtp decoding.""" - -from modelopt.torch.opt.dynamic import DynamicModule - - -class MTPModel(DynamicModule): - """Base MTP Model.""" - - def _setup(self): - self._register_temp_attribute("mtp_num_layers", 0) - self._register_temp_attribute("mtp_num_module", 0) - self._register_temp_attribute("mtp_freeze_list", []) - - def modify(self, mtp_num_layers, mtp_num_module, mtp_freeze_list, use_last_layernorm): - """Base MTP Model modify function. Child class should implement the details.""" - self.mtp_num_layers = mtp_num_layers - self.mtp_num_module = mtp_num_module - self.mtp_freeze_list = mtp_freeze_list - self.use_last_layernorm = use_last_layernorm diff --git a/modelopt/torch/speculative/mtp/utils.py b/modelopt/torch/speculative/mtp/utils.py deleted file mode 100644 index 10db4fc53..000000000 --- a/modelopt/torch/speculative/mtp/utils.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""DeepSeek V3.""" - -import torch -from torch import nn - - -class DeepseekV3RMSNorm(nn.Module): - """Deepseek V3 RMSNorm implementation.""" - - def __init__(self, hidden_size, eps=1e-6): - """DeepseekV3RMSNorm is equivalent to T5LayerNorm.""" - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - """Forward function of DeepseekV3RMSNorm.""" - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index ba050c9b6..01818a44f 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -21,6 +21,7 @@ import megatron.core import torch +import torch.nn.functional as F from megatron.core import InferenceParams, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding @@ -31,6 +32,8 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import ( get_data_parallel_rank, + get_expert_tensor_parallel_world_size, + get_pipeline_model_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) @@ -40,17 +43,16 @@ scatter_to_sequence_parallel_region, ) from megatron.core.transformer.attention import SelfAttention -from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.transformer.utils import sharded_state_dict_default from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint from packaging.version import Version -from ...opt.plugins.megatron_model_config import Llama31Config8B -from ..eagle.conversion import EagleDMRegistry +from ..eagle.conversion import EagleDMRegistry, OfflineEagleDMRegistry from ..eagle.eagle_model import EagleModel from ..utils import ( AcceptanceRateValidation, @@ -66,6 +68,76 @@ warnings.warn("Fail to import megatron.core.post_training! EAGLE feature will be disable!") +def dict_to_config( + architecture_config, + use_cpu_initialization=None, + fp16=False, + bf16=True, + sequence_parallel=False, +): + """Helper function to convert a dictionary to TransformerConfig.""" + config = TransformerConfig( + normalization="RMSNorm", + activation_func=F.silu, + gated_linear_unit=True, + hidden_dropout=0.0, + attention_softmax_in_fp32=False, + tensor_model_parallel_size=get_tensor_model_parallel_world_size(), + pipeline_model_parallel_size=get_pipeline_model_parallel_world_size(), + expert_tensor_parallel_size=get_expert_tensor_parallel_world_size(), + sequence_parallel=sequence_parallel, + use_cpu_initialization=use_cpu_initialization, + fp16=fp16, + bf16=bf16, + params_dtype=getattr(torch, architecture_config["torch_dtype"]), + pipeline_dtype=None, + num_layers=architecture_config.get("num_hidden_layers"), + hidden_size=architecture_config.get("hidden_size"), + ffn_hidden_size=architecture_config.get("intermediate_size"), + num_attention_heads=architecture_config.get("num_attention_heads"), + kv_channels=architecture_config.get( + "head_dim", + architecture_config.get("hidden_size") + // architecture_config.get("num_attention_heads"), + ), + num_query_groups=architecture_config.get("num_key_value_heads"), + init_method_std=architecture_config.get("initializer_range"), + layernorm_epsilon=architecture_config.get("rms_norm_eps"), + add_bias_linear=architecture_config.get("mlp_bias"), + attention_dropout=architecture_config.get("attention_dropout"), + ) + + config.transformer_layer_spec = None + config.seq_length = 8192 + config.gradient_accumulation_fusion = False + config.vocab_size = architecture_config.get("vocab_size") + config.max_sequence_length = architecture_config.get("max_position_embeddings") + config.position_embedding_type = architecture_config.get("position_embedding_type") + config.rotary_percent = 1.0 + config.rotary_base = architecture_config.get("rope_theta") + config.rope_scaling = "rope_scaling" in architecture_config + config.rope_scaling_factor = ( + architecture_config.get("rope_scaling").get("factor") + if "rope_scaling" in architecture_config + else None + ) + + config.draft_vocab_size = architecture_config.get("draft_vocab_size") + config.use_input_layernorm_in_first_layer = architecture_config.get( + "use_input_layernorm_in_first_layer" + ) + config.use_last_layernorm = architecture_config.get("use_last_layernorm") + config.use_aux_hidden_state = architecture_config.get("use_aux_hidden_state") + config.eagle_aux_hidden_state_layer_ids = architecture_config.get( + "eagle_aux_hidden_state_layer_ids" + ) + config.use_mtp_layernorm = architecture_config.get("use_mtp_layernorm") + config.parallel_draft_step = architecture_config.get("parallel_draft_step") + config.has_lm_head = architecture_config.get("has_lm_head") + + return config + + def mcore_version_higher_than(target_version: str): """Check if megatron-core is least this version.""" return Version(megatron.core.__version__) > Version(target_version) @@ -382,12 +454,7 @@ def __init__( self, config, rotary_pos_emb: torch.nn.Module, - num_layers: int, - use_last_layernorm: bool, - use_input_layernorm_in_first_layer: bool = True, - use_mtp_layernorm: bool = False, bias: bool = False, - num_aux_hidden_states: int = 0, ): """Constructor. @@ -398,31 +465,24 @@ def __init__( Args: config: MCore transformer config - num_layers: number of Eagle layers - rotary_pos_emb: If None, use the default Llama-3.1 rope (GPT-NeoX). + rotary_pos_emb: nn.Module. """ # Override transformer_config before superclass initialization - self._num_eagle_layers = num_layers - self._use_input_layernorm_in_first_layer = use_input_layernorm_in_first_layer - self._use_mtp_layernorm = use_mtp_layernorm - self._num_aux_hidden_states = num_aux_hidden_states - eagle_config = self._get_eagle_transformer_config(config) - super().__init__(config=eagle_config) + config.pipeline_model_parallel_size = 1 + config.virtual_pipeline_model_parallel_size = None + config.num_layers_in_first_pipeline_stage = None + config.num_layers_in_last_pipeline_stage = None + super().__init__(config=config) - eagle_transformer_layer_spec = self._get_eagle_transformer_layer_spec(eagle_config) + eagle_transformer_layer_spec = self._get_eagle_transformer_layer_spec(config) + self._num_aux_hidden_states = len(self.config.eagle_aux_hidden_state_layer_ids) if self._num_aux_hidden_states > 0: - self.enorm = TENorm( - eagle_config, eagle_config.hidden_size, eagle_config.layernorm_epsilon - ) + self.enorm = TENorm(config, config.hidden_size, config.layernorm_epsilon) self._embeddings = None - elif self._use_mtp_layernorm: - self.enorm = TENorm( - eagle_config, eagle_config.hidden_size, eagle_config.layernorm_epsilon - ) - self.hnorm = TENorm( - eagle_config, eagle_config.hidden_size, eagle_config.layernorm_epsilon - ) + elif self.config.use_mtp_layernorm: + self.enorm = TENorm(config, config.hidden_size, config.layernorm_epsilon) + self.hnorm = TENorm(config, config.hidden_size, config.layernorm_epsilon) device = "cpu" if config.use_cpu_initialization else torch.cuda.current_device() @@ -436,9 +496,9 @@ def __init__( # parallel is used and does not allow gathering the outputs. with torch.device(device): self.fc = Linear( - eagle_config.hidden_size * fc_input_size_multiplier, - eagle_config.hidden_size, - config=eagle_config, + config.hidden_size * fc_input_size_multiplier, + config.hidden_size, + config=config, init_method=(lambda w: None), # not used bias=bias, ) @@ -448,9 +508,9 @@ def __init__( # Eagle does not use the final_layernorm in decoder. with torch.device(device): self.decoder = EagleTransformerBlock( - config=eagle_config, + config=config, spec=eagle_transformer_layer_spec, - post_layer_norm=use_last_layernorm, + post_layer_norm=config.use_last_layernorm, pre_process=True, post_process=True, ) @@ -479,21 +539,28 @@ def __init__( tp_comm_buffer_name="qkv", ) - # Sanity check - if self.decoder.layers[0].self_attention.attention_type == AttnMaskType.arbitrary: - raise ValueError("EAGLE-3 must use arbitrary attention mask.") - - def _get_eagle_transformer_config(self, base_model_config): - eagle_config = copy.deepcopy(base_model_config) - eagle_config.num_layers = self._num_eagle_layers - # Unset the PP config. - eagle_config.pipeline_model_parallel_size = 1 - eagle_config.virtual_pipeline_model_parallel_size = None - eagle_config.num_layers_in_first_pipeline_stage = None - eagle_config.num_layers_in_last_pipeline_stage = None - return eagle_config - - def _get_eagle_transformer_layer_spec(self, eagle_config): + if self.config.draft_vocab_size != self.config.vocab_size: + # Need an extra lm_head for eagle module since vocab size is reduced. + assert self.config.draft_vocab_size <= self.config.vocab_size, ( + "EAGLE module's vocab size should be <= base model vocab size!" + ) + + self.register_buffer( + "d2t", torch.zeros(self.config.draft_vocab_size, dtype=torch.int64) + ) + if self.config.draft_vocab_size != self.config.vocab_size or self.config.has_lm_head: + self.eagle_output_layer = tensor_parallel.ColumnParallelLinear( + self.config.hidden_size, + self.config.draft_vocab_size, + config=self.config, + init_method=self.config.init_method, + bias=False, + skip_bias_add=False, + gather_output=False, + skip_weight_param_allocation=False, + ) + + def _get_eagle_transformer_layer_spec(self, config): """Get the TransformerLayer implementation spec. IMPORTANT: EagleModule must use arbitrary_attention_mask since we need to @@ -501,7 +568,7 @@ def _get_eagle_transformer_layer_spec(self, eagle_config): causal mask will result in leaking. """ transformer_layer_spec = get_gpt_modelopt_spec( - eagle_config, + config, remap_te_layernorm=True, use_arbitrary_attention_mask=True, ) @@ -515,7 +582,7 @@ def _get_eagle_transformer_layer_spec(self, eagle_config): # Force TransformerLayer in case RealQuantTransformerLayer was used. eagle_transformer_layer_spec.module = TransformerLayer - if not self._use_input_layernorm_in_first_layer: + if not self.config.use_input_layernorm_in_first_layer: eagle_transformer_layer_spec.submodules.input_layernorm = IdentityOp return eagle_transformer_layer_spec @@ -556,7 +623,7 @@ def forward( if self.config.sequence_parallel: rotary_seq_len *= self.config.tensor_model_parallel_size - if self._use_mtp_layernorm: + if self.config.use_mtp_layernorm: embeddings = self.enorm(embeddings) hidden_states = self.hnorm(hidden_states) @@ -598,90 +665,30 @@ def forward( return hidden_states, next_hidden_states_input -class EagleLlama3Module(EagleModule): - """EagleLlama3Module definition. - - EagleLlama3Module is the default subclass which uses Llama3 architecture - and rotary position embedding. - """ - - def __init__( - self, - config, - num_layers: int, - use_last_layernorm: bool, - use_input_layernorm_in_first_layer: bool = True, - use_mtp_layernorm: bool = False, - bias: bool = False, - num_aux_hidden_states: int = 0, - ffn_hidden_size: int | None = 0, - ): - """Constructor.""" - eagle_config = Llama31Config8B( - # Getting ModelParallelConfig from the base model - tensor_model_parallel_size=config.tensor_model_parallel_size, - sequence_parallel=config.sequence_parallel, - expert_tensor_parallel_size=config.expert_tensor_parallel_size, - use_cpu_initialization=config.use_cpu_initialization, - fp16=config.fp16, - bf16=config.bf16, - params_dtype=config.params_dtype, - # Override hidden_size and ffn_hidden_size from the base model - hidden_size=config.hidden_size, - ffn_hidden_size=config.ffn_hidden_size, - ) - - # If base model is using MHA/GQA, then use the same config to simply KV-cache impl. - if config.kv_channels is not None: - eagle_config.kv_channels = config.kv_channels - if config.num_attention_heads > 0: - eagle_config.num_attention_heads = config.num_attention_heads - else: - eagle_config.num_attention_heads = eagle_config.hidden_size // eagle_config.kv_channels - if config.num_query_groups is not None: - eagle_config.num_query_groups = config.num_query_groups - - # Override ffn_hidden_size if provided to widen the transformer. - if ffn_hidden_size > 0: - eagle_config.ffn_hidden_size = ffn_hidden_size - - rotary_pos_emb = RotaryEmbedding( - kv_channels=eagle_config.kv_channels, - rotary_percent=1.0, - rotary_interleaved=False, - seq_len_interpolation_factor=None, - rotary_base=500000.0, - rope_scaling=True, - rope_scaling_factor=8.0, - use_cpu_initialization=eagle_config.use_cpu_initialization, - ) - - super().__init__( - eagle_config, - rotary_pos_emb, - num_layers, - use_last_layernorm, - use_input_layernorm_in_first_layer=use_input_layernorm_in_first_layer, - use_mtp_layernorm=use_mtp_layernorm, - bias=bias, - num_aux_hidden_states=num_aux_hidden_states, - ) - - @EagleDMRegistry.register({GPTModel: "megatron.core.models.gpt.GPTModel"}) class _DynamicEagleGPTModel(EagleModel): """A ``megatron.core.models.gpt.GPTModel`` model with dynamic hyperparams.""" def _set_default_aux_hidden_state_layers(self): - num_layers = self.config.num_layers - self.eagle_aux_hidden_state_layer_ids = [1, num_layers // 2 - 1, num_layers - 4] + if hasattr(self.config, "original_num_layers"): + num_layers = self.config.original_num_layers + else: + num_layers = self.config.num_layers + self.eagle_config.eagle_aux_hidden_state_layer_ids = [ + 1, + max(0, num_layers // 2 - 1), + max(0, num_layers - 4), + ] + self.eagle_config.eagle_aux_hidden_state_layer_ids = list( + set(self.eagle_config.eagle_aux_hidden_state_layer_ids) + ) def _transformer_layer_forward_hook(self, module, input, output) -> None: if not isinstance(module, TransformerLayer): raise ValueError( "_transformer_layer_forward_hook can only be registered to TransformerLayer" ) - if module.layer_number - 1 not in self.eagle_aux_hidden_state_layer_ids: + if module.layer_number - 1 not in self.eagle_config.eagle_aux_hidden_state_layer_ids: return hidden_states = ( output.clone().detach() @@ -697,20 +704,14 @@ def _setup(self): def modify( self, - eagle_num_layers=0, - use_input_layernorm_in_first_layer=True, - use_last_layernorm=True, - eagle_hidden_state_distillation=False, - use_aux_hidden_state=False, - eagle_aux_hidden_state_layer_ids=[], - eagle_disable_moe=False, - draft_vocab_size=0, - use_mtp_layernorm=False, - parallel_draft_step=1, - eagle_self_logit_distillation=True, - eagle_freeze_base_model=True, - eagle_report_acc=True, - ffn_hidden_size=0, + eagle_offline, + eagle_hidden_state_distillation, + eagle_self_logit_distillation, + eagle_freeze_base_model, + eagle_report_acc, + eagle_reuse_base_decoder, + eagle_loss_decay_factor, + eagle_architecture_config, ): if self.config.pipeline_model_parallel_size > 1: warnings.warn( @@ -724,25 +725,46 @@ def modify( self.config.hetereogenous_dist_checkpoint = True super().modify( - eagle_num_layers=eagle_num_layers, - use_input_layernorm_in_first_layer=use_input_layernorm_in_first_layer, - use_last_layernorm=use_last_layernorm, + eagle_offline=eagle_offline, eagle_hidden_state_distillation=eagle_hidden_state_distillation, - use_aux_hidden_state=use_aux_hidden_state, - eagle_aux_hidden_state_layer_ids=eagle_aux_hidden_state_layer_ids, - eagle_disable_moe=eagle_disable_moe, - draft_vocab_size=draft_vocab_size, - use_mtp_layernorm=use_mtp_layernorm, - parallel_draft_step=parallel_draft_step, + eagle_self_logit_distillation=eagle_self_logit_distillation, + eagle_freeze_base_model=eagle_freeze_base_model, + eagle_report_acc=eagle_report_acc, + eagle_reuse_base_decoder=eagle_reuse_base_decoder, + eagle_loss_decay_factor=eagle_loss_decay_factor, + eagle_architecture_config=eagle_architecture_config, ) - self.eagle_report_acc = eagle_report_acc - self.eagle_self_logit_distillation = eagle_self_logit_distillation - self.eagle_freeze_base_model = eagle_freeze_base_model + + self.eagle_config = dict_to_config( + eagle_architecture_config, + self.config.use_cpu_initialization, + self.config.fp16, + self.config.bf16, + self.config.sequence_parallel, + ) + + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + assert eagle_self_logit_distillation, ( + "Only logit distillation is supported when draft_vocab_size != vocab_size!" + ) + + # Use default aux_hidden_state layers if use_aux_hidden_state is True + # but no layer id is given + if ( + self.eagle_config.use_aux_hidden_state + and len(self.eagle_config.eagle_aux_hidden_state_layer_ids) == 0 + ): + self._set_default_aux_hidden_state_layers() + + if len(self.eagle_config.eagle_aux_hidden_state_layer_ids) > 0: + assert not self.eagle_hidden_state_distillation, ( + "EAGLE-3 does not support hidden state distillation!" + ) # EAGLE-3 auxiluary hidden_states (only work for TP+EP, does not work for PP) self._aux_hidden_states = [] - if self.position_embedding_type not in ["rope", "yarn"]: + if self.eagle_config.position_embedding_type not in ["rope", "yarn"]: raise ValueError("For EAGLE, only RoPE or YaRN embedding are supported") if not self.pre_process and self.post_process: @@ -754,7 +776,7 @@ def modify( ) # Register TransformerLayer forward hook to extract aux hidden_states. - if len(self.eagle_aux_hidden_state_layer_ids) > 0: + if len(self.eagle_config.eagle_aux_hidden_state_layer_ids) > 0: for layer in self.decoder.layers: layer.register_forward_hook(self._transformer_layer_forward_hook) @@ -766,55 +788,44 @@ def modify( # Only the last PP stage has the additional projection and decoder layer. # This is to simplify the export. if self.post_process: - if self.eagle_disable_moe: - self.eagle_module = EagleLlama3Module( - self.config, - self.eagle_num_layers, - self.use_last_layernorm, - use_input_layernorm_in_first_layer=use_input_layernorm_in_first_layer, - use_mtp_layernorm=self.use_mtp_layernorm, - num_aux_hidden_states=len(self.eagle_aux_hidden_state_layer_ids), - bias=False, - ffn_hidden_size=ffn_hidden_size, + if self.eagle_reuse_base_decoder: + eagle_config = copy.deepcopy(self.config) + # Overwrite values from the eagle config + eagle_config.num_layers = self.eagle_config.num_layers + eagle_config.use_last_layernorm = self.eagle_config.use_last_layernorm + eagle_config.use_input_layernorm_in_first_layer = ( + self.eagle_config.use_input_layernorm_in_first_layer ) - else: + eagle_config.eagle_aux_hidden_state_layer_ids = ( + self.eagle_config.eagle_aux_hidden_state_layer_ids + ) + eagle_config.use_mtp_layernorm = self.eagle_config.use_mtp_layernorm self.eagle_module = EagleModule( - self.config, + eagle_config, self.rotary_pos_emb, - self.eagle_num_layers, - self.use_last_layernorm, - use_input_layernorm_in_first_layer=use_input_layernorm_in_first_layer, - use_mtp_layernorm=self.use_mtp_layernorm, - num_aux_hidden_states=len(self.eagle_aux_hidden_state_layer_ids), bias=False, ) - - # Eagle loss functions - self.kld = logits_kld_loss - - if self.draft_vocab_size > 0: - # Need an extra lm_head for eagle module since vocab size is reduced. - assert self.draft_vocab_size <= self.vocab_size, ( - "EAGLE module's vocab size should be <= base model vocab size!" - ) - assert eagle_self_logit_distillation, ( - "Only logit distillation is supported when draft_vocab_size > 0!" + else: + rotary_pos_emb = RotaryEmbedding( + kv_channels=self.eagle_config.kv_channels, + rotary_percent=self.eagle_config.rotary_percent, + rotary_interleaved=False, + seq_len_interpolation_factor=None, + rotary_base=self.eagle_config.rotary_base, + rope_scaling=self.eagle_config.rope_scaling, + rope_scaling_factor=self.eagle_config.rope_scaling_factor, + use_cpu_initialization=self.eagle_config.use_cpu_initialization, ) - self.eagle_module.register_buffer( - "d2t", torch.zeros(self.draft_vocab_size, dtype=torch.int64) - ) - self.eagle_module.eagle_output_layer = tensor_parallel.ColumnParallelLinear( - self.config.hidden_size, - self.draft_vocab_size, - config=self.output_layer.config, - init_method=self.config.init_method, + self.eagle_module = EagleModule( + self.eagle_config, + rotary_pos_emb, bias=False, - skip_bias_add=False, - gather_output=False, - skip_weight_param_allocation=False, ) + # Eagle loss functions + self.kld = logits_kld_loss + def _get_eagle_input_hidden_states(self, hidden_states: torch.Tensor, apply_fc: bool = True): """When _aux_hidden_states is not empty, then this is EAGLE-3. @@ -860,7 +871,7 @@ def _get_eagle_module_inputs( eagle_inputs = {} - if self.parallel_draft_step > 1: + if self.eagle_config.parallel_draft_step > 1: eagle_inputs["input_ids"] = padded_input_ids eagle_inputs["position_ids"] = position_ids if rotary_pos_emb is not None: @@ -875,7 +886,7 @@ def _get_eagle_module_inputs( gathered_hidden_states = hidden_states eagle_inputs["hidden_states"] = gathered_hidden_states - for i in range(self.parallel_draft_step - 1): + for i in range(self.eagle_config.parallel_draft_step - 1): eagle_inputs["input_ids"] = torch.cat( ( eagle_inputs["input_ids"], @@ -915,7 +926,7 @@ def _get_eagle_module_inputs( ) eagle_inputs["attention_mask"] = set_multi_step_attention_mask( - attn_mask, self.parallel_draft_step + attn_mask, self.eagle_config.parallel_draft_step ) elif features is None: eagle_inputs["input_ids"] = padded_input_ids @@ -1049,7 +1060,7 @@ def _compute_eagle_loss(self, logits, labels, eagle_logits): """ # Compute lm loss (classification loss) or KLDivergence if self.eagle_self_logit_distillation: - mapping = self.eagle_module.d2t if self.draft_vocab_size > 0 else None + mapping = self.eagle_module.d2t if hasattr(self.eagle_module, "d2t") else None token_loss = self.kld(eagle_logits[:-1, :, :], logits[1:, :, :], mapping) else: token_loss = self.compute_language_model_loss(labels[:, 1:], eagle_logits[:-1, :, :]) @@ -1142,7 +1153,7 @@ def _eagle_forward( **(extra_block_kwargs or {}), ) - if self.draft_vocab_size > 0: + if hasattr(self.eagle_module, "eagle_output_layer"): eagle_logits, _ = self.eagle_module.eagle_output_layer(eagle_hidden_states) else: eagle_logits, _ = self.output_layer(eagle_hidden_states, weight=output_weight) @@ -1160,7 +1171,6 @@ def forward( packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict | None = None, return_eagle_inputs: bool = False, - loss_decay_factor: float = 0.9, **kwargs, ) -> torch.Tensor: if input_ids is not None and (position_ids is None or attention_mask is None): @@ -1194,12 +1204,18 @@ def forward( eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( hidden_states, apply_fc=False ) + + if self.config.sequence_parallel: + eagle_module_input_hidden_states = gather_from_sequence_parallel_region( + eagle_module_input_hidden_states + ) + hidden_states = gather_from_sequence_parallel_region(hidden_states) + logits_sbh = gather_from_tensor_model_parallel_region(logits_sbh) # In case of VLM, there will be other fields for pixels. return { - "input_ids": input_ids, - "decoder_input": decoder_input_for_eagle, - "hidden_states": eagle_module_input_hidden_states, - "logits": logits_sbh, + "input_ids": input_ids.squeeze(0).cpu(), + "aux_hidden_states": eagle_module_input_hidden_states.squeeze(1).cpu(), + "hidden_states": hidden_states.squeeze(1).cpu(), } else: eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( @@ -1234,8 +1250,8 @@ def forward( loss = self.compute_language_model_loss(labels, logits_sbh) loss = 0.0 * loss - if self.parallel_draft_step > 1: - for i in range(self.parallel_draft_step): + if self.eagle_config.parallel_draft_step > 1: + for i in range(self.eagle_config.parallel_draft_step): eagle_logits = eagle_logits_0[i * labels.shape[1] : (i + 1) * labels.shape[1]] loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) loss_ = loss_[:, i:] @@ -1243,7 +1259,7 @@ def forward( return loss loss_0 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_0) - loss[:, 1:] += loss_decay_factor * loss_0 + loss[:, 1:] += self.eagle_loss_decay_factor * loss_0 if self.eagle_report_acc and not self.training: acc = [] @@ -1252,7 +1268,7 @@ def forward( eagle_logits_0[:-1, :, :] ) eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.draft_vocab_size > 0: + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: eagle_top1 += self.eagle_module.d2t[eagle_top1] top1_p = torch.eq(labels[:, 1:], eagle_top1).sum() / eagle_top1.numel() acc.append(top1_p) @@ -1284,7 +1300,7 @@ def forward( loss_1 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_1) # [b, s - 2] loss_1 = loss_1[:, 1:] - loss[:, 2:] += loss_decay_factor**2 * loss_1 + loss[:, 2:] += self.eagle_loss_decay_factor**2 * loss_1 if self.eagle_report_acc and not self.training: acc = [] @@ -1293,7 +1309,7 @@ def forward( eagle_logits_1[1:-1, :, :] ) eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.draft_vocab_size > 0: + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: eagle_top1 += self.eagle_module.d2t[eagle_top1] top1_p = torch.eq(labels[:, 2:], eagle_top1).sum() / eagle_top1.numel() acc.append(top1_p) @@ -1326,7 +1342,7 @@ def forward( loss_2 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_2) # [b, s - 3] loss_2 = loss_2[:, 2:] - loss[:, 3:] += loss_decay_factor**3 * loss_2 + loss[:, 3:] += self.eagle_loss_decay_factor**3 * loss_2 if self.eagle_report_acc and not self.training: acc = [] @@ -1335,7 +1351,7 @@ def forward( eagle_logits_2[2:-1, :, :] ) eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.draft_vocab_size > 0: + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: eagle_top1 += self.eagle_module.d2t[eagle_top1] top1_p = torch.eq(labels[:, 3:], eagle_top1).sum() / eagle_top1.numel() acc.append(top1_p) @@ -1368,7 +1384,7 @@ def forward( loss_3 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_3) # [b, s - 4] loss_3 = loss_3[:, 3:] - loss[:, 4:] += loss_decay_factor**4 * loss_3 + loss[:, 4:] += self.eagle_loss_decay_factor**4 * loss_3 if self.eagle_report_acc and not self.training: acc = [] @@ -1377,7 +1393,7 @@ def forward( eagle_logits_3[3:-1, :, :] ) eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.draft_vocab_size > 0: + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: eagle_top1 += self.eagle_module.d2t[eagle_top1] top1_p = torch.eq(labels[:, 4:], eagle_top1).sum() / eagle_top1.numel() acc.append(top1_p) @@ -1608,7 +1624,7 @@ def pseudo_speculative_generate( # EAGLE-3 # Only the first iteration input_hidden_states are from aux_hidden_state layers - hidden_states = self._get_eagle_input_hidden_states(hidden_states) + hidden_states = self._get_eagle_input_hidden_states(hidden_states, apply_fc=True) # Remove the padding if self.config.sequence_parallel: hidden_states = gather_from_sequence_parallel_region(hidden_states) @@ -1616,8 +1632,8 @@ def pseudo_speculative_generate( draft_tokens = [] for _ in range(steps): - if self.parallel_draft_step > 1: - for i in range(self.parallel_draft_step - 1): + if self.eagle_config.parallel_draft_step > 1: + for i in range(self.eagle_config.parallel_draft_step - 1): eagle_ids = torch.cat( (eagle_ids, getattr(self, f"mask_token_{i}").view((1, 1))), dim=-1 ) @@ -1655,10 +1671,10 @@ def pseudo_speculative_generate( ) eagle_next_hidden_states_input = eagle_next_hidden_states_input[:seq_len, :, :] - if self.parallel_draft_step > 1: + if self.eagle_config.parallel_draft_step > 1: draft_token = ( gather_from_tensor_model_parallel_region(eagle_logits)[ - -self.parallel_draft_step :, :, : + -self.eagle_config.parallel_draft_step :, :, : ] .argmax(dim=-1) .transpose(0, 1) @@ -1669,10 +1685,10 @@ def pseudo_speculative_generate( .argmax(dim=-1) .transpose(0, 1) ) - if self.draft_vocab_size > 0: + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: draft_token += self.eagle_module.d2t[draft_token] - if self.parallel_draft_step > 1: + if self.eagle_config.parallel_draft_step > 1: return base_token, draft_token draft_tokens.append(draft_token) @@ -1687,6 +1703,408 @@ def pseudo_speculative_generate( return base_token, draft_tokens +@OfflineEagleDMRegistry.register({GPTModel: "megatron.core.models.gpt.GPTModel"}) +class _DetachedEagleGPTModel(_DynamicEagleGPTModel): + """A wrapper for detached Eagle module.""" + + def modify( + self, + eagle_offline, + eagle_hidden_state_distillation, + eagle_self_logit_distillation, + eagle_freeze_base_model, + eagle_report_acc, + eagle_reuse_base_decoder, + eagle_loss_decay_factor, + eagle_architecture_config, + ): + super(_DynamicEagleGPTModel, self).modify( + eagle_offline=eagle_offline, + eagle_hidden_state_distillation=eagle_hidden_state_distillation, + eagle_self_logit_distillation=eagle_self_logit_distillation, + eagle_freeze_base_model=eagle_freeze_base_model, + eagle_report_acc=eagle_report_acc, + eagle_reuse_base_decoder=eagle_reuse_base_decoder, + eagle_loss_decay_factor=eagle_loss_decay_factor, + eagle_architecture_config=eagle_architecture_config, + ) + + # Freeze all parameters + if self.eagle_freeze_base_model: + for name, param in self.named_parameters(): + param.requires_grad = False + + self.eagle_config = dict_to_config( + eagle_architecture_config, + self.config.use_cpu_initialization, + self.config.fp16, + self.config.bf16, + ) + + assert not eagle_reuse_base_decoder, ( + "_DetachedEagleGPTModel does not have a base model so eagle_reuse_base_decoder must be False!" + ) + + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + assert eagle_self_logit_distillation, ( + "Only logit distillation is supported when draft_vocab_size != vocab_size!" + ) + + # Use default aux_hidden_state layers if use_aux_hidden_state is True + # but no layer id is given + # layer ids are not used in detached eagle, but we need to set this to have correct fc_input_size_multiplier + if ( + self.eagle_config.use_aux_hidden_state + and len(self.eagle_config.eagle_aux_hidden_state_layer_ids) == 0 + ): + self._set_default_aux_hidden_state_layers() + + # Only the last PP stage has the additional projection and decoder layer. + # This is to simplify the export. + if self.post_process: + rotary_pos_emb = RotaryEmbedding( + kv_channels=self.eagle_config.kv_channels, + rotary_percent=self.eagle_config.rotary_percent, + rotary_interleaved=False, + seq_len_interpolation_factor=None, + rotary_base=self.eagle_config.rotary_base, + rope_scaling=self.eagle_config.rope_scaling, + rope_scaling_factor=self.eagle_config.rope_scaling_factor, + use_cpu_initialization=self.eagle_config.use_cpu_initialization, + ) + + self.eagle_module = EagleModule( + self.eagle_config, + rotary_pos_emb, + bias=False, + ) + + # Eagle loss functions + self.kld = logits_kld_loss + + def _get_eagle_input_hidden_states(self, hidden_states: torch.Tensor, apply_fc: bool = True): + if apply_fc: + # [s / TP, b, 3h] -> [s / TP, b, h] + return self.eagle_module.fc(hidden_states)[0] + else: + return hidden_states + + def _get_detached_eagle_module_inputs( + self, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + features: torch.Tensor | None = None, + ): + """Getting EAGLE module inputs.""" + b = hidden_states.shape[1] + h = hidden_states.shape[2] + + # [b, 1] + id_padding = torch.zeros((b, 1), dtype=input_ids.dtype, device=input_ids.device) + padded_input_ids = torch.cat((input_ids[:, 1:], id_padding), dim=-1) + + rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) + + attn_mask = attention_mask.clone().detach() + attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:] + attn_mask[:, :, -1, :] = True + attn_mask[:, :, :, -1] = True + + eagle_inputs = {} + + assert self.eagle_config.parallel_draft_step == 1, ( + "Detached Eagle module does not support parallel draft yet!" + ) + if features is None: + eagle_inputs["input_ids"] = padded_input_ids + eagle_inputs["hidden_states"] = hidden_states + eagle_inputs["attention_mask"] = attn_mask + eagle_inputs["position_ids"] = position_ids + eagle_inputs["rotary_pos_emb"] = rotary_pos_emb + elif features.shape[0] == hidden_states.shape[0]: + eagle_inputs["input_ids"] = torch.cat( + (padded_input_ids, padded_input_ids), + dim=-1, + ) + eagle_inputs["hidden_states"] = torch.cat( + ( + hidden_states, + torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), + features[:-1, :, :], + ), + dim=0, + ) + eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 2) + eagle_inputs["position_ids"] = torch.cat((position_ids, position_ids), dim=-1) + + if rotary_pos_emb is not None: + eagle_inputs["rotary_pos_emb"] = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=0) + else: + # [TODO] (yeyu): there will be problem here with MLA + eagle_inputs["rotary_pos_emb"] = None + elif features.shape[0] == hidden_states.shape[0] * 2: + eagle_inputs["input_ids"] = torch.cat( + (padded_input_ids, padded_input_ids, padded_input_ids), + dim=-1, + ) + eagle_inputs["hidden_states"] = torch.cat( + ( + hidden_states, + torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), + features[:-1, :, :], + ), + dim=0, + ) + + eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 3) + eagle_inputs["position_ids"] = torch.cat( + (position_ids, position_ids, position_ids), dim=-1 + ) + + if rotary_pos_emb is not None: + eagle_inputs["rotary_pos_emb"] = torch.cat( + (rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), + dim=0, + ) + else: + # [TODO] (yeyu): there will be problem here with MLA + eagle_inputs["rotary_pos_emb"] = None + else: + eagle_inputs["input_ids"] = torch.cat( + (padded_input_ids, padded_input_ids, padded_input_ids, padded_input_ids), + dim=-1, + ) + eagle_inputs["hidden_states"] = torch.cat( + ( + hidden_states, + torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), + features[:-1, :, :], + ), + dim=0, + ) + + eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 4) + eagle_inputs["position_ids"] = torch.cat( + (position_ids, position_ids, position_ids, position_ids), dim=-1 + ) + + if rotary_pos_emb is not None: + eagle_inputs["rotary_pos_emb"] = torch.cat( + (rotary_pos_emb, rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), + dim=0, + ) + else: + # [TODO] (yeyu): there will be problem here with MLA + eagle_inputs["rotary_pos_emb"] = None + + eagle_inputs["embedding"] = self.embedding( + input_ids=eagle_inputs["input_ids"], + position_ids=eagle_inputs["position_ids"], + ) + + return eagle_inputs + + def forward( + self, + input_ids: torch.Tensor = None, + position_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + decoder_input: torch.Tensor = None, + labels: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict | None = None, + return_eagle_inputs: bool = False, # Not used in Detached Eagle + **kwargs, + ) -> torch.Tensor: + assert "aux_hidden_states" in kwargs, ( + "aux_hidden_states is required as input to _DetachedEagleGPTModel" + ) + assert "hidden_states" in kwargs, ( + "hidden_states is required as input to _DetachedEagleGPTModel" + ) + aux_hidden_states = kwargs.get("aux_hidden_states") + hidden_states = kwargs.get("hidden_states") + + # Note: labels is 1 token shorter than logits in detached mode + + if position_ids is None or attention_mask is None: + attention_mask, position_ids = get_default_attention_mask_and_position_ids(input_ids) + + eagle_module_input_hidden_states = self._get_eagle_input_hidden_states(aux_hidden_states) + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits_sbh, _ = self.output_layer(hidden_states, weight=output_weight) + + eagle_inputs_0 = self._get_detached_eagle_module_inputs( + input_ids=input_ids, + hidden_states=eagle_module_input_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + _, eagle_logits_0, eagle_hidden_states_0_pre_norm = self._eagle_forward( + eagle_inputs_0, + None, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + **(extra_block_kwargs or {}), + ) + + loss = torch.zeros(input_ids.shape).to(input_ids.device) + + loss_0 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_0) + loss[:, 1:] += self.eagle_loss_decay_factor * loss_0 + + if self.eagle_report_acc and not self.training: + acc = [] + with torch.no_grad(): + gathered_logits = gather_from_tensor_model_parallel_region( + eagle_logits_0[:-2, :, :] + ) + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = torch.eq(labels[:, 1:], eagle_top1).sum() / eagle_top1.numel() + acc.append(top1_p) + + if get_tensor_model_parallel_rank() == 0: + print( + f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 1st Top-1: {acc}", + flush=True, + ) + + # Second round of EAGLE loss + eagle_inputs_1 = self._get_detached_eagle_module_inputs( + input_ids=input_ids, + hidden_states=eagle_module_input_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + features=eagle_hidden_states_0_pre_norm, + ) + + _, eagle_logits_2x, eagle_hidden_states_2x_pre_norm = self._eagle_forward( + eagle_inputs_1, + None, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + **(extra_block_kwargs or {}), + ) + eagle_logits_1 = eagle_logits_2x[logits_sbh.shape[0] :, :, :] + + loss_1 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_1) + # [b, s - 2] + loss_1 = loss_1[:, 1:] + loss[:, 2:] += self.eagle_loss_decay_factor**2 * loss_1 + + if self.eagle_report_acc and not self.training: + acc = [] + with torch.no_grad(): + gathered_logits = gather_from_tensor_model_parallel_region( + eagle_logits_1[1:-2, :, :] + ) + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = torch.eq(labels[:, 2:], eagle_top1).sum() / eagle_top1.numel() + acc.append(top1_p) + + if get_tensor_model_parallel_rank() == 0: + print( + f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 2nd Top-1: {acc}", + flush=True, + ) + + # Third EAGLE loss + eagle_inputs_2 = self._get_detached_eagle_module_inputs( + input_ids=input_ids, + hidden_states=eagle_module_input_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + features=eagle_hidden_states_2x_pre_norm, + ) + + _, eagle_logits_3x, eagle_hidden_states_3x_pre_norm = self._eagle_forward( + eagle_inputs_2, + None, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + **(extra_block_kwargs or {}), + ) + + eagle_logits_2 = eagle_logits_3x[-logits_sbh.shape[0] :, :, :] + + loss_2 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_2) + # [b, s - 3] + loss_2 = loss_2[:, 2:] + loss[:, 3:] += self.eagle_loss_decay_factor**3 * loss_2 + + if self.eagle_report_acc and not self.training: + acc = [] + with torch.no_grad(): + gathered_logits = gather_from_tensor_model_parallel_region( + eagle_logits_2[2:-2, :, :] + ) + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = torch.eq(labels[:, 3:], eagle_top1).sum() / eagle_top1.numel() + acc.append(top1_p) + + if get_tensor_model_parallel_rank() == 0: + print( + f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 3rd Top-1: {acc}", + flush=True, + ) + + # Forth EAGLE loss + eagle_inputs_3 = self._get_detached_eagle_module_inputs( + input_ids=input_ids, + hidden_states=eagle_module_input_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + features=eagle_hidden_states_3x_pre_norm, + ) + + _, eagle_logits_4x, eagle_hidden_states_4x_pre_norm = self._eagle_forward( + eagle_inputs_3, + None, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + **(extra_block_kwargs or {}), + ) + + eagle_logits_3 = eagle_logits_4x[-logits_sbh.shape[0] :, :, :] + + loss_3 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_3) + # [b, s - 4] + loss_3 = loss_3[:, 3:] + loss[:, 4:] += self.eagle_loss_decay_factor**4 * loss_3 + + if self.eagle_report_acc and not self.training: + acc = [] + with torch.no_grad(): + gathered_logits = gather_from_tensor_model_parallel_region( + eagle_logits_3[3:-2, :, :] + ) + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = torch.eq(labels[:, 4:], eagle_top1).sum() / eagle_top1.numel() + acc.append(top1_p) + + if get_tensor_model_parallel_rank() == 0: + print( + f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 4th Top-1: {acc}", + flush=True, + ) + + return loss + + class MegatronARValidation(AcceptanceRateValidation): """This is the subclass for megatron model AR validation.""" diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 7743b8406..216d18034 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -36,12 +36,12 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss -from transformers import Cache, DynamicCache, PreTrainedModel +from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm from transformers.trainer_pt_utils import LabelSmoother from transformers.utils import ModelOutput -from ..eagle.conversion import EagleDMRegistry +from ..eagle.conversion import EagleDMRegistry, OfflineEagleDMRegistry from ..eagle.eagle_model import EagleModel from ..eagle.utils import RMSNorm, expand_mask, make_causal_mask from ..medusa.conversion import MedusaDMRegistry @@ -173,50 +173,53 @@ def forward( class EagleModule(nn.Module): """Eagle module used in EAGLE model.""" - def __init__( - self, - config, - decoder_layer_cls, - ): + def __init__(self, config, decoder_layer_cls, bias=False): """Init function for EagleModule.""" super().__init__() self.config = config + + # NOTE:This is a temporary fix to support Qwen and Mixtral in current release. + # This is refactored in following MR. + config_overwrite = { + "mlp_bias": False, + "attention_bias": False, + "head_dim": self.config.hidden_size // self.config.num_attention_heads, + } + for key, value in config_overwrite.items(): + setattr(self.config, key, value) + self.layers = nn.ModuleList( - [ - decoder_layer_cls(config, layer_idx) - for layer_idx in range(config.eagle["num_hidden_layers"]) - ] + [decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - if config.eagle["use_last_layernorm"]: + if config.use_last_layernorm: self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) # Optionally, we use a smaller vocab table for eagle module - if config.eagle["draft_vocab_size"] > 0: + if config.draft_vocab_size != config.vocab_size or config.has_lm_head: # Need an extra lm_head for eagle module since vocab size is reduced. - assert config.eagle["draft_vocab_size"] <= config.vocab_size, ( + assert config.draft_vocab_size <= config.vocab_size, ( "EAGLE module's vocab size should be <= base model vocab size!" ) # Initialize the buffers to zero. # Their values depend on specific tokenzier and calibrate dataset, and should be set in training script. - self.register_buffer( - "d2t", torch.zeros(config.eagle["draft_vocab_size"], dtype=torch.int64) - ) + if config.draft_vocab_size < config.vocab_size: + self.register_buffer("d2t", torch.zeros(config.draft_vocab_size, dtype=torch.int64)) self.eagle_lm_head = nn.Linear( config.hidden_size, - config.eagle["draft_vocab_size"], + config.draft_vocab_size, bias=False, ) - if not config.eagle["use_aux_hidden_state"]: + if not config.use_aux_hidden_state: # In Eagle-1, the FC concentrate input embeddings and hidden states - self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=bias) else: # In EAGLE-3, the FC concentrate hidden states from multiple base model layers self.fc = nn.Linear( - len(config.eagle["eagle_aux_hidden_state_layer_ids"]) * config.hidden_size, + len(config.eagle_aux_hidden_state_layer_ids) * config.hidden_size, config.hidden_size, - bias=False, + bias=bias, ) first_layer_attn = self.layers[0].self_attn @@ -246,10 +249,8 @@ def __init__( ) # In EAGLE-3, input_embeds and hidden_states are normalized separately before concatenation. - self.input_embeds_norm = LlamaRMSNorm( - config.hidden_size, eps=config.eagle["rms_norm_eps"] - ) - self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.eagle["rms_norm_eps"]) + self.input_embeds_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Disable input norm in first layer. We normed embeds and h individually before. self.layers[0].input_layernorm = nn.Identity() @@ -303,7 +304,7 @@ def forward( position_ids = position_ids.view(-1, seq_length).long() inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device) - if self.config.eagle["use_aux_hidden_state"]: + if self.config.use_aux_hidden_state: # In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function # Also, we normalize input embeddings and hidden states before concatenating them. # The default input norm in first layer attn will be disabled. @@ -340,11 +341,14 @@ class HFEagleModel(EagleModel): def _set_default_aux_hidden_state_layers(self): num_layers = self.config.num_hidden_layers - default_layer_ids = [1, num_layers // 2 - 1, num_layers - 4] - # Remove negative and duplicate when base model is small - default_layer_ids = [max(0, i) for i in default_layer_ids] - default_layer_ids = list(set(default_layer_ids)) - self.eagle_aux_hidden_state_layer_ids = default_layer_ids + self.eagle_config.eagle_aux_hidden_state_layer_ids = [ + 1, + max(0, num_layers // 2 - 1), + max(0, num_layers - 4), + ] + self.eagle_config.eagle_aux_hidden_state_layer_ids = list( + set(self.eagle_config.eagle_aux_hidden_state_layer_ids) + ) def _collect_aux_hidden_states_forward_hook(self, module, input, output) -> None: """Collect auxiliary hidden states from base model intermediate layers, save them in attribute.""" @@ -360,7 +364,7 @@ def pop_aux_hidden_states(self): # In PTQ, forward method will be called with try and except to find max batch size. # This leads to uncleared aux hidden states in the front of the list. # To fix it, we only return the last num_aux_h items in the list. - num_aux_h = len(self.eagle_aux_hidden_state_layer_ids) + num_aux_h = len(self.eagle_config.eagle_aux_hidden_state_layer_ids) aux_h_list = self._aux_hidden_states[-num_aux_h:] self._aux_hidden_states.clear() @@ -368,17 +372,14 @@ def pop_aux_hidden_states(self): def modify( self, - eagle_num_layers, - use_input_layernorm_in_first_layer, - use_last_layernorm, + eagle_offline, eagle_hidden_state_distillation, - use_aux_hidden_state, - eagle_aux_hidden_state_layer_ids, - eagle_disable_moe, # Not used in HFEagleModel - draft_vocab_size, - use_mtp_layernorm, - parallel_draft_step=1, - ffn_hidden_size=0, + eagle_self_logit_distillation, + eagle_freeze_base_model, + eagle_report_acc, + eagle_reuse_base_decoder, + eagle_loss_decay_factor, + eagle_architecture_config, ): """Constructor. @@ -386,43 +387,39 @@ def modify( config: The config for eagle decoder layers. """ super().modify( - eagle_num_layers=eagle_num_layers, - use_input_layernorm_in_first_layer=use_input_layernorm_in_first_layer, - use_last_layernorm=use_last_layernorm, + eagle_offline=eagle_offline, eagle_hidden_state_distillation=eagle_hidden_state_distillation, - use_aux_hidden_state=use_aux_hidden_state, - eagle_aux_hidden_state_layer_ids=eagle_aux_hidden_state_layer_ids, - eagle_disable_moe=eagle_disable_moe, - draft_vocab_size=draft_vocab_size, - use_mtp_layernorm=use_mtp_layernorm, - parallel_draft_step=parallel_draft_step, + eagle_self_logit_distillation=eagle_self_logit_distillation, + eagle_freeze_base_model=eagle_freeze_base_model, + eagle_report_acc=eagle_report_acc, + eagle_reuse_base_decoder=eagle_reuse_base_decoder, + eagle_loss_decay_factor=eagle_loss_decay_factor, + eagle_architecture_config=eagle_architecture_config, + ) + self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config) + self.eagle_config._attn_implementation = "sdpa" + decoder_cls = ( + type(self.model.layers[-1]) if self.eagle_reuse_base_decoder else LlamaDecoderLayer ) - if use_aux_hidden_state and not eagle_aux_hidden_state_layer_ids: + # Use default aux_hidden_state layers if use_aux_hidden_state is True + # but no layer id is given + if ( + self.eagle_config.use_aux_hidden_state + and len(self.eagle_config.eagle_aux_hidden_state_layer_ids) == 0 + ): self._set_default_aux_hidden_state_layers() - self.config.eagle = { - "num_hidden_layers": eagle_num_layers, - "num_attention_heads": self.config.num_attention_heads, - "head_dim": getattr( - self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads - ), - "intermediate_size": self.config.intermediate_size, - "hidden_size": self.config.hidden_size, - "num_key_value_heads": self.config.num_key_value_heads, - "rms_norm_eps": self.config.rms_norm_eps, - "max_position_embeddings": self.config.max_position_embeddings, - "rope_theta": self.config.rope_theta, - "use_input_layernorm_in_first_layer": use_input_layernorm_in_first_layer, - "use_last_layernorm": use_last_layernorm, - "use_aux_hidden_state": use_aux_hidden_state, - "eagle_aux_hidden_state_layer_ids": self.eagle_aux_hidden_state_layer_ids, - "draft_vocab_size": draft_vocab_size, - } + if self.config.hidden_size != self.eagle_config.hidden_size: + raise ValueError( + "EAGLE module hidden size " + f"{self.eagle_config.hidden_size} must match base model hidden size " + f"{self.config.hidden_size}!" + ) self.eagle_module = EagleModule( - config=self.config, - decoder_layer_cls=LlamaDecoderLayer, + self.eagle_config, + decoder_cls, ) if hasattr(self.model.layers[-1].self_attn, "o_proj"): @@ -440,10 +437,10 @@ def modify( param.requires_grad = False # EAGLE-3 auxiluary hidden_states - if self.use_aux_hidden_state: + if self.eagle_config.use_aux_hidden_state: self._aux_hidden_states = [] for layer_idx, layer in enumerate(self.model.layers): - if layer_idx in self.eagle_aux_hidden_state_layer_ids: + if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids: layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook) def _prepare_decoder_attention_mask( @@ -719,7 +716,7 @@ def _base_model_forward( base_model_loss = loss_fct(loss_logits, labels) # Map the base model logits to the draft vocab - if self.draft_vocab_size > 0 and self.training: + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size and self.training: reverse_mapping = ( torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device) + self.eagle_module.d2t @@ -745,7 +742,9 @@ def _eagle_forward( position_embeddings=position_embeddings, ) eagle_lm_head = ( - self.eagle_module.eagle_lm_head if self.draft_vocab_size > 0 else self.lm_head + self.eagle_module.eagle_lm_head + if hasattr(self.eagle_module, "eagle_lm_head") + else self.lm_head ) eagle_logits = eagle_lm_head(eagle_postnorm_h) @@ -765,7 +764,6 @@ def forward( cache_position: torch.LongTensor | None = None, logits_to_keep: int = 0, loss_mask: torch.Tensor | None = None, - freeze_base_model: bool = True, classification_loss_coefficient: float | None = 1, regression_loss_coefficient: float | None = 0, **kwargs, @@ -797,7 +795,7 @@ def forward( attention_mask, position_ids, past_key_values, - freeze_base_model, + self.eagle_freeze_base_model, labels, kwargs, ) @@ -808,7 +806,7 @@ def forward( if self.training: # In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers batch_size, seq_length, _ = base_model_hidden_states.shape - if self.config.eagle["use_aux_hidden_state"]: + if self.eagle_config.use_aux_hidden_state: eagle_input_hidden_states = self.eagle_module.fc( torch.cat(self.pop_aux_hidden_states(), dim=-1) ) @@ -853,8 +851,7 @@ def forward( + classification_loss_coefficient * classification_loss ) - # ====Perform training-time-testing with 3 extra eagle forward passes==== - if self.training: + # ====Perform training-time-testing with 3 extra eagle forward passes==== # ====Second step of eagle forward==== eagle_input_hidden_states_1, eagle_input_ids_1, attention_mask_1, position_ids_1 = ( self._concat_eagle_inputs( @@ -1062,7 +1059,7 @@ def pseudo_speculative_generate( eagle_ids = torch.cat((input_ids[:, 1:], base_token), dim=-1) - if self.use_aux_hidden_state: + if self.eagle_config.use_aux_hidden_state: # EAGLE-3 # Only the first iteration input_hidden_states are from aux_hidden_state layers # Gather _aux_hidden_states from all devices before concatenation @@ -1100,7 +1097,7 @@ def pseudo_speculative_generate( ) draft_token = eagle_logits[:, -1:, :].argmax(dim=-1) - if self.draft_vocab_size > 0: + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: draft_token += self.eagle_module.d2t[draft_token] draft_tokens.append(draft_token) @@ -1114,6 +1111,13 @@ def pseudo_speculative_generate( return base_token, draft_tokens +@OfflineEagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) +class DetachedHFEagleModel(HFEagleModel): + """A wrapper for detached Eagle module.""" + + # TODO: Implement DetachedHFEagleModel class for offline eagle. + + class HFARValidation(AcceptanceRateValidation): """This is the subclass for HF model AR validation.""" diff --git a/setup.py b/setup.py index 5039f3825..697f36aea 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ "cupy-cuda12x; platform_machine != 'aarch64' and platform_system != 'Darwin'", "ml_dtypes", # for bfloat16 conversion "onnx-graphsurgeon", - "onnx>=1.18.0", + "onnx~=1.18.0", "onnxconverter-common", "onnxruntime~=1.22.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'", "onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501 diff --git a/tests/_test_utils/examples/run_command.py b/tests/_test_utils/examples/run_command.py index 22364c7ba..bf3a0b038 100644 --- a/tests/_test_utils/examples/run_command.py +++ b/tests/_test_utils/examples/run_command.py @@ -87,12 +87,34 @@ def run_llm_autodeploy_command( server_handler.terminate() -def run_torch_onnx_command(*, quantize_mode: str, onnx_save_path: str, **kwargs): - kwargs.update({"quantize_mode": quantize_mode, "onnx_save_path": onnx_save_path}) +def run_torch_onnx_command(*, quantize_mode: str, onnx_save_path: str, calib_size: str, **kwargs): + kwargs.update( + { + "quantize_mode": quantize_mode, + "onnx_save_path": onnx_save_path, + "calibration_data_size": calib_size, + } + ) cmd_parts = _extend_cmd_parts(["python", "torch_quant_to_onnx.py"], **kwargs) run_example_command(cmd_parts, "onnx_ptq") +def run_llm_export_command( + *, torch_dir: str, dtype: str, lm_head: str, output_dir: str, calib_size: str, **kwargs +): + kwargs.update( + { + "torch_dir": torch_dir, + "dtype": dtype, + "lm_head": lm_head, + "output_dir": output_dir, + "calib_size": calib_size, + } + ) + cmd_parts = _extend_cmd_parts(["python", "llm_export.py"], **kwargs) + run_example_command(cmd_parts, "onnx_ptq") + + def run_llm_ptq_command(*, model: str, quant: str, **kwargs): kwargs.update({"model": model, "quant": quant}) kwargs.setdefault("tasks", "build") diff --git a/tests/examples/onnx_ptq/test_torch_quant_to_onnx.py b/tests/examples/onnx_ptq/test_torch_quant_to_onnx.py index a7c930493..dc2ef023f 100644 --- a/tests/examples/onnx_ptq/test_torch_quant_to_onnx.py +++ b/tests/examples/onnx_ptq/test_torch_quant_to_onnx.py @@ -20,8 +20,14 @@ # TODO: Add accuracy evaluation after we upgrade TRT version to 10.12 @pytest.mark.parametrize( - ("quantize_mode", "onnx_save_path"), - [("nvfp4", "vit_base_patch16_224.nvfp4.onnx"), ("mxfp8", "vit_base_patch16_224.mxfp8.onnx")], + ("quantize_mode", "onnx_save_path", "calib_size"), + [ + ("nvfp4", "vit_base_patch16_224.nvfp4.onnx", "1"), + ("mxfp8", "vit_base_patch16_224.mxfp8.onnx", "1"), + ("int4_awq", "vit_base_patch16_224.int4_awq.onnx", "1"), + ], ) -def test_torch_onnx(quantize_mode, onnx_save_path): - run_torch_onnx_command(quantize_mode=quantize_mode, onnx_save_path=onnx_save_path) +def test_torch_onnx(quantize_mode, onnx_save_path, calib_size): + run_torch_onnx_command( + quantize_mode=quantize_mode, onnx_save_path=onnx_save_path, calib_size=calib_size + ) diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 79d59d17b..689a6aae3 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -28,9 +28,8 @@ def test_llama_eagle(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_p "--lr", "1e-5", "--do_eval", "False", "--num_gpu", str(num_gpus), - "--mode", "eagle", + "--mode", "eagle3", "--output_dir", tmp_path / "eagle-tinyllama", - "--eagle_num_layers", "1", ], "speculative_decoding", ) diff --git a/tests/gpu/torch/export/test_unified_export_megatron.py b/tests/gpu/torch/export/test_unified_export_megatron.py index 2d21da56a..16fb19c37 100644 --- a/tests/gpu/torch/export/test_unified_export_megatron.py +++ b/tests/gpu/torch/export/test_unified_export_megatron.py @@ -14,6 +14,7 @@ # limitations under the License. import json +from copy import deepcopy from functools import partial import pytest @@ -28,6 +29,7 @@ import modelopt.torch.speculative as mtsp from modelopt.torch.export import export_mcore_gpt_to_hf, import_mcore_gpt_from_hf +from modelopt.torch.speculative.eagle.default_config import default_eagle_config from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel @@ -69,7 +71,7 @@ def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size): model = mtsp.convert(model, [("medusa", config)]) assert isinstance(model, _DynamicMedusaGPTModel) elif algo == "eagle": - config = {"eagle_num_layers": 1} + config = {"eagle_architecture_config": deepcopy(default_eagle_config)} model = mtsp.convert(model, [("eagle", config)]) assert isinstance(model, _DynamicEagleGPTModel) diff --git a/tests/gpu/torch/quantization/backends/test_gemm_common.py b/tests/gpu/torch/quantization/backends/test_gemm_common.py index 6e5909b1e..ccfec1b0b 100644 --- a/tests/gpu/torch/quantization/backends/test_gemm_common.py +++ b/tests/gpu/torch/quantization/backends/test_gemm_common.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy + import pytest import torch +from _test_utils.torch_misc import set_seed from _test_utils.torch_quantization.models import OneLayerLinear, SimpleLinear from _test_utils.torch_quantization.quantize_common import compute_backward_grad @@ -23,14 +26,18 @@ from modelopt.torch.quantization.backends.nvfp4_gemm import Nvfp4Linear from modelopt.torch.quantization.backends.utils import fp4_compatible, fp8_compatible +set_seed() + @pytest.mark.parametrize( - ("model", "config", "gemm_forward"), + ("model", "config", "gemm_forward", "atol", "rtol"), [ pytest.param( OneLayerLinear(in_features=64, out_features=32), mtq.NVFP4_DEFAULT_CFG, Nvfp4Linear.apply, + 0.5, + 0.1, marks=[ pytest.mark.skipif(not fp4_compatible(), reason="FP4 is not supported on this GPU"), ], @@ -39,13 +46,15 @@ SimpleLinear(), mtq.FP8_DEFAULT_CFG, Fp8PerTensorLinear.apply, + 0.05, + 0.1, marks=[ pytest.mark.skipif(not fp8_compatible(), reason="FP8 is not supported on this GPU"), ], ), ], ) -def test_gemm(model, config, gemm_forward): +def test_gemm(model, config, gemm_forward, atol, rtol): model = model.to(torch.float16).cuda() calib_data = [model.get_input().to(torch.float16).cuda() for _ in range(8)] @@ -64,9 +73,9 @@ def forward_loop(model, run_backward=False): # Test without bias result_no_bias = gemm_forward(module, input_tensor, module.weight) diff = (result_no_bias - expected).abs() - assert torch.allclose(result_no_bias, expected, atol=0.5, rtol=0.1), ( + assert torch.allclose(result_no_bias, expected, atol=atol, rtol=rtol), ( f"Test without bias failed: {diff.amax()}\n" - f"{result_no_bias[diff > 0.1]} != {expected[diff > 0.1]}" + f"{result_no_bias[diff > atol]} != {expected[diff > atol]}" ) # Generate a random bias for testing @@ -76,17 +85,17 @@ def forward_loop(model, run_backward=False): # Test 1: Bias as keyword argument (kwargs) result_with_bias_kwargs = gemm_forward(module, input_tensor, module.weight, bias=bias) diff = (result_with_bias_kwargs - expected_with_bias).abs() - assert torch.allclose(result_with_bias_kwargs, expected_with_bias, atol=0.5, rtol=0.1), ( + assert torch.allclose(result_with_bias_kwargs, expected_with_bias, atol=atol, rtol=rtol), ( f"Bias as kwargs failed: {diff.amax()}\n" - f"{result_with_bias_kwargs[diff > 0.1]} != {expected_with_bias[diff > 0.1]}" + f"{result_with_bias_kwargs[diff > atol]} != {expected_with_bias[diff > atol]}" ) # Test 2: Bias as positional argument (args) result_with_bias_args = gemm_forward(module, input_tensor, module.weight, bias) diff = (result_with_bias_args - expected_with_bias).abs() - assert torch.allclose(result_with_bias_args, expected_with_bias, atol=0.5, rtol=0.1), ( + assert torch.allclose(result_with_bias_args, expected_with_bias, atol=atol, rtol=rtol), ( f"Bias as args failed: {diff.amax()}\n" - f"{result_with_bias_args[diff > 0.1]} != {expected_with_bias[diff > 0.1]}" + f"{result_with_bias_args[diff > atol]} != {expected_with_bias[diff > atol]}" ) # Verify both methods produce the same result @@ -96,11 +105,13 @@ def forward_loop(model, run_backward=False): @pytest.mark.parametrize( - ("model", "config"), + ("model", "config", "atol_bias", "atol_input"), [ pytest.param( SimpleLinear(), mtq.NVFP4_DEFAULT_CFG, + 0.5, + 0.2, marks=[ pytest.mark.skipif(not fp4_compatible(), reason="FP4 is not supported on this GPU"), ], @@ -108,19 +119,23 @@ def forward_loop(model, run_backward=False): pytest.param( SimpleLinear(), mtq.FP8_DEFAULT_CFG, + 0.02, + 0.02, marks=[ pytest.mark.skipif(not fp8_compatible(), reason="FP8 is not supported on this GPU"), ], ), ], ) -def test_compressed_backward_to_input(model, config): +def test_compressed_backward_to_input(model, config, atol_bias, atol_input): model = model.to(torch.float16).cuda() input_tensor = model.get_input().to(torch.float16).cuda() input_tensor.requires_grad = True + _, bias_grads = compute_backward_grad(model, input_tensor, config=config, quantize=True) input_grad = input_tensor.grad input_tensor.grad = None + weight_grads_quantized, bias_grads_quantized = compute_backward_grad( model, input_tensor, config=config, compress=True ) @@ -130,12 +145,116 @@ def test_compressed_backward_to_input(model, config): for bias_grad, bias_grad_quantized in zip(bias_grads, bias_grads_quantized): diff = (bias_grad - bias_grad_quantized).abs() - assert torch.allclose(bias_grad, bias_grad_quantized, atol=0.5), ( - f"bias grad mismatch: {bias_grad[diff > 0.5]} != {bias_grad_quantized[diff > 0.5]}" + assert torch.allclose(bias_grad, bias_grad_quantized, atol=atol_bias), ( + f"bias grad mismatch: {bias_grad[diff > atol_bias]} != {bias_grad_quantized[diff > atol_bias]}" ) diff = (input_grad - input_grad_quantized).abs() - assert torch.allclose(input_grad, input_grad_quantized, atol=0.2), ( + assert torch.allclose(input_grad, input_grad_quantized, atol=atol_input), ( f"input grad mismatch: {diff.amax()}\n" - f"{input_grad[diff > 0.2]} != {input_grad_quantized[diff > 0.2]}" + f"{input_grad[diff > atol_input]} != {input_grad_quantized[diff > atol_input]}" ) + + +@pytest.mark.parametrize( + ("model", "config", "gemm_forward", "atol", "rtol"), + [ + pytest.param( + OneLayerLinear(in_features=64, out_features=32), + mtq.NVFP4_DEFAULT_CFG, + Nvfp4Linear.apply, + 0.3, + 0.1, + marks=[ + pytest.mark.skipif(not fp4_compatible(), reason="FP4 is not supported on this GPU"), + ], + ), + pytest.param( + OneLayerLinear(in_features=64, out_features=32), + mtq.FP8_DEFAULT_CFG, + Fp8PerTensorLinear.apply, + 0.1, + 0.1, + marks=[ + pytest.mark.skipif(not fp8_compatible(), reason="FP8 is not supported on this GPU"), + ], + ), + ], +) +def test_dynamic_gemm(model, config, gemm_forward, atol, rtol): + model_fp16 = model.to(torch.float16).cuda() + calib_data = [model.get_input().to(torch.float16).cuda() for _ in range(8)] + + model_dynamic_quant = copy.deepcopy(model_fp16) + mtq.quantize(model_dynamic_quant, config) + + model_dynamic_quant_compressed = copy.deepcopy(model_dynamic_quant) + mtq.compress(model_dynamic_quant_compressed) + + def forward_loop(model, run_backward=False): + for batch in calib_data: + output = model(batch) + if run_backward: + output.sum().backward() + + model_calib_quant = copy.deepcopy(model_fp16) + mtq.quantize(model_calib_quant, config, forward_loop) + + model_calib_quant_compressed = copy.deepcopy(model_calib_quant) + mtq.compress(model_calib_quant_compressed) + + result_fp16 = [model_fp16(input_tensor) for input_tensor in calib_data] + + result_dynamic_quant = [model_dynamic_quant(input_tensor) for input_tensor in calib_data] + + module = model_dynamic_quant.net[0] + result_dynamic_quant_gemm = [ + gemm_forward(module, input_tensor, module.weight, bias=module.bias) + for input_tensor in calib_data + ] + result_dynamic_quant_compressed = [ + model_dynamic_quant_compressed(input_tensor) for input_tensor in calib_data + ] + + result_calib_quant = [model_calib_quant(input_tensor) for input_tensor in calib_data] + + module = model_calib_quant.net[0] + result_calib_quant_gemm = [ + gemm_forward(module, input_tensor, module.weight, bias=module.bias) + for input_tensor in calib_data + ] + + result_calib_quant_compressed = [ + model_calib_quant_compressed(input_tensor) for input_tensor in calib_data + ] + + for ( + output_fp16, + output_dynamic_quant, + output_dynamic_quant_gemm, + output_dynamic_quant_compressed, + output_calib_quant, + output_calib_quant_gemm, + output_calib_quant_compressed, + ) in zip( + result_fp16, + result_dynamic_quant, + result_dynamic_quant_gemm, + result_dynamic_quant_compressed, + result_calib_quant, + result_calib_quant_gemm, + result_calib_quant_compressed, + ): + assert torch.allclose(output_fp16, output_dynamic_quant_gemm, atol=atol, rtol=rtol) + assert torch.allclose(output_fp16, output_calib_quant_gemm, atol=atol, rtol=rtol) + + # The way the compression of the weights and inputs might be different. + # E.g. we may use torch.compile in the gemms.= + assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=0.05, rtol=0.1) + assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=0.05, rtol=0.1) + assert torch.allclose( + output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=0.05, rtol=0.1 + ) + assert torch.allclose( + output_calib_quant_gemm, output_calib_quant_compressed, atol=0.05, rtol=0.1 + ) diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index 441bbf986..904207e6b 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -17,11 +17,15 @@ import pytest import torch +from _test_utils.torch_misc import set_seed +from modelopt.torch.quantization.backends.utils import fp4_compatible from modelopt.torch.quantization.config import QuantizerAttributeConfig from modelopt.torch.quantization.nn import TensorQuantizer from modelopt.torch.quantization.qtensor import NVFP4QTensor +set_seed() + class TestQTensor: @pytest.mark.parametrize( @@ -89,7 +93,7 @@ def test_amax_from_tensor_quantizer( quantizer = TensorQuantizer(quant_cfg).to(device) # Mock amax - mock_amax = torch.rand(1, device=device) + mock_amax = torch.tensor(1.1, device=device) quantizer.amax = mock_amax x = torch.rand(32, 32).to(device).to(dtype=input_dtype) @@ -399,6 +403,11 @@ def _unpack_tensor(x): torch.tensor([[0.25, 0.75, 1.25], [1.75, 2.5, 3.5]], dtype=torch.float32), torch.tensor([[0.1, 2.5, 1.0, 4.8], [1.5, 1.25, 3.25, 5.0]], dtype=torch.float32), torch.tensor([[0, 0.75, 1.25], [1.75, 2.5, 5.5]], dtype=torch.float32), + torch.tensor([[-0.25, -0.75, -1.25], [-1.75, -2.5, -3.5]], dtype=torch.float32), + torch.tensor( + [[-0.1, -2.5, -1.0, -4.8], [-1.5, -1.25, -3.25, -5.0]], dtype=torch.float32 + ), + torch.tensor([[0, -0.75, -1.25], [-1.75, -2.5, -5.5]], dtype=torch.float32), ], ) def test_cast_fp4_equivalence(self, test_input, device): @@ -434,25 +443,28 @@ def _cast_fp4(weight: torch.Tensor): @pytest.mark.parametrize( "input_shape", - [(16, 32)], + [(1600, 1600)], ) def test_cast_fp4_impl_gpu_mem(self, input_shape): def _get_gpu_mem_used(): device = torch.device("cuda:0") free, total = torch.cuda.mem_get_info(device) - mem_used = (total - free) / 1024**2 + mem_used = total - free return mem_used + # Do a warmup + test_input = torch.rand((8, 8), dtype=torch.float32).to("cuda") + NVFP4QTensor._cast_fp4(test_input) + + test_input = torch.rand((input_shape), dtype=torch.float32).to("cuda") torch.cuda.empty_cache() # Define input and thresholds - test_input = torch.rand((input_shape), dtype=torch.float32).to("cuda") - # Size of input tensor in MB - input_size = (test_input.element_size() * test_input.numel()) / (1024**2) + input_size = test_input.element_size() * test_input.numel() before_quantize = _get_gpu_mem_used() NVFP4QTensor._cast_fp4(test_input) after_quantize = _get_gpu_mem_used() - assert after_quantize - before_quantize < input_size * 10 + assert (after_quantize - before_quantize) < input_size * 2.1 @pytest.mark.parametrize( ("num_bits", "block_sizes", "axis", "input_shape", "expected_output_shape"), @@ -489,3 +501,71 @@ def test_quantized_data_shape( q_x = quantizer(x) assert q_x._quantized_data.shape == expected_output_shape + + @pytest.mark.parametrize("shape", [(128, 64), (64, 128, 32)]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_nvfp4_qdq_correctness(self, shape, input_dtype): + """Test NVFP4 quantization and dequantization with fast option.""" + block_sizes = {-1: 16, "type": "dynamic", "scale_bits": (4, 3)} + + # Create test tensor + test_tensor = torch.randn(shape, dtype=input_dtype, device="cuda") + + # Quantize tensor + qtensor, scale, double_scale = NVFP4QTensor.quantize( + test_tensor, block_sizes[-1], try_tensorrt=False + ) + + # Dequantize using standard approach + dequant_standard = qtensor.dequantize( + dtype=input_dtype, + fast=False, + scale=scale, + double_scale=double_scale, + block_sizes=block_sizes, + ) + + # Check that standard dequantization is close to original + assert torch.allclose(dequant_standard, test_tensor, atol=0.5, rtol=0.1), ( + f"Standard dequantization differs from original: " + f"max diff = {(dequant_standard - test_tensor).abs().max()}" + ) + + @pytest.mark.skipif(not fp4_compatible(), reason="FP4 is not supported on this GPU") + @pytest.mark.parametrize("shape", [(128, 64), (64, 128, 32)]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_nvfp4_dequantize_fast(self, shape, input_dtype): + """Test NVFP4 quantization and dequantization with fast option.""" + block_sizes = {-1: 16, "type": "dynamic", "scale_bits": (4, 3)} + + # Create test tensor + test_tensor = torch.randn(shape, dtype=input_dtype, device="cuda") + + # Quantize tensor + qtensor, scale, double_scale = NVFP4QTensor.quantize( + test_tensor, block_sizes[-1], try_tensorrt=False + ) + + # Dequantize using standard approach + dequant_standard = qtensor.dequantize( + dtype=input_dtype, + fast=False, + scale=scale, + double_scale=double_scale, + block_sizes=block_sizes, + ) + + # Dequantize using fast approach + dequant_fast = qtensor.dequantize( + dtype=input_dtype, + fast=True, + scale=scale, + double_scale=double_scale, + block_sizes=block_sizes, + ) + + # Check that fast and standard dequantization produce the same results + assert torch.allclose(dequant_fast, dequant_standard, atol=1e-6, rtol=1e-5), ( + f"Fast and standard dequantization differ: " + f"max diff = {(dequant_fast - dequant_standard).abs().max()}" + ) diff --git a/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py b/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py index acd0db719..0ab314cd5 100644 --- a/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py +++ b/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import deque +from copy import deepcopy from functools import partial import pytest @@ -26,6 +27,7 @@ from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.eagle.default_config import default_eagle_config from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel, right_padding from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel from modelopt.torch.speculative.utils import Tree, get_default_attention_mask_and_position_ids @@ -63,7 +65,10 @@ def _test_speculative_gpt_model( # Type checking assert isinstance(model, _DynamicMedusaGPTModel) elif algo == "eagle": - config = {"eagle_num_layers": 1} + config = {"eagle_architecture_config": deepcopy(default_eagle_config)} + config["eagle_architecture_config"]["hidden_size"] = model.config.hidden_size + config["eagle_architecture_config"]["vocab_size"] = model.vocab_size + config["eagle_architecture_config"]["draft_vocab_size"] = model.vocab_size model = mtsp.convert(model, [("eagle", config)]) @@ -176,8 +181,6 @@ def _test_tree_decode(tree_paths, greedy_steps, rank, size): vocab_size = 64 batch_size = 1 - config = {"eagle_num_layers": 1} - model = get_mcore_gpt_model( tensor_model_parallel_size=size, pipeline_model_parallel_size=1, @@ -190,6 +193,11 @@ def _test_tree_decode(tree_paths, greedy_steps, rank, size): normalization=normalization, ).cuda() + config = {"eagle_architecture_config": deepcopy(default_eagle_config)} + config["eagle_architecture_config"]["hidden_size"] = model.config.hidden_size + config["eagle_architecture_config"]["vocab_size"] = model.vocab_size + config["eagle_architecture_config"]["draft_vocab_size"] = model.vocab_size + model = mtsp.convert(model, [("eagle", config)]) # Bfloat16 diff --git a/tests/unit/onnx/test_quantize_zint4.py b/tests/unit/onnx/test_quantize_zint4.py index 4249f1fe1..3f37bfcc6 100644 --- a/tests/unit/onnx/test_quantize_zint4.py +++ b/tests/unit/onnx/test_quantize_zint4.py @@ -45,6 +45,26 @@ def _matmul_model(w: np.ndarray, in_shape: Sequence[int], out_shape: Sequence[in return onnx_path +def _gather_model(rows, cols, tmp_path): + data_const = np.arange(rows * cols, dtype=np.float16).reshape(rows, cols) + data_node = gs.Constant("data", values=data_const) + + indices_var = gs.Variable("indices", dtype=np.int64, shape=()) + + gather_out = gs.Variable("output", dtype=np.float16, shape=(cols,)) + gather_node = gs.Node( + op="Gather", inputs=[data_node, indices_var], outputs=[gather_out], attrs={"axis": 0} + ) + + graph = gs.Graph(nodes=[gather_node], inputs=[indices_var], outputs=[gather_out]) + + onnx_model = gs.export_onnx(graph) + onnx_path = os.path.join(tmp_path, "gather_base_model.onnx") + save_onnx(onnx_model, onnx_path) + + return onnx_path + + def test_int4_rtn(tmp_path): # Test scale factor computation. # Use moq.quantize once to check that path doesnt have any bugs @@ -151,3 +171,57 @@ def test_shape_awq(tmp_path): block_size=8, use_external_data_format=False, ) # Ensure it passes. + + +def test_int4_gather(tmp_path): + gather_rows = 8 + gather_cols = 16 + gather_block_size = 8 + + m = _gather_model(gather_rows, gather_cols, tmp_path=tmp_path) + + m1 = quantize_int4( + m, calibration_method="rtn_dq", gather_block_size=gather_block_size, gather_quantize_axis=1 + ) + m2 = quantize_int4( + m, calibration_method="rtn_dq", gather_block_size=gather_block_size, gather_quantize_axis=0 + ) + + def is_gather_quantized(model): + g = gs.import_onnx(model) + for node in g.nodes: + if node.op != "Gather": + continue + for inp in node.inputs: + if inp.name != "indices": + # print(f"inp={inp}, p={inp.inputs[0].op}") + assert inp.inputs[0].op == "DequantizeLinear", ( + f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!" + ) + return True + + assert is_gather_quantized(m1), "Failure in rtn_dq quantization of Gather node, quant-axis: 1" + assert is_gather_quantized(m2), "Failure in rtn_dq quantization of Gather node, quant-axis: 0" + + def is_quant_scale_with_right_shape(model, quant_axis, block_size): + assert quant_axis in [0, 1], "Incorrect quant-axis" # used for 0/1 indexing below + orig_shape = [gather_rows, gather_cols] + graph = gs.import_onnx(model) + for node in graph.nodes: + if node.op == "DequantizeLinear": + for inp in node.inputs: + if inp.name == "x_scale": + print(f"\nname={inp.name}, shape={inp.shape}\n") + c1 = (orig_shape[quant_axis] // block_size) == inp.shape[quant_axis] + c2 = orig_shape[1 - quant_axis] == inp.shape[1 - quant_axis] + assert c1 and c2, "Incorrect scale shape in DQ node for Gather" + return True + + assert is_quant_scale_with_right_shape(m1, 1, gather_block_size), ( + "DQ Scale Error in rtn_dq quantization, axis 1" + ) + assert is_quant_scale_with_right_shape(m2, 0, gather_block_size), ( + "DQ Scale Error in rtn_dq quantization, axis 0" + ) + + # Ensure above tests pass. diff --git a/tests/unit/torch/speculative/plugins/test_hf_speculative.py b/tests/unit/torch/speculative/plugins/test_hf_speculative.py index de7bda152..51d93996e 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_speculative.py +++ b/tests/unit/torch/speculative/plugins/test_hf_speculative.py @@ -14,6 +14,7 @@ # limitations under the License. import os +from copy import deepcopy import pytest import torch @@ -25,6 +26,7 @@ from transformers import AutoModelForCausalLM, LlamaForCausalLM import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import EAGLE1_DEFAULT_CFG, EAGLE3_DEFAULT_CFG def test_medusa_model_convert_save_and_restore(tmp_path): @@ -46,13 +48,18 @@ def test_medusa_model_convert_save_and_restore(tmp_path): tf_modelopt_state_and_output_tester(model_ref, model_test) -def test_eagle_model_convert_save_and_restore(tmp_path): +@pytest.mark.parametrize("eagle_config", [EAGLE1_DEFAULT_CFG, EAGLE3_DEFAULT_CFG]) +def test_eagle_model_convert_save_and_restore(tmp_path, eagle_config): model_ref = get_tiny_llama(num_hidden_layers=8) - config = { - "eagle_num_layers": 1, - "use_aux_hidden_state": True, - } + config = deepcopy(eagle_config["config"]) + config["eagle_architecture_config"].update( + { + "draft_vocab_size": model_ref.config.vocab_size, + "hidden_size": model_ref.config.hidden_size, + } + ) + mtsp.convert(model_ref, mode=[("eagle", config)]) assert isinstance(model_ref, mtsp.plugins.HFEagleModel) @@ -69,10 +76,11 @@ def test_eagle_model_convert_save_and_restore(tmp_path): def test_eagle_model_prepare_eagle_inputs(dtype): dummy_model = get_tiny_llama(num_hidden_layers=4) - config = { - "eagle_num_layers": 1, - "use_aux_hidden_state": True, - } + config = EAGLE3_DEFAULT_CFG["config"] + config["eagle_architecture_config"].update({ + "draft_vocab_size": dummy_model.config.vocab_size, + "hidden_size": dummy_model.config.hidden_size, + }) mtsp.convert(dummy_model, mode=[("eagle", config)]) eagle_input_ids_0 = torch.tensor([[10, 20, 30, 40]], dtype=torch.long)