Skip to content

Commit 8e1ea00

Browse files
committed
start nunchaku.
1 parent 91a151b commit 8e1ea00

File tree

7 files changed

+324
-2
lines changed

7 files changed

+324
-2
lines changed

src/diffusers/quantizers/gguf/gguf_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def check_quantized_param_shape(self, param_name, current_param, loaded_param):
9090
def check_if_quantized_param(
9191
self,
9292
model: "ModelMixin",
93-
param_value: Union["GGUFParameter", "torch.Tensor"],
93+
param_value: Union["torch.Tensor"],
9494
param_name: str,
9595
state_dict: Dict[str, Any],
9696
**kwargs,
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from typing import TYPE_CHECKING, Any, Dict, List, Union
2+
3+
from diffusers.utils.import_utils import is_nunchaku_version
4+
5+
from ...utils import (
6+
get_module_from_name,
7+
is_accelerate_available,
8+
is_nunchaku_available,
9+
is_torch_available,
10+
logging,
11+
)
12+
from ...utils.torch_utils import is_fp8_available
13+
from ..base import DiffusersQuantizer
14+
15+
16+
if TYPE_CHECKING:
17+
from ...models.modeling_utils import ModelMixin
18+
19+
20+
if is_torch_available():
21+
import torch
22+
23+
if is_accelerate_available():
24+
pass
25+
26+
if is_nunchaku_available():
27+
from .utils import replace_with_nunchaku_linear
28+
29+
logger = logging.get_logger(__name__)
30+
31+
32+
class QuantoQuantizer(DiffusersQuantizer):
33+
r"""
34+
Diffusers Quantizer for Optimum Quanto
35+
"""
36+
37+
use_keep_in_fp32_modules = True
38+
requires_calibration = False
39+
required_packages = ["nunchaku", "accelerate"]
40+
41+
dtype_map = {"int4": torch.int8}
42+
if is_fp8_available():
43+
dtype_map = {"nvfp4": torch.float8_e4m3fn}
44+
45+
def __init__(self, quantization_config, **kwargs):
46+
super().__init__(quantization_config, **kwargs)
47+
48+
def validate_environment(self, *args, **kwargs):
49+
if not torch.cuda.is_available():
50+
raise RuntimeError("No GPU found. A GPU is needed for nunchaku quantization.")
51+
52+
if not is_nunchaku_available():
53+
raise ImportError(
54+
"Loading an nunchaku quantized model requires nunchaku library (follow https://nunchaku.tech/docs/nunchaku/installation/installation.html)"
55+
)
56+
if not is_nunchaku_version(">=", "0.3.1"):
57+
raise ImportError(
58+
"Loading an nunchaku quantized model requires `nunchaku>=1.0.0`. "
59+
"Please upgrade your installation by following https://nunchaku.tech/docs/nunchaku/installation/installation.html."
60+
)
61+
62+
if not is_accelerate_available():
63+
raise ImportError(
64+
"Loading an nunchaku quantized model requires accelerate library (`pip install accelerate`)"
65+
)
66+
67+
# TODO: check
68+
# device_map = kwargs.get("device_map", None)
69+
# if isinstance(device_map, dict) and len(device_map.keys()) > 1:
70+
# raise ValueError(
71+
# "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend"
72+
# )
73+
74+
def check_if_quantized_param(
75+
self,
76+
model: "ModelMixin",
77+
param_value: "torch.Tensor",
78+
param_name: str,
79+
state_dict: Dict[str, Any],
80+
**kwargs,
81+
):
82+
# Quanto imports diffusers internally. This is here to prevent circular imports
83+
from nunchaku.models.linear import SVDQW4A4Linear
84+
85+
module, tensor_name = get_module_from_name(model, param_name)
86+
if self.pre_quantized and isinstance(module, SVDQW4A4Linear):
87+
return True
88+
89+
return False
90+
91+
def create_quantized_param(
92+
self,
93+
model: "ModelMixin",
94+
param_value: "torch.Tensor",
95+
param_name: str,
96+
target_device: "torch.device",
97+
*args,
98+
**kwargs,
99+
):
100+
"""
101+
Create a quantized parameter.
102+
"""
103+
from nunchaku.models.linear import SVDQW4A4Linear
104+
105+
module, tensor_name = get_module_from_name(model, param_name)
106+
if tensor_name not in module._parameters and tensor_name not in module._buffers:
107+
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
108+
109+
if self.pre_quantized:
110+
if tensor_name in module._parameters:
111+
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
112+
if tensor_name in module._buffers:
113+
module._buffers[tensor_name] = torch.nn.Parameter(param_value.to(target_device))
114+
115+
elif isinstance(module, torch.nn.Linear):
116+
if tensor_name in module._parameters:
117+
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
118+
if tensor_name in module._buffers:
119+
module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(target_device)
120+
121+
new_module = SVDQW4A4Linear.from_linear(module)
122+
setattr(model, param_name, new_module)
123+
124+
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
125+
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
126+
return max_memory
127+
128+
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
129+
precision = self.quantization_config.precision
130+
expected_target_dtypes = [torch.int8]
131+
if is_fp8_available():
132+
expected_target_dtypes.append(torch.float8_e4m3fn)
133+
if target_dtype not in expected_target_dtypes:
134+
new_target_dtype = self.dtype_map[precision]
135+
136+
logger.info(f"target_dtype {target_dtype} is replaced by {new_target_dtype} for `nunchaku` quantization")
137+
return new_target_dtype
138+
else:
139+
raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.")
140+
141+
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
142+
if torch_dtype is None:
143+
# We force the `dtype` to be bfloat16, this is a requirement from `bitsandbytes`
144+
logger.info(
145+
"Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to "
146+
"requirements of `nunchaku` to enable model loading in 4-bit. "
147+
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
148+
" torch_dtype=torch.bfloat16 to remove this warning.",
149+
torch_dtype,
150+
)
151+
torch_dtype = torch.bfloat16
152+
return torch_dtype
153+
154+
def _process_model_before_weight_loading(
155+
self,
156+
model: "ModelMixin",
157+
device_map,
158+
keep_in_fp32_modules: List[str] = [],
159+
**kwargs,
160+
):
161+
# TODO: deal with `device_map`
162+
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
163+
164+
if not isinstance(self.modules_to_not_convert, list):
165+
self.modules_to_not_convert = [self.modules_to_not_convert]
166+
167+
self.modules_to_not_convert.extend(keep_in_fp32_modules)
168+
169+
model = replace_with_nunchaku_linear(
170+
model,
171+
modules_to_not_convert=self.modules_to_not_convert,
172+
quantization_config=self.quantization_config,
173+
pre_quantized=self.pre_quantized,
174+
)
175+
model.config.quantization_config = self.quantization_config
176+
177+
def _process_model_after_weight_loading(self, model, **kwargs):
178+
return model
179+
180+
# @property
181+
# def is_serializable(self):
182+
# return True
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch.nn as nn
2+
3+
from ...utils import is_accelerate_available, is_nunchaku_available, logging
4+
5+
6+
if is_accelerate_available():
7+
from accelerate import init_empty_weights
8+
9+
if is_nunchaku_available():
10+
from nunchaku.models.linear import SVDQW4A4Linear
11+
12+
13+
logger = logging.get_logger(__name__)
14+
15+
16+
def _replace_with_nunchaku_linear(
17+
model,
18+
modules_to_not_convert=None,
19+
current_key_name=None,
20+
quantization_config=None,
21+
has_been_replaced=False,
22+
):
23+
for name, module in model.named_children():
24+
if current_key_name is None:
25+
current_key_name = []
26+
current_key_name.append(name)
27+
28+
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
29+
# Check if the current key is not in the `modules_to_not_convert`
30+
current_key_name_str = ".".join(current_key_name)
31+
if not any(
32+
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
33+
):
34+
with init_empty_weights():
35+
in_features = module.in_features
36+
out_features = module.out_features
37+
38+
if quantization_config.precision in ["int4", "nvfp4"]:
39+
model._modules[name] = SVDQW4A4Linear(
40+
in_features,
41+
out_features,
42+
rank=quantization_config.rank,
43+
bias=module.bias is not None,
44+
dtype=model.dtype,
45+
)
46+
has_been_replaced = True
47+
# Store the module class in case we need to transpose the weight later
48+
model._modules[name].source_cls = type(module)
49+
# Force requires grad to False to avoid unexpected errors
50+
model._modules[name].requires_grad_(False)
51+
if len(list(module.children())) > 0:
52+
_, has_been_replaced = _replace_with_nunchaku_linear(
53+
module,
54+
modules_to_not_convert,
55+
current_key_name,
56+
quantization_config,
57+
has_been_replaced=has_been_replaced,
58+
)
59+
# Remove the last key for recursion
60+
current_key_name.pop(-1)
61+
return model, has_been_replaced
62+
63+
64+
def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
65+
model, _ = _replace_with_nunchaku_linear(model, modules_to_not_convert, current_key_name, quantization_config)
66+
67+
has_been_replaced = any(
68+
isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules()
69+
)
70+
if not has_been_replaced:
71+
logger.warning(
72+
"You are loading your model in the SVDQuant method but no linear modules were found in your model."
73+
" Please double check your model architecture, or submit an issue on github if you think this is"
74+
" a bug."
75+
)
76+
77+
return model

src/diffusers/quantizers/quantization_config.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum):
4646
GGUF = "gguf"
4747
TORCHAO = "torchao"
4848
QUANTO = "quanto"
49+
NUNCHAKU = "nunchaku"
4950

