|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "b54dd4ac-5b4f-42fb-af1a-8f3001abd08a", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# NVIDIA ModelOpt Quantization Aware Training (QAT) Walkthrough" |
| 9 | + ] |
| 10 | + }, |
| 11 | + { |
| 12 | + "cell_type": "markdown", |
| 13 | + "id": "a695be45-1472-42bc-824e-5c992a487fa7", |
| 14 | + "metadata": {}, |
| 15 | + "source": [ |
| 16 | + "**Quantization Aware Training (QAT)** is a method that simulates the effects of quantization during neural network post-training to preserve accuracy when deploying models in very-low-precision formats. Unlike post-training quantization, QAT inserts \"fake quantization\" nodes into the computational graph, mimicking the rounding and clamping operations that occur during actual quantization. This allows the model to adapt its weights and activations to mitigate accuracy loss.\n", |
| 17 | + "\n", |
| 18 | + "This notebook demonstrates how to apply Quantization Aware Training (QAT) to an LLM, Meta's Llama-3.1-8b in this case with NVIDIA's TensorRT Model Optimizer (ModelOpt) QAT toolkit. We walk through downloading and loading the model, calibrating it using an example dataset, specifically CNN/DailyMail dataset sample, applying NVFP4 quantization, generating outputs, and exporting the quantized model." |
| 19 | + ] |
| 20 | + }, |
| 21 | + { |
| 22 | + "cell_type": "markdown", |
| 23 | + "id": "c3f7f931-ac38-494e-aea8-ca2cd6d05794", |
| 24 | + "metadata": {}, |
| 25 | + "source": [ |
| 26 | + "## Installing Prerequisites and Dependancies" |
| 27 | + ] |
| 28 | + }, |
| 29 | + { |
| 30 | + "cell_type": "markdown", |
| 31 | + "id": "d7d4f25f-e569-42cf-8022-bb7cc6f9ea6e", |
| 32 | + "metadata": {}, |
| 33 | + "source": [ |
| 34 | + "If you haven't already, install the required dependencies for this notebook. Key dependancies include:\n", |
| 35 | + "\n", |
| 36 | + "- nvidia-modelopt\n", |
| 37 | + "- torch\n", |
| 38 | + "- transformers\n", |
| 39 | + "- jupyterlab\n", |
| 40 | + "\n", |
| 41 | + "This repo contains a `examples/llm_qat/notebooks/requirements.txt` file that can be used to install all required dependancies." |
| 42 | + ] |
| 43 | + }, |
| 44 | + { |
| 45 | + "cell_type": "code", |
| 46 | + "execution_count": null, |
| 47 | + "id": "ab464a07-8a19-43a9-a715-81ccef350253", |
| 48 | + "metadata": { |
| 49 | + "scrolled": true |
| 50 | + }, |
| 51 | + "outputs": [], |
| 52 | + "source": [ |
| 53 | + "!pip install -r requirements.txt" |
| 54 | + ] |
| 55 | + }, |
| 56 | + { |
| 57 | + "cell_type": "markdown", |
| 58 | + "id": "99c6ca5d-0d08-4b6c-814f-b8a92a8469f2", |
| 59 | + "metadata": {}, |
| 60 | + "source": [ |
| 61 | + "## Setting HuggingFace Token and Model for Download" |
| 62 | + ] |
| 63 | + }, |
| 64 | + { |
| 65 | + "cell_type": "markdown", |
| 66 | + "id": "09f3c6a7-1c9c-4254-9524-7e253528d9d7", |
| 67 | + "metadata": {}, |
| 68 | + "source": [ |
| 69 | + "Set the HF_TOKEN environment variable making sure to update it to include you token (eg. `%env HF_TOKEN=hf_abdxyz...`)" |
| 70 | + ] |
| 71 | + }, |
| 72 | + { |
| 73 | + "cell_type": "code", |
| 74 | + "execution_count": null, |
| 75 | + "id": "eb2071be-df85-4961-92b0-567830a37d71", |
| 76 | + "metadata": { |
| 77 | + "scrolled": true |
| 78 | + }, |
| 79 | + "outputs": [], |
| 80 | + "source": [ |
| 81 | + "%env HF_TOKEN=<YOUR_HUGGINGFACE_TOKEN>" |
| 82 | + ] |
| 83 | + }, |
| 84 | + { |
| 85 | + "cell_type": "markdown", |
| 86 | + "id": "4eab1f6a-5855-4f2a-a982-27b7e756deca", |
| 87 | + "metadata": {}, |
| 88 | + "source": [ |
| 89 | + "As mentioned above, we will use Meta's **Llamma-3.1-8B-Instruct** in this example" |
| 90 | + ] |
| 91 | + }, |
| 92 | + { |
| 93 | + "cell_type": "code", |
| 94 | + "execution_count": 6, |
| 95 | + "id": "6d25c2b1-a68b-4748-ac29-e8a893ce1762", |
| 96 | + "metadata": {}, |
| 97 | + "outputs": [], |
| 98 | + "source": [ |
| 99 | + "model_name = \"meta-llama/Llama-3.1-8B-Instruct\"" |
| 100 | + ] |
| 101 | + }, |
| 102 | + { |
| 103 | + "cell_type": "markdown", |
| 104 | + "id": "41b7c23f-748b-4c9e-8883-d6ca24af46ed", |
| 105 | + "metadata": {}, |
| 106 | + "source": [ |
| 107 | + "## Import Required Libraries" |
| 108 | + ] |
| 109 | + }, |
| 110 | + { |
| 111 | + "cell_type": "code", |
| 112 | + "execution_count": 14, |
| 113 | + "id": "0ec71181-770a-4ee6-8760-c62cfab8340f", |
| 114 | + "metadata": {}, |
| 115 | + "outputs": [], |
| 116 | + "source": [ |
| 117 | + "import torch\n", |
| 118 | + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", |
| 119 | + "from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader" |
| 120 | + ] |
| 121 | + }, |
| 122 | + { |
| 123 | + "cell_type": "markdown", |
| 124 | + "id": "53f8c53a-94ab-4968-aac8-866d850ef874", |
| 125 | + "metadata": {}, |
| 126 | + "source": [ |
| 127 | + "## Download and Load Model and Tokenizer" |
| 128 | + ] |
| 129 | + }, |
| 130 | + { |
| 131 | + "cell_type": "code", |
| 132 | + "execution_count": 18, |
| 133 | + "id": "5f946576-83ac-45b5-a290-9a2167193e3d", |
| 134 | + "metadata": {}, |
| 135 | + "outputs": [ |
| 136 | + { |
| 137 | + "data": { |
| 138 | + "application/vnd.jupyter.widget-view+json": { |
| 139 | + "model_id": "37c5f366ef204794bad4711ae6056d6c", |
| 140 | + "version_major": 2, |
| 141 | + "version_minor": 0 |
| 142 | + }, |
| 143 | + "text/plain": [ |
| 144 | + "Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]" |
| 145 | + ] |
| 146 | + }, |
| 147 | + "metadata": {}, |
| 148 | + "output_type": "display_data" |
| 149 | + } |
| 150 | + ], |
| 151 | + "source": [ |
| 152 | + "model = AutoModelForCausalLM.from_pretrained(model_name).cuda()\n", |
| 153 | + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", |
| 154 | + "tokenizer.pad_token = tokenizer.eos_token\n", |
| 155 | + "tokenizer.padding_side = \"left\" # Setting this as tokenizer with the right padding_side may impact calibration accuracy." |
| 156 | + ] |
| 157 | + }, |
| 158 | + { |
| 159 | + "cell_type": "markdown", |
| 160 | + "id": "fef14995-c3cd-445d-a93a-55b60229600a", |
| 161 | + "metadata": {}, |
| 162 | + "source": [ |
| 163 | + "## Load Calibration Datset" |
| 164 | + ] |
| 165 | + }, |
| 166 | + { |
| 167 | + "cell_type": "code", |
| 168 | + "execution_count": 19, |
| 169 | + "id": "f3b618e9-fdee-46b2-8d7e-f11f1f7ada8d", |
| 170 | + "metadata": {}, |
| 171 | + "outputs": [], |
| 172 | + "source": [ |
| 173 | + "\n", |
| 174 | + "# Calibration dataloader\n", |
| 175 | + "calib_loader = get_dataset_dataloader(\n", |
| 176 | + " dataset_name=\"cnn_dailymail\",\n", |
| 177 | + " tokenizer=tokenizer,\n", |
| 178 | + " batch_size=8,\n", |
| 179 | + " num_samples=512,\n", |
| 180 | + " device=\"cuda\"\n", |
| 181 | + ")\n", |
| 182 | + "\n", |
| 183 | + "forward_loop = create_forward_loop(dataloader=calib_loader)" |
| 184 | + ] |
| 185 | + }, |
| 186 | + { |
| 187 | + "cell_type": "markdown", |
| 188 | + "id": "f971d580-495f-4fb9-b390-0e723f6d1c18", |
| 189 | + "metadata": {}, |
| 190 | + "source": [ |
| 191 | + "## Set Quatization Config and Quantize the Model" |
| 192 | + ] |
| 193 | + }, |
| 194 | + { |
| 195 | + "cell_type": "markdown", |
| 196 | + "id": "7c35bf66-fe51-48ff-809b-884de91fbaf2", |
| 197 | + "metadata": {}, |
| 198 | + "source": [ |
| 199 | + "Applying QAT with Model Optimizer is fairly straightforward. QAT supports the same quantization formats as the PTQ workflow, including key formats such as FP8, NVFP4, MXFP4, INT8, and INT4. In this example, we are using the default NVFP4 config which quantizes weight and activation to NVFP4 format. \n" |
| 200 | + ] |
| 201 | + }, |
| 202 | + { |
| 203 | + "cell_type": "code", |
| 204 | + "execution_count": 17, |
| 205 | + "id": "51c0c1bb-2804-45ae-873f-e33388458e04", |
| 206 | + "metadata": {}, |
| 207 | + "outputs": [ |
| 208 | + { |
| 209 | + "name": "stdout", |
| 210 | + "output_type": "stream", |
| 211 | + "text": [ |
| 212 | + "Registered <class 'transformers.models.llama.modeling_llama.LlamaAttention'> to _QuantAttention for KV Cache quantization\n", |
| 213 | + "Inserted 771 quantizers\n" |
| 214 | + ] |
| 215 | + }, |
| 216 | + { |
| 217 | + "name": "stderr", |
| 218 | + "output_type": "stream", |
| 219 | + "text": [ |
| 220 | + "100%|█████████████████████████████████████████████████████████████████████████████| 64/64 [01:14<00:00, 1.16s/it]\n" |
| 221 | + ] |
| 222 | + } |
| 223 | + ], |
| 224 | + "source": [ |
| 225 | + "import modelopt.torch.quantization as mtq\n", |
| 226 | + "\n", |
| 227 | + "config = mtq.NVFP4_DEFAULT_CFG\n", |
| 228 | + "\n", |
| 229 | + "# # Define forward loop for calibration\n", |
| 230 | + "# def forward_loop(model):\n", |
| 231 | + "# for data in calib_set:\n", |
| 232 | + "# model(data)\n", |
| 233 | + "\n", |
| 234 | + "# quantize the model and prepare for QAT\n", |
| 235 | + "model = mtq.quantize(model, config, forward_loop) \n" |
| 236 | + ] |
| 237 | + }, |
| 238 | + { |
| 239 | + "cell_type": "code", |
| 240 | + "execution_count": 13, |
| 241 | + "id": "c1a15f93-ee06-42a5-ab3b-ca3428a62fe7", |
| 242 | + "metadata": {}, |
| 243 | + "outputs": [], |
| 244 | + "source": [ |
| 245 | + "import modelopt.torch.opt as mto\n", |
| 246 | + "torch.save(mto.modelopt_state(model), \"modelopt_quantizer_states.pt\")" |
| 247 | + ] |
| 248 | + }, |
| 249 | + { |
| 250 | + "cell_type": "code", |
| 251 | + "execution_count": 14, |
| 252 | + "id": "95411f4c-b1d3-4e82-9afb-2608bd21a9a4", |
| 253 | + "metadata": {}, |
| 254 | + "outputs": [ |
| 255 | + { |
| 256 | + "ename": "SyntaxError", |
| 257 | + "evalue": "incomplete input (1262456024.py, line 1)", |
| 258 | + "output_type": "error", |
| 259 | + "traceback": [ |
| 260 | + " \u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[31m \u001b[39m\u001b[31mtrainer = Trainer(model=model, processing_class=tokenizer, args=training_args, **data_module\u001b[39m\n ^\n\u001b[31mSyntaxError\u001b[39m\u001b[31m:\u001b[39m incomplete input\n" |
| 261 | + ] |
| 262 | + } |
| 263 | + ], |
| 264 | + "source": [ |
| 265 | + "trainer = Trainer(model=model, processing_class=tokenizer, args=training_args, **data_module" |
| 266 | + ] |
| 267 | + }, |
| 268 | + { |
| 269 | + "cell_type": "code", |
| 270 | + "execution_count": null, |
| 271 | + "id": "e5ff221a-d807-450b-a099-6481cb3b00d0", |
| 272 | + "metadata": {}, |
| 273 | + "outputs": [], |
| 274 | + "source": [] |
| 275 | + } |
| 276 | + ], |
| 277 | + "metadata": { |
| 278 | + "kernelspec": { |
| 279 | + "display_name": "Python 3 (ipykernel)", |
| 280 | + "language": "python", |
| 281 | + "name": "python3" |
| 282 | + }, |
| 283 | + "language_info": { |
| 284 | + "codemirror_mode": { |
| 285 | + "name": "ipython", |
| 286 | + "version": 3 |
| 287 | + }, |
| 288 | + "file_extension": ".py", |
| 289 | + "mimetype": "text/x-python", |
| 290 | + "name": "python", |
| 291 | + "nbconvert_exporter": "python", |
| 292 | + "pygments_lexer": "ipython3", |
| 293 | + "version": "3.12.3" |
| 294 | + } |
| 295 | + }, |
| 296 | + "nbformat": 4, |
| 297 | + "nbformat_minor": 5 |
| 298 | +} |
0 commit comments