Skip to content

Commit d561f65

Browse files
committed
frst commit
1 parent c36f848 commit d561f65

File tree

12 files changed

+707
-2
lines changed

12 files changed

+707
-2
lines changed

.github/workflows/nightly_tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,8 @@ jobs:
473473
additional_deps: []
474474
- backend: "optimum_quanto"
475475
test_location: "quanto"
476+
- backend: "finegrained_fp8"
477+
test_location: "finegrained_fp8"
476478
additional_deps: []
477479
runs-on:
478480
group: aws-g6e-xlarge-plus

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@
174174
title: torchao
175175
- local: quantization/quanto
176176
title: quanto
177+
- local: quantization/finegrained_fp8
178+
title: finegrained_fp8
177179
title: Quantization Methods
178180
- sections:
179181
- local: optimization/fp16

docs/source/en/api/quantization.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
4141

4242
[[autodoc]] TorchAoConfig
4343

44+
## FinegrainedFP8Config
45+
46+
[[autodoc]] FinegrainedFP8Config
47+
4448
## DiffusersQuantizer
4549

4650
[[autodoc]] quantizers.base.DiffusersQuantizer
51+
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
3+
the License. You may obtain a copy of the License at
4+
http://www.apache.org/licenses/LICENSE-2.0
5+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
6+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
7+
specific language governing permissions and limitations under the License.
8+
-->
9+
10+
# FinegrainedFP8
11+
12+
## Overview
13+
14+
## Usage
15+

docs/source/en/quantization/overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Diffusers currently supports the following quantization methods.
3737
- [TorchAO](./torchao)
3838
- [GGUF](./gguf)
3939
- [Quanto](./quanto.md)
40-
40+
- [FinegrainedFP8](./finegrained_fp8.md)
4141
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
4242

4343
## Pipeline-level quantization

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@
9696
else:
9797
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
9898

99+
_import_structure["quantizers.quantization_config"].append("FinegrainedFP8Config")
100+
99101
try:
100102
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
101103
raise OptionalDependencyNotAvailable()
@@ -724,6 +726,8 @@
724726
else:
725727
from .quantizers.quantization_config import QuantoConfig
726728

729+
from .quantizers.quantization_config import FinegrainedFP8Config
730+
727731
try:
728732
if not is_onnx_available():
729733
raise OptionalDependencyNotAvailable()

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
12381238
}
12391239

12401240
# Dispatch model with hooks on all devices if necessary
1241+
print(model.transformer_blocks[0].attn.to_q.weight)
1242+
print(model.transformer_blocks[0].attn.to_q.weight_scale_inv)
12411243
if device_map is not None:
12421244
device_map_kwargs = {
12431245
"device_map": device_map,

src/diffusers/quantizers/auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
QuantizationMethod,
2929
QuantoConfig,
3030
TorchAoConfig,
31+
FinegrainedFP8Config,
3132
)
3233
from .quanto import QuantoQuantizer
3334
from .torchao import TorchAoHfQuantizer
35+
from .finegrained_fp8 import FinegrainedFP8Quantizer
3436

3537

3638
AUTO_QUANTIZER_MAPPING = {
@@ -39,6 +41,7 @@
3941
"gguf": GGUFQuantizer,
4042
"quanto": QuantoQuantizer,
4143
"torchao": TorchAoHfQuantizer,
44+
"finegrained_fp8": FinegrainedFP8Quantizer,
4245
}
4346

4447
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -47,6 +50,7 @@
4750
"gguf": GGUFQuantizationConfig,
4851
"quanto": QuantoConfig,
4952
"torchao": TorchAoConfig,
53+
"finegrained_fp8": FinegrainedFP8Config,
5054
}
5155

