Skip to content

Commit ea6c364

Browse files
committed
start higgs
1 parent 7392c8f commit ea6c364

File tree

5 files changed

+948
-0
lines changed

5 files changed

+948
-0
lines changed

src/diffusers/quantizers/auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .quantization_config import (
2525
BitsAndBytesConfig,
2626
GGUFQuantizationConfig,
27+
HiggsConfig,
2728
QuantizationConfigMixin,
2829
QuantizationMethod,
2930
QuantoConfig,
@@ -39,6 +40,7 @@
3940
"gguf": GGUFQuantizer,
4041
"quanto": QuantoQuantizer,
4142
"torchao": TorchAoHfQuantizer,
43+
"higgs": 1,
4244
}
4345

4446
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -47,6 +49,7 @@
4749
"gguf": GGUFQuantizationConfig,
4850
"quanto": QuantoConfig,
4951
"torchao": TorchAoConfig,
52+
"higgs": HiggsConfig,
5053
}
5154

5255

src/diffusers/quantizers/higgs/__init__.py

Whitespace-only changes.
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Adapted from
16+
https://github.com/huggingface/transformers/blob/d3d835d4fc145e5062d2153ac23ccd4b3e2c2cbd/src/transformers/quantizers/quantizer_higgs.py
17+
"""
18+
19+
from typing import TYPE_CHECKING, Any, Optional
20+
21+
from ...utils import get_module_from_name
22+
from ..base import DiffusersQuantizer
23+
24+
25+
if TYPE_CHECKING:
26+
from ...models.modeling_utils import ModelMixin
27+
28+
from ...utils import is_accelerate_available, is_torch_available, logging
29+
from ...utils.logging import tqdm
30+
31+
32+
if is_torch_available():
33+
import torch
34+
35+
logger = logging.get_logger(__name__)
36+
37+
38+
class HiggsHfQuantizer(DiffusersQuantizer):
39+
"""
40+
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of
41+
full-precision models.
42+
"""
43+
44+
requires_calibration = False
45+
requires_parameters_quantization = True
46+
required_packages = ["flute-kernel", "fast_hadamard_transform"]
47+
48+
def __init__(self, quantization_config, **kwargs):
49+
super().__init__(quantization_config, **kwargs)
50+
self.quantization_config = quantization_config
51+
52+
def validate_environment(self, device_map, **kwargs):
53+
if not torch.cuda.is_available():
54+
raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.")
55+
56+
if not is_accelerate_available():
57+
raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`")
58+
59+
# TODO: enable this.
60+
# if not is_flute_available():
61+
# raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel>=0.3.0`")
62+
63+
# if not is_hadamard_available():
64+
# raise ImportError(
65+
# "Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`"
66+
# )
67+
68+
if device_map is None:
69+
raise ValueError(
70+
"You are attempting to load a HIGGS model without setting device_map."
71+
" Please set device_map comprised of 'cuda' devices."
72+
)
73+
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
74+
raise ValueError(
75+
"You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device."
76+
" This is not supported. Please remove the CPU or disk device from the device_map."
77+
)
78+
79+
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
80+
if torch_dtype is None:
81+
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.float16` for FLUTE compatibility.")
82+
torch_dtype = torch.float16
83+
elif torch_dtype != torch.float16 and torch_dtype != torch.bfloat16:
84+
raise ValueError(
85+
f"Invalid `torch_dtype` {torch_dtype}. HIGGS quantization only supports `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16`."
86+
)
87+
88+
return torch_dtype
89+
90+
def create_quantized_param(
91+
self,
92+
model: "ModelMixin",
93+
param_value: "torch.Tensor",
94+
param_name: str,
95+
target_device: "torch.device",
96+
state_dict: dict[str, Any],
97+
unexpected_keys: Optional[list[str]] = None,
98+
):
99+
from .utils import quantize_with_higgs
100+
101+
"""
102+
Quantizes weights into weight and weight_scale
103+
"""
104+
flute_dict = quantize_with_higgs(
105+
param_value.to(target_device),
106+
self.quantization_config.bits,
107+
self.quantization_config.p,
108+
self.quantization_config.group_size,
109+
self.quantization_config.hadamard_size,
110+
)
111+
del param_value
112+
113+
module, _ = get_module_from_name(model, param_name)
114+
module_name = ".".join(param_name.split(".")[:-1])
115+
for key, value in flute_dict.items():
116+
if key in module._parameters:
117+
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
118+
elif key in module._buffers:
119+
module._buffers[key] = torch.nn.Buffer(value)
120+
elif key == "tune_metadata":
121+
module.tune_metadata = value
122+
self.quantization_config.tune_metadata[module_name] = value.to_dict()
123+
else:
124+
raise ValueError(f"Unexpected key {key} in module {module}")
125+
126+
if unexpected_keys is not None and param_name in unexpected_keys:
127+
unexpected_keys.remove(param_name)
128+
129+
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
130+
from .utils import HiggsLinear
131+
132+
higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
133+
134+
def should_update(key: str) -> bool:
135+
if key.endswith(".weight") or key.endswith(".bias"):
136+
return False
137+
full_key = f"{prefix}.{key}"
138+
return any(name in key or name in full_key for name in higgs_names)
139+
140+
return [key for key in missing_keys if not should_update(key)]
141+
142+
@property
143+
def is_trainable(self):
144+
return False
145+
146+
def is_serializable(self):
147+
return True
148+
149+
def check_quantized_param(
150+
self,
151+
model: "ModelMixin",
152+
param_value: "torch.Tensor",
153+
param_name: str,
154+
state_dict: dict[str, Any],
155+
**kwargs,
156+
) -> bool:
157+
from .utils import HiggsLinear
158+
159+
module, tensor_name = get_module_from_name(model, param_name)
160+
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16:
161+
# Only quantize weights of HiggsLinear modules that are not already quantized
162+
return True
163+
else:
164+
return False
165+
166+
def _process_model_before_weight_loading(
167+
self,
168+
model: "ModelMixin",
169+
**kwargs,
170+
):
171+
from .utils import replace_with_higgs_linear
172+
173+
replace_with_higgs_linear(model, quantization_config=self.quantization_config)
174+
model.config.quantization_config = self.quantization_config
175+
176+
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
177+
from flute.tune import TuneMetaData, maybe_tune_and_repack
178+
from flute.utils import make_workspace_streamk
179+
180+
from .utils import HiggsLinear
181+
182+
flute_workspaces = {}
183+
flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
184+
for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False):
185+
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
186+
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
187+
if module.weight.device not in flute_workspaces:
188+
flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device)
189+
module.workspace = flute_workspaces[module.weight.device]
190+
191+
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
192+
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
193+
module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name])
194+
module.weight.data, module.tune_metadata = maybe_tune_and_repack(
195+
weight=module.weight.data,
196+
scales=module.scales.data,
197+
metadata=module.tune_metadata,
198+
)
199+
self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()
200+
201+
def _dequantize(self, model):
202+
from .utils import dequantize_higgs
203+
204+
model = dequantize_higgs(model)
205+
return model

0 commit comments

Comments
 (0)