5051

5152
if is_torchao_available():
@@ -724,3 +725,40 @@ def post_init(self):
724725
accepted_weights = ["float8", "int8", "int4", "int2"]
725726
if self.weights_dtype not in accepted_weights:
726727
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
728+
729+
730+
class NunchakuConfig(QuantizationConfigMixin):
731+
"""
732+
This is a wrapper class about all possible attributes and features that you can play with a model that has been
733+
loaded using `nunchaku`.
734+
735+
Args:
736+
TODO
737+
modules_to_not_convert (`list`, *optional*, default to `None`):
738+
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
739+
modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
740+
"""
741+
742+
group_size_map = {"int4": 64, "nvfp4": 16}
743+
744+
def __init__(
745+
self,
746+
precision: str = "int4",
747+
rank: int = 32,
748+
modules_to_not_convert: Optional[List[str]] = None,
749+
**kwargs,
750+
):
751+
self.quant_method = QuantizationMethod.NUNCHAKU
752+
self.precision = precision
753+
self.group_size = self.group_size_map[precision]
754+
self.modules_to_not_convert = modules_to_not_convert
755+
756+
self.post_init()
757+
758+
def post_init(self):
759+
r"""
760+
Safety checker that arguments are correct
761+
"""
762+
accpeted_precision = ["int4", "nvfp4"]
763+
if self.precision not in accpeted_precision:
764+
raise ValueError(f"Only supported precision in {accpeted_precision} but found {self.precision}")

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
is_matplotlib_available,
9090
is_nltk_available,
9191
is_note_seq_available,
92+
is_nunchaku_available,
9293
is_onnx_available,
9394
is_opencv_available,
9495
is_optimum_quanto_available,

