|  | 
| 23 | 23 | 
 | 
| 24 | 24 | from .. import __version__ | 
| 25 | 25 | from ..quantizers import DiffusersAutoQuantizer | 
| 26 |  | -from ..utils import deprecate, is_accelerate_available, logging | 
|  | 26 | +from ..quantizers.quantization_config import QuantizationMethod | 
|  | 27 | +from ..utils import deprecate, is_accelerate_available, is_nunchaku_available, logging | 
| 27 | 28 | from ..utils.torch_utils import empty_device_cache | 
| 28 | 29 | from .single_file_utils import ( | 
| 29 | 30 |     SingleFileComponentError, | 
| @@ -243,6 +244,32 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = | 
| 243 | 244 |         >>> model = StableCascadeUNet.from_single_file(ckpt_path) | 
| 244 | 245 |         ``` | 
| 245 | 246 |         """ | 
|  | 247 | +        quantization_config = kwargs.get("quantization_config") | 
|  | 248 | +        if quantization_config is not None and quantization_config.quant_method == QuantizationMethod.SVDQUANT: | 
|  | 249 | +            if not is_nunchaku_available(): | 
|  | 250 | +                raise ImportError("Loading SVDQuant models requires the `nunchaku` package. Please install it.") | 
|  | 251 | + | 
|  | 252 | +            if isinstance(pretrained_model_link_or_path_or_dict, dict): | 
|  | 253 | +                raise ValueError( | 
|  | 254 | +                    "Loading a nunchaku model from a state_dict is not supported directly via from_single_file. Please provide a path." | 
|  | 255 | +                ) | 
|  | 256 | + | 
|  | 257 | +            if "FluxTransformer2DModel" in cls.__name__: | 
|  | 258 | +                from nunchaku import NunchakuFluxTransformer2dModel | 
|  | 259 | + | 
|  | 260 | +                kwargs.pop("quantization_config", None) | 
|  | 261 | +                return NunchakuFluxTransformer2dModel.from_pretrained( | 
|  | 262 | +                    pretrained_model_link_or_path_or_dict, **kwargs | 
|  | 263 | +                ) | 
|  | 264 | +            elif "SanaTransformer2DModel" in cls.__name__: | 
|  | 265 | +                from nunchaku import NunchakuSanaTransformer2DModel | 
|  | 266 | + | 
|  | 267 | +                kwargs.pop("quantization_config", None) | 
|  | 268 | +                return NunchakuSanaTransformer2DModel.from_pretrained( | 
|  | 269 | +                    pretrained_model_link_or_path_or_dict, **kwargs | 
|  | 270 | +                ) | 
|  | 271 | +            else: | 
|  | 272 | +                raise NotImplementedError(f"SVDQuant loading is not implemented for {cls.__name__}") | 
| 246 | 273 | 
 | 
| 247 | 274 |         mapping_class_name = _get_single_file_loadable_mapping_class(cls) | 
| 248 | 275 |         # if class_name not in SINGLE_FILE_LOADABLE_CLASSES: | 
|  | 
0 commit comments