Skip to content

Commit cc65dc5

Browse files
authored
hf integration doc page (#2899)
1 parent f9bc52d commit cc65dc5

File tree

5 files changed

+130
-32
lines changed

5 files changed

+130
-32
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ for an overall introduction to the library and recent highlight and updates.
4545
finetuning
4646
serving
4747
torchao_vllm_integration
48+
torchao_hf_integration
4849
serialization
4950
static_quantization
5051
subclass_basic

docs/source/output.png

1.32 MB
Loading

docs/source/serving.rst

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,7 @@ Post-training Quantization with HuggingFace
1515
-------------------------------------------
1616

1717
HuggingFace Transformers provides seamless integration with torchao quantization. The ``TorchAoConfig`` automatically applies torchao's optimized quantization algorithms during model loading.
18-
19-
.. code-block:: bash
20-
21-
pip install git+https://github.com/huggingface/transformers@main
22-
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
23-
pip install torch
24-
pip install accelerate
25-
26-
For this example, we'll use ``Float8DynamicActivationFloat8WeightConfig`` on the Phi-4 mini-instruct model.
27-
28-
.. code-block:: python
29-
30-
import torch
31-
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
32-
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
33-
34-
model_id = "microsoft/Phi-4-mini-instruct"
35-
36-
quant_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
37-
quantization_config = TorchAoConfig(quant_type=quant_config)
38-
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
39-
tokenizer = AutoTokenizer.from_pretrained(model_id)
40-
41-
# Push the model to hub
42-
USER_ID = "YOUR_USER_ID"
43-
MODEL_NAME = model_id.split("/")[-1]
44-
save_to = f"{USER_ID}/{MODEL_NAME}-float8dq"
45-
quantized_model.push_to_hub(save_to, safe_serialization=False)
46-
tokenizer.push_to_hub(save_to)
47-
48-
.. note::
49-
For more information on supported quantization and sparsity configurations, see `HF-Torchao Docs <https://huggingface.co/docs/transformers/main/en/quantization/torchao>`_.
18+
Please check out our `HF Integration Docs <torchao_hf_integration.html>`_ for examples on how to use quantization and sparsity in Transformers and Diffusers.
5019

5120
Serving and Inference
5221
--------------------
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
(torchao_hf_integration)=
2+
# Hugging Face Integration
3+
4+
```{contents}
5+
:local:
6+
:depth: 2
7+
```
8+
9+
(usage-examples)=
10+
## Quick Start: Usage Example
11+
12+
First, install the required packages.
13+
14+
```bash
15+
pip install git+https://github.com/huggingface/transformers@main
16+
pip install git+https://github.com/huggingface/diffusers@main
17+
pip install torchao
18+
pip install torch
19+
pip install accelerate
20+
```
21+
22+
(quantizing-models-transformers)=
23+
### 1. Quantizing Models with Transformers
24+
25+
Below is an example of using `Float8DynamicActivationInt4WeightConfig` on the Llama-3.2-1B model.
26+
27+
```python
28+
from transformers import TorchAoConfig, AutoModelForCausalLM
29+
from torchao.quantization import Float8DynamicActivationInt4WeightConfig
30+
31+
# Create quantization configuration
32+
quantization_config = TorchAoConfig(
33+
quant_type=Float8DynamicActivationInt4WeightConfig(group_size=128, use_hqq=True)
34+
)
35+
36+
# Load and automatically quantize the model
37+
model = AutoModelForCausalLM.from_pretrained(
38+
"meta-llama/Llama-3.2-1B",
39+
torch_dtype="auto",
40+
device_map="auto",
41+
quantization_config=quantization_config
42+
)
43+
```
44+
```{seealso}
45+
For inference examples and recommended quantization methods based on different hardwares (i.e. A100 GPU, H100 GPU, CPU), see [HF-Torchao Docs (Quantization Examples)](https://huggingface.co/docs/transformers/main/en/quantization/torchao#quantization-examples).
46+
47+
For inference using vLLM, please see [(Part 3) Serving on vLLM, SGLang, ExecuTorch](https://docs.pytorch.org/ao/main/serving.html) for a full end-to-end tutorial.
48+
```
49+
50+
(quantizing-models-diffusers)=
51+
### 2. Quantizing Models with Diffusers
52+
53+
Below is an example of how we can integrate with Diffusers.
54+
55+
```python
56+
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
57+
58+
model_id = "black-forest-labs/Flux.1-Dev"
59+
dtype = torch.bfloat16
60+
61+
quantization_config = TorchAoConfig("int8wo")
62+
transformer = FluxTransformer2DModel.from_pretrained(
63+
model_id,
64+
subfolder="transformer",
65+
quantization_config=quantization_config,
66+
torch_dtype=dtype,
67+
)
68+
pipe = FluxPipeline.from_pretrained(
69+
model_id,
70+
transformer=transformer,
71+
torch_dtype=dtype,
72+
)
73+
pipe.to("cuda")
74+
75+
prompt = "A cat holding a sign that says hello world"
76+
image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
77+
image.save("output.png")
78+
```
79+
80+
```{note}
81+
Example Output:
82+
![alt text](output.png "Model Output")
83+
```
84+
85+
```{seealso}
86+
Please refer to [HF-TorchAO-Diffuser Docs](https://huggingface.co/docs/diffusers/en/quantization/torchao) for more examples and benchmarking results.
87+
```
88+
89+
(saving-models)=
90+
## Saving the Model
91+
92+
After we quantize the model, we can save it.
93+
94+
```python
95+
# Save quantized model (see below for safe_serialization enablement progress)
96+
with tempfile.TemporaryDirectory() as tmp_dir:
97+
model.save_pretrained(tmp_dir, safe_serialization=False)
98+
99+
# optional: push to hub (uncomment the following lines)
100+
# save_to = "your-username/Llama-3.2-1B-int4"
101+
# model.push_to_hub(save_to, safe_serialization=False)
102+
103+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
104+
tokenizer.push_to_hub(save_to)
105+
```
106+
107+
**Current Status of Safetensors support**: TorchAO quantized models cannot yet be serialized with safetensors due to tensor subclass limitations. When saving quantized models, you must use `safe_serialization=False`.
108+
109+
```python
110+
# don't serialize model with Safetensors
111+
output_dir = "llama3-8b-int4wo-128"
112+
quantized_model.save_pretrained("llama3-8b-int4wo-128", safe_serialization=False)
113+
```
114+
115+
**Workaround**: For production use, save models with `safe_serialization=False` when pushing to Hugging Face Hub.
116+
117+
**Future Work**: The TorchAO team is actively working on safetensors support for tensor subclasses. Track progress [here](https://github.com/pytorch/ao/issues/2338) and [here](https://github.com/pytorch/ao/pull/2881).
118+
119+
(Supported-Quantization-Types)=
120+
## Supported Quantization Types
121+
122+
Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
123+
124+
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
125+
126+
```{note}
127+
Please refer to the [torchao docs](https://docs.pytorch.org/ao/main/api_ref_quantization.html) for supported quantization types.
128+
```

output.png

1.32 MB
Loading

0 commit comments

Comments
 (0)