src/diffusers/utils/import_utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
217217
_torchao_available, _torchao_version = _is_package_available("torchao")
218218
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
219219
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
220+
_nunchaku_available, _nunchaku_version = _is_package_available("nunchaku", get_dist_name=True)
220221
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
221222
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
222223
_nltk_available, _nltk_version = _is_package_available("nltk")
@@ -363,6 +364,10 @@ def is_optimum_quanto_available():
363364
return _optimum_quanto_available
364365

365366

367+
def is_nunchaku_available():
368+
return _nunchaku_available
369+
370+
366371
def is_timm_available():
367372
return _timm_available
368373

@@ -816,7 +821,7 @@ def is_k_diffusion_version(operation: str, version: str):
816821

817822
def is_optimum_quanto_version(operation: str, version: str):
818823
"""
819-
Compares the current Accelerate version to a given reference with an operation.
824+
Compares the current quanto version to a given reference with an operation.
820825
821826
Args:
822827
operation (`str`):
@@ -829,6 +834,21 @@ def is_optimum_quanto_version(operation: str, version: str):
829834
return compare_versions(parse(_optimum_quanto_version), operation, version)
830835

831836

837+
def is_nunchaku_version(operation: str, version: str):
838+
"""
839+
Compares the current nunchaku version to a given reference with an operation.
840+
841+
Args:
842+
operation (`str`):
843+
A string representation of an operator, such as `">"` or `"<="`
844+
version (`str`):
845+
A version string
846+
"""
847+
if not _nunchaku_available:
848+
return False
849+
return compare_versions(parse(_nunchaku_version), operation, version)
850+
851+
832852
def is_xformers_version(operation: str, version: str):
833853
"""
834854
Compares the current xformers version to a given reference with an operation.

src/diffusers/utils/torch_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,7 @@ def device_synchronize(device_type: Optional[str] = None):
197197
device_type = get_device()
198198
device_mod = getattr(torch, device_type, torch.cuda)
199199
device_mod.synchronize()
200+
201+
202+
def is_fp8_available():
203+
return getattr(torch, "float8_e4m3fn", None) is None

0 commit comments

Comments
 (0)