Skip to content

Commit b5bf4cf

Browse files
committed
init
1 parent 18c8f10 commit b5bf4cf

File tree

7 files changed

+97
-1
lines changed

7 files changed

+97
-1
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323

2424
from .. import __version__
2525
from ..quantizers import DiffusersAutoQuantizer
26-
from ..utils import deprecate, is_accelerate_available, logging
26+
from ..quantizers.quantization_config import QuantizationMethod
27+
from ..utils import deprecate, is_accelerate_available, is_nunchaku_available, logging
2728
from ..utils.torch_utils import empty_device_cache
2829
from .single_file_utils import (
2930
SingleFileComponentError,
@@ -243,6 +244,32 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
243244
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
244245
```
245246
"""
247+
quantization_config = kwargs.get("quantization_config")
248+
if quantization_config is not None and quantization_config.quant_method == QuantizationMethod.SVDQUANT:
249+
if not is_nunchaku_available():
250+
raise ImportError("Loading SVDQuant models requires the `nunchaku` package. Please install it.")
251+
252+
if isinstance(pretrained_model_link_or_path_or_dict, dict):
253+
raise ValueError(
254+
"Loading a nunchaku model from a state_dict is not supported directly via from_single_file. Please provide a path."
255+
)
256+
257+
if "FluxTransformer2DModel" in cls.__name__:
258+
from nunchaku import NunchakuFluxTransformer2dModel
259+
260+
kwargs.pop("quantization_config", None)
261+
return NunchakuFluxTransformer2dModel.from_pretrained(
262+
pretrained_model_link_or_path_or_dict, **kwargs
263+
)
264+
elif "SanaTransformer2DModel" in cls.__name__:
265+
from nunchaku import NunchakuSanaTransformer2DModel
266+
267+
kwargs.pop("quantization_config", None)
268+
return NunchakuSanaTransformer2DModel.from_pretrained(
269+
pretrained_model_link_or_path_or_dict, **kwargs
270+
)
271+
else:
272+
raise NotImplementedError(f"SVDQuant loading is not implemented for {cls.__name__}")
246273

247274
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
248275
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:

src/diffusers/quantizers/auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
QuantizationConfigMixin,
2828
QuantizationMethod,
2929
QuantoConfig,
30+
SVDQuantConfig,
3031
TorchAoConfig,
3132
)
3233
from .quanto import QuantoQuantizer
34+
from .svdquant import SVDQuantizer
3335
from .torchao import TorchAoHfQuantizer
3436

3537

@@ -39,6 +41,7 @@
3941
"gguf": GGUFQuantizer,
4042
"quanto": QuantoQuantizer,
4143
"torchao": TorchAoHfQuantizer,
44+
"svdquant": SVDQuantizer,
4245
}
4346

4447
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -47,6 +50,7 @@
4750
"gguf": GGUFQuantizationConfig,
4851
"quanto": QuantoConfig,
4952
"torchao": TorchAoConfig,
53+
"svdquant": SVDQuantConfig,
5054
}
5155

5256

src/diffusers/quantizers/quantization_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum):
4646
GGUF = "gguf"
4747
TORCHAO = "torchao"
4848
QUANTO = "quanto"
49+
SVDQUANT = "svdquant"
4950

5051

5152
if is_torchao_available():
@@ -724,3 +725,12 @@ def post_init(self):
724725
accepted_weights = ["float8", "int8", "int4", "int2"]
725726
if self.weights_dtype not in accepted_weights:
726727
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
728+
729+
@dataclass
730+
class SVDQuantConfig(QuantizationConfigMixin):
731+
"""Config for SVDQuant models. This is a placeholder for loading pre-quantized nunchaku models."""
732+
733+
def __init__(self, **kwargs):
734+
self.quant_method = QuantizationMethod.SVDQUANT
735+
for key, value in kwargs.items():
736+
setattr(self, key, value)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .svdquant_quantizer import SVDQuantizer
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ..base import DiffusersQuantizer
16+
17+
18+
class SVDQuantizer(DiffusersQuantizer):
19+
"""
20+
SVDQuantizer is a placeholder quantizer for loading pre-quantized SVDQuant models using the nunchaku library.
21+
"""
22+
23+
use_keep_in_fp32_modules = False
24+
requires_calibration = False
25+
26+
def __init__(self, quantization_config, **kwargs):
27+
super().__init__(quantization_config, **kwargs)
28+
29+
def _process_model_before_weight_loading(self, model, **kwargs):
30+
# No-op, as the model is fully loaded by nunchaku.
31+
return model
32+
33+
def _process_model_after_weight_loading(self, model, **kwargs):
34+
return model
35+
36+
@property
37+
def is_serializable(self):
38+
# The model is serialized in its own format.
39+
return True
40+
41+
@property
42+
def is_trainable(self):
43+
return False

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
is_matplotlib_available,
8686
is_nltk_available,
8787
is_note_seq_available,
88+
is_nunchaku_available,
8889
is_onnx_available,
8990
is_opencv_available,
9091
is_optimum_quanto_available,

src/diffusers/utils/import_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
223223
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
224224
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
225225
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
226+
_nunchaku_available, _nunchaku_version = _is_package_available("nunchaku")
226227

227228

228229
def is_torch_available():
@@ -393,6 +394,10 @@ def is_flash_attn_3_available():
393394
return _flash_attn_3_available
394395

395396

397+
def is_nunchaku_available():
398+
return _nunchaku_available
399+
400+
396401
# docstyle-ignore
397402
FLAX_IMPORT_ERROR = """
398403
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -556,6 +561,10 @@ def is_flash_attn_3_available():
556561
{0} requires the nltk library but it was not found in your environment. You can install it with pip: `pip install nltk`
557562
"""
558563

564+
NUNCHAKU_IMPORT_ERROR = """
565+
{0} requires the nunchaku library but it was not found in your environment. You can install it with pip: `pip install nunchaku`
566+
"""
567+
559568

560569
BACKENDS_MAPPING = OrderedDict(
561570
[
@@ -588,6 +597,7 @@ def is_flash_attn_3_available():
588597
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
589598
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
590599
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
600+
("nunchaku", (is_nunchaku_available, NUNCHAKU_IMPORT_ERROR)),
591601
]
592602
)
593603

0 commit comments

Comments
 (0)