5256

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .finegrained_fp8_quantizer import FinegrainedFP8Quantizer
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
2+
3+
from ...utils import is_accelerate_available, is_torch_available, logging
4+
from ..base import DiffusersQuantizer
5+
from ...utils import get_module_from_name
6+
7+
8+
if is_torch_available():
9+
import torch
10+
11+
logger = logging.get_logger(__name__)
12+
13+
if TYPE_CHECKING:
14+
from ...models.modeling_utils import ModelMixin
15+
16+
class FinegrainedFP8Quantizer(DiffusersQuantizer):
17+
"""
18+
FP8 quantization implementation supporting both standard and MoE models.
19+
Supports both e4m3fn formats based on platform.
20+
"""
21+
22+
requires_parameters_quantization = True
23+
requires_calibration = False
24+
required_packages = ["accelerate"]
25+
26+
def __init__(self, quantization_config, **kwargs):
27+
super().__init__(quantization_config, **kwargs)
28+
self.quantization_config = quantization_config
29+
30+
def validate_environment(self, *args, **kwargs):
31+
if not is_torch_available():
32+
raise ImportError(
33+
"Using fp8 quantization requires torch >= 2.1.0"
34+
"Please install the latest version of torch ( pip install --upgrade torch )"
35+
)
36+
37+
if not is_accelerate_available():
38+
raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)")
39+
40+
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
41+
raise ValueError(
42+
"Converting into FP8 weights from tf/flax weights is currently not supported, "
43+
"please make sure the weights are in PyTorch format."
44+
)
45+
46+
if torch.cuda.is_available():
47+
compute_capability = torch.cuda.get_device_capability()
48+
major, minor = compute_capability
49+
if (major < 8) or (major == 8 and minor < 9):
50+
raise ValueError(
51+
"FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)"
52+
f", actual = `{major}.{minor}`"
53+
)
54+
55+
device_map = kwargs.get("device_map", None)
56+
if device_map is None:
57+
logger.warning_once(
58+
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
59+
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
60+
)
61+
elif device_map is not None:
62+
if (
63+
not self.pre_quantized
64+
and isinstance(device_map, dict)
65+
and ("cpu" in device_map.values() or "disk" in device_map.values())
66+
):
67+
raise ValueError(
68+
"You are attempting to load an FP8 model with a device_map that contains a cpu/disk device."
69+
"This is not supported when the model is quantized on the fly. "
70+
"Please use a quantized checkpoint or remove the cpu/disk device from the device_map."
71+
)
72+
73+
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
74+
if torch_dtype is None:
75+
logger.info("Setting torch_dtype to torch.float32 as no torch_dtype was specified in from_pretrained")
76+
torch_dtype = torch.float32
77+
return torch_dtype
78+
79+
def create_quantized_param(
80+
self,
81+
model: "ModelMixin",
82+
param_value: "torch.Tensor",
83+
param_name: str,
84+
target_device: "torch.device",
85+
state_dict: Dict[str, Any],
86+
unexpected_keys: Optional[List[str]] = None,
87+
**kwargs,
88+
):
89+
"""
90+
Quantizes weights to FP8 format using Block-wise quantization
91+
"""
92+
# print("############ create quantized param ########")
93+
from accelerate.utils import set_module_tensor_to_device
94+
95+
set_module_tensor_to_device(model, param_name, target_device, param_value)
96+
97+
module, tensor_name = get_module_from_name(model, param_name)
98+
99+
# Get FP8 min/max values
100+
fp8_min = torch.finfo(torch.float8_e4m3fn).min
101+
fp8_max = torch.finfo(torch.float8_e4m3fn).max
102+
103+
block_size_m, block_size_n = self.quantization_config.weight_block_size
104+
105+
rows, cols = param_value.shape[-2:]
106+
107+
if rows % block_size_m != 0 or cols % block_size_n != 0:
108+
raise ValueError(
109+
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
110+
)
111+
param_value_orig_shape = param_value.shape
112+
113+
param_value = param_value.reshape(
114+
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
115+
).permute(0, 1, 3, 2, 4)
116+
117+
# Calculate scaling factor for each block
118+
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
119+
scale = fp8_max / max_abs
120+
scale_orig_shape = scale.shape
121+
scale = scale.unsqueeze(-1).unsqueeze(-1)
122+
123+
# Quantize the weights
124+
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
125+
126+
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
127+
# Reshape back to matrix shape
128+
quantized_param = quantized_param.reshape(param_value_orig_shape)
129+
130+
# Reshape scale to match the number of blocks
131+
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
132+
133+
# Load into the model
134+
module._buffers[tensor_name] = quantized_param.to(target_device)
135+
module._buffers["weight_scale_inv"] = scale.to(target_device)
136+
# print("_buffers[0]", module._buffers["weight_scale_inv"])
137+
138+
def check_if_quantized_param(
139+
self,
140+
model: "ModelMixin",
141+
param_value: "torch.Tensor",
142+
param_name: str,
143+
state_dict: Dict[str, Any],
144+
**kwargs,
145+
):
146+
from .utils import FP8Linear
147+
148+
module, tensor_name = get_module_from_name(model, param_name)
149+
if isinstance(module, FP8Linear):
150+
if self.pre_quantized or tensor_name == "bias":
151+
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
152+
raise ValueError("Expect quantized weights but got an unquantized weight")
153+
return False
154+
else:
155+
if tensor_name == "weight_scale_inv":
156+
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
157+
return True
158+
return False
159+
160+
def _process_model_before_weight_loading(
161+
self,
162+
model: "ModelMixin",
163+
keep_in_fp32_modules: Optional[List[str]] = None,
164+
**kwargs,
165+
):
166+
from .utils import replace_with_fp8_linear
167+
168+
if self.quantization_config.modules_to_not_convert is not None:
169+
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
170+
171+
model = replace_with_fp8_linear(
172+
model,
173+
modules_to_not_convert=self.modules_to_not_convert,
174+
quantization_config=self.quantization_config,
175+
)
176+
177+
model.config.quantization_config = self.quantization_config
178+
179+
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
180+
return model
181+
182+
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
183+
from .utils import FP8Linear
184+
185+
not_missing_keys = []
186+
for name, module in model.named_modules():
187+
if isinstance(module, FP8Linear):
188+
for missing in missing_keys:
189+
if (
190+
(name in missing or name in f"{prefix}.{missing}")
191+
and not missing.endswith(".weight")
192+
and not missing.endswith(".bias")
193+
):
194+
not_missing_keys.append(missing)
195+
return [k for k in missing_keys if k not in not_missing_keys]
196+
197+
def is_serializable(self, safe_serialization=None):
198+
return True
199+
200+
@property
201+
def is_trainable(self) -> bool:
202+
return False
203+
204+
def get_cuda_warm_up_factor(self):
205+
# Pre-processing is done cleanly, so we can allocate everything here
206+
return 2

0 commit comments

Comments
 (0)