|
| 1 | +(torchao_hf_integration)= |
| 2 | +# Integration with HuggingFace: Architecture and Usage Guide |
| 3 | + |
| 4 | +```{contents} |
| 5 | +:local: |
| 6 | +:depth: 2 |
| 7 | +``` |
| 8 | + |
| 9 | +(configuration-system)= |
| 10 | +## Configuration System |
| 11 | + |
| 12 | +(huggingface-model-configuration)= |
| 13 | +### 1. HuggingFace Model Configuration |
| 14 | + |
| 15 | +TorchAO quantization is configured through the model's `config.json` file: |
| 16 | + |
| 17 | +```json |
| 18 | +{ |
| 19 | + "model_type": "llama", |
| 20 | + "quant_type": { |
| 21 | + "default": { |
| 22 | + "_type": "Int4WeightOnlyConfig", |
| 23 | + "_data": { |
| 24 | + "group_size": 128, |
| 25 | + "use_hqq": true |
| 26 | + } |
| 27 | + } |
| 28 | + } |
| 29 | +} |
| 30 | +``` |
| 31 | + |
| 32 | +(torchao-configuration-classes)= |
| 33 | +### 2. TorchAO Configuration Classes |
| 34 | + |
| 35 | +All quantization methods inherit from `AOBaseConfig`: |
| 36 | + |
| 37 | +```python |
| 38 | +from torchao.core.config import AOBaseConfig |
| 39 | +from torchao.quantization import Int4WeightOnlyConfig |
| 40 | + |
| 41 | +# Example configuration |
| 42 | +config = Int4WeightOnlyConfig( |
| 43 | + group_size=128, |
| 44 | + use_hqq=True, |
| 45 | +) |
| 46 | +assert isinstance(config, AOBaseConfig) |
| 47 | +``` |
| 48 | + |
| 49 | +```{note} |
| 50 | +All quantization configurations inherit from {class}`torchao.core.config.AOBaseConfig`, which provides serialization and validation capabilities. |
| 51 | +``` |
| 52 | + |
| 53 | +(module-level-configuration)= |
| 54 | +### 3. Module-Level Configuration |
| 55 | + |
| 56 | +For granular control, use `ModuleFqnToConfig`: |
| 57 | + |
| 58 | +```python |
| 59 | +from torchao.quantization import ModuleFqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig |
| 60 | + |
| 61 | +config = ModuleFqnToConfig({ |
| 62 | + "model.layers.0.self_attn.q_proj": Int4WeightOnlyConfig(group_size=64), |
| 63 | + "model.layers.0.self_attn.k_proj": Int4WeightOnlyConfig(group_size=64), |
| 64 | + "model.layers.0.mlp.gate_proj": Int8WeightOnlyConfig(), |
| 65 | + "_default": Int4WeightOnlyConfig(group_size=128) # Default for other modules |
| 66 | +}) |
| 67 | +``` |
| 68 | + |
| 69 | +(usage-examples)= |
| 70 | +## Usage Examples |
| 71 | + |
| 72 | +First, install the required packages. |
| 73 | + |
| 74 | +```bash |
| 75 | +pip install git+https://github.com/huggingface/transformers@main |
| 76 | +pip install torchao |
| 77 | +pip install torch |
| 78 | +pip install accelerate |
| 79 | +``` |
| 80 | + |
| 81 | +(quantizing-models-huggingface)= |
| 82 | +### 1. Quantizing Models with HuggingFace Integration |
| 83 | + |
| 84 | +Below is an example of using `Float8DynamicActivationInt4WeightConfig` on the Llama-3.2-1B model. |
| 85 | + |
| 86 | +```python |
| 87 | +from transformers import TorchAoConfig, AutoModelForCausalLM |
| 88 | +from torchao.quantization import Float8DynamicActivationInt4WeightConfig |
| 89 | + |
| 90 | +# Create quantization configuration |
| 91 | +quantization_config = TorchAoConfig( |
| 92 | + quant_type=Float8DynamicActivationInt4WeightConfig(group_size=128, use_hqq=True) |
| 93 | +) |
| 94 | + |
| 95 | +# Load and automatically quantize the model |
| 96 | +model = AutoModelForCausalLM.from_pretrained( |
| 97 | + "meta-llama/Llama-3.2-1B", |
| 98 | + torch_dtype="auto", |
| 99 | + device_map="auto", |
| 100 | + quantization_config=quantization_config |
| 101 | +) |
| 102 | +``` |
| 103 | + |
| 104 | +After we quantize the model, we can save it. |
| 105 | + |
| 106 | +```python |
| 107 | +# Save quantized model (see Serialization section below for safe_serialization details) |
| 108 | +model.push_to_hub("your-username/Llama-3.2-1B-int4", safe_serialization=False) |
| 109 | +``` |
| 110 | + |
| 111 | +Here is another example using `Float8DynamicActivationFloat8WeightConfig` on the Phi-4-mini-instruct model. |
| 112 | + |
| 113 | +```python |
| 114 | +import torch |
| 115 | +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig |
| 116 | +from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow |
| 117 | + |
| 118 | +model_id = "microsoft/Phi-4-mini-instruct" |
| 119 | + |
| 120 | +# Create quantization configuration |
| 121 | +quant_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) |
| 122 | +quantization_config = TorchAoConfig(quant_type=quant_config) |
| 123 | + |
| 124 | +# Load and automatically quantize the model |
| 125 | +quantized_model = AutoModelForCausalLM.from_pretrained( |
| 126 | + model_id, |
| 127 | + device_map="auto", |
| 128 | + torch_dtype=torch.bfloat16, |
| 129 | + quantization_config=quantization_config |
| 130 | +) |
| 131 | + |
| 132 | +tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 133 | +``` |
| 134 | + |
| 135 | +Just like the previous example, we can now save the quantized model. |
| 136 | + |
| 137 | +```python |
| 138 | +# Save quantized model (see Serialization section below for safe_serialization details) |
| 139 | +USER_ID = "YOUR_USER_ID" |
| 140 | +MODEL_NAME = model_id.split("/")[-1] |
| 141 | +save_to = f"{USER_ID}/{MODEL_NAME}-float8dq" |
| 142 | +quantized_model.push_to_hub(save_to, safe_serialization=False) |
| 143 | + |
| 144 | +# Save tokenizer |
| 145 | +tokenizer.push_to_hub(save_to) |
| 146 | +``` |
| 147 | + |
| 148 | +```{seealso} |
| 149 | +For more information on quantization configs, see {class}`torchao.quantization.Int4WeightOnlyConfig`, {class}`torchao.quantization.Float8DynamicActivationInt4WeightConfig`, and {class}`torchao.quantization.Int8WeightOnlyConfig`. |
| 150 | +``` |
| 151 | + |
| 152 | +```{note} |
| 153 | +For more information on supported quantization and sparsity configurations, see [HF-Torchao Docs](https://huggingface.co/docs/transformers/main/en/quantization/torchao). |
| 154 | +``` |
| 155 | + |
| 156 | +(serving-with-vllm)= |
| 157 | +### 2. Serving with VLLM |
| 158 | + |
| 159 | +```{note} |
| 160 | +For more information on serving and inference with VLLM, please refer to [Integration with VLLM: Architecture and Usage Guide](https://docs.pytorch.org/ao/main/torchao_vllm_integration.html) and [(Part 3) Serving on vLLM, SGLang, ExecuTorch](https://docs.pytorch.org/ao/main/serving.html) for a full end-to-end tutorial. |
| 161 | +``` |
| 162 | + |
| 163 | +(Inference-with-HuggingFace-Transformers)= |
| 164 | +### 3. Inference with HuggingFace Transformers |
| 165 | + |
| 166 | +Recall how we can quantize models using HuggingFace Transformers. |
| 167 | + |
| 168 | +```python |
| 169 | +import torch |
| 170 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 171 | + |
| 172 | +torch.random.manual_seed(0) |
| 173 | + |
| 174 | +model_path = "pytorch/Phi-4-mini-instruct-float8dq" |
| 175 | + |
| 176 | +# Load and automatically quantize the model |
| 177 | +model = AutoModelForCausalLM.from_pretrained( |
| 178 | + model_path, |
| 179 | + device_map="auto", |
| 180 | + torch_dtype="auto", |
| 181 | + trust_remote_code=True, |
| 182 | +) |
| 183 | +tokenizer = AutoTokenizer.from_pretrained(model_path) |
| 184 | +``` |
| 185 | + |
| 186 | +Now we can use the model for inference. |
| 187 | + |
| 188 | +```python |
| 189 | +from transformers import pipeline |
| 190 | + |
| 191 | +# Simulate conversation between user and assistant |
| 192 | +messages = [ |
| 193 | + {"role": "system", "content": "You are a helpful AI assistant."}, |
| 194 | + {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"}, |
| 195 | + {"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."}, |
| 196 | + {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"}, |
| 197 | +] |
| 198 | + |
| 199 | +# Initialize HuggingFace pipeline for text generation |
| 200 | +pipe = pipeline( |
| 201 | + "text-generation", |
| 202 | + model=model, |
| 203 | + tokenizer=tokenizer, |
| 204 | +) |
| 205 | + |
| 206 | +generation_args = { |
| 207 | + "max_new_tokens": 500, |
| 208 | + "return_full_text": False, |
| 209 | + "temperature": 0.0, |
| 210 | + "do_sample": False, |
| 211 | +} |
| 212 | + |
| 213 | +# Generate output |
| 214 | +output = pipe(messages, **generation_args) |
| 215 | +print(output[0]['generated_text']) |
| 216 | +``` |
| 217 | + |
| 218 | +```{seealso} |
| 219 | +For more examples and recommended quantization methods based on different hardwares (i.e. A100 GPU, H100 GPU, CPU), see [HF-Torchao Docs (Quantization Examples)](https://huggingface.co/docs/transformers/main/en/quantization/torchao#quantization-examples). |
| 220 | +``` |
| 221 | + |
| 222 | +(Inference-with-HuggingFace-Diffuser)= |
| 223 | +### 4. Inference with HuggingFace Diffuser |
| 224 | + |
| 225 | +```bash |
| 226 | +pip install git+https://github.com/huggingface/diffusers@main |
| 227 | +``` |
| 228 | + |
| 229 | +Below is an example of how we can integrate with HuggingFace Diffusers. |
| 230 | + |
| 231 | +```python |
| 232 | +from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig |
| 233 | + |
| 234 | +model_id = "black-forest-labs/Flux.1-Dev" |
| 235 | +dtype = torch.bfloat16 |
| 236 | + |
| 237 | +quantization_config = TorchAoConfig("int8wo") |
| 238 | +transformer = FluxTransformer2DModel.from_pretrained( |
| 239 | + model_id, |
| 240 | + subfolder="transformer", |
| 241 | + quantization_config=quantization_config, |
| 242 | + torch_dtype=dtype, |
| 243 | +) |
| 244 | +pipe = FluxPipeline.from_pretrained( |
| 245 | + model_id, |
| 246 | + transformer=transformer, |
| 247 | + torch_dtype=dtype, |
| 248 | +) |
| 249 | +pipe.to("cuda") |
| 250 | + |
| 251 | +prompt = "A cat holding a sign that says hello world" |
| 252 | +image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] |
| 253 | +image.save("output.png") |
| 254 | +``` |
| 255 | + |
| 256 | +We can also use [torch.compile]() to speed up inference by adding this line of code after initializing the transformer. |
| 257 | +```python |
| 258 | +# In the above code, add the following after initializing the transformer |
| 259 | +transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) |
| 260 | +``` |
| 261 | + |
| 262 | +```{seealso} |
| 263 | +Please refer to [HF-TorchAO-Diffuser Docs](https://huggingface.co/docs/diffusers/en/quantization/torchao) for more examples and benchmarking results. |
| 264 | +``` |
| 265 | + |
| 266 | +```{note} |
| 267 | +Refer [here](https://github.com/huggingface/diffusers/pull/10009) for time and memory results from a single H100 GPU. |
| 268 | +``` |
| 269 | + |
| 270 | +(Supported-Quantization-Types)= |
| 271 | +## Supported Quantization Types |
| 272 | + |
| 273 | +Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation. |
| 274 | + |
| 275 | +Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly. |
| 276 | + |
| 277 | +Below are the supported quantization types. |
| 278 | + |
| 279 | +| Category | Full Function Names |
| 280 | +|:---------|:-------------------| |
| 281 | +| Integer quantization | `Int8DynamicActivationInt4WeightConfig`<br>`Int8DynamicActivationIntxWeightConfig`<br>`Int4DynamicActivationInt4WeightConfig`<br> `Int4WeightOnlyConfig`<br> `Int8WeightOnlyConfig`<br>`Int8DynamicActivationInt8WeightConfig`| |
| 282 | +| Floating point quantization | `Float8DynamicActivationInt4WeightConfig`<br>`Float8WeightOnlyConfig`<br>`Float8DynamicActivationFloat8WeightConfig`<br> `Float8DynamicActivationFloat8SemiSparseWeightConfig`<br> `Float8StaticActivationFloat8WeightConfig` | |
| 283 | +| Integer X-bit quantization | `IntxWeightOnlyConfig` | |
| 284 | +| Floating point X-bit quantization | `FPXWeightOnlyConfig` | |
| 285 | +| Unsigned Integer Quanization | `GemliteUIntXWeightOnlyConfig` <br> `UIntXWeightOnlyConfig`| |
| 286 | + |
| 287 | +```{note} |
| 288 | +For full definitions of the above types, please see [here](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py) |
| 289 | +``` |
| 290 | + |
| 291 | + |
| 292 | +(Serialization)= |
| 293 | +## Serialization |
| 294 | + |
| 295 | +To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the `save_pretrained()` method. |
| 296 | + |
| 297 | + |
| 298 | +**Using Transformers**: |
| 299 | +```python |
| 300 | +import torch |
| 301 | +from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer |
| 302 | +from torchao.quantization import Int8WeightOnlyConfig |
| 303 | + |
| 304 | +quant_config = Int8WeightOnlyConfig(group_size=128) |
| 305 | +quantization_config = TorchAoConfig(quant_type=quant_config) |
| 306 | + |
| 307 | +# Load and quantize the model |
| 308 | +quantized_model = AutoModelForCausalLM.from_pretrained( |
| 309 | + "meta-llama/Llama-3.1-8B-Instruct", |
| 310 | + dtype="auto", |
| 311 | + device_map="cpu", |
| 312 | + quantization_config=quantization_config |
| 313 | +) |
| 314 | +# save the quantized model |
| 315 | +output_dir = "llama-3.1-8b-torchao-int8" |
| 316 | +quantized_model.save_pretrained(output_dir, safe_serialization=False) |
| 317 | +``` |
| 318 | + |
| 319 | +**Using Diffusers**: |
| 320 | +```python |
| 321 | +import torch |
| 322 | +from diffusers import AutoModel, TorchAoConfig |
| 323 | + |
| 324 | +quantization_config = TorchAoConfig("int8wo") |
| 325 | +transformer = AutoModel.from_pretrained( |
| 326 | + "black-forest-labs/Flux.1-Dev", |
| 327 | + subfolder="transformer", |
| 328 | + quantization_config=quantization_config, |
| 329 | + torch_dtype=torch.bfloat16, |
| 330 | +) |
| 331 | +transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False) |
| 332 | +``` |
| 333 | + |
| 334 | + |
| 335 | +To load a serialized quantized model, use the `from_pretrained()` method. |
| 336 | + |
| 337 | +**Using Transformers**: |
| 338 | +```python |
| 339 | +# reload the quantized model |
| 340 | +reloaded_model = AutoModelForCausalLM.from_pretrained( |
| 341 | + output_dir, |
| 342 | + device_map="auto", |
| 343 | + dtype=torch.bfloat16 |
| 344 | +) |
| 345 | +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") |
| 346 | +input_text = "What are we having for dinner?" |
| 347 | +input_ids = tokenizer(input_text, return_tensors="pt").to(reloaded_model.device.type) |
| 348 | + |
| 349 | +output = reloaded_model.generate(**input_ids, max_new_tokens=10) |
| 350 | +print(tokenizer.decode(output[0], skip_special_tokens=True)) |
| 351 | +``` |
| 352 | + |
| 353 | +**Using Diffusers**: |
| 354 | +```python |
| 355 | +import torch |
| 356 | +from diffusers import FluxPipeline, AutoModel |
| 357 | + |
| 358 | +transformer = AutoModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False) |
| 359 | +pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16) |
| 360 | +pipe.to("cuda") |
| 361 | + |
| 362 | +prompt = "A cat holding a sign that says hello world" |
| 363 | +image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0] |
| 364 | +image.save("output.png") |
| 365 | +``` |
| 366 | + |
| 367 | +(safetensors-support)= |
| 368 | +### SafeTensors Support |
| 369 | + |
| 370 | +**Current Status**: TorchAO quantized models cannot yet be serialized with safetensors due to tensor subclass limitations. When saving quantized models, you must use `safe_serialization=False`. |
| 371 | + |
| 372 | +```python |
| 373 | +# don't serialize model with Safetensors |
| 374 | +output_dir = "llama3-8b-int4wo-128" |
| 375 | +quantized_model.save_pretrained("llama3-8b-int4wo-128", safe_serialization=False) |
| 376 | +``` |
| 377 | + |
| 378 | +**Workaround**: For production use, save models with `safe_serialization=False` when pushing to HuggingFace Hub. |
| 379 | + |
| 380 | +**Future Work**: The TorchAO team is actively working on safetensors support for tensor subclasses. Track progress [here](https://github.com/pytorch/ao/issues/2338) and [here](https://github.com/pytorch/ao/pull/2881) |
0 commit comments