diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md index f272346aa2e2..118511b75d50 100644 --- a/docs/source/en/quantization/bitsandbytes.md +++ b/docs/source/en/quantization/bitsandbytes.md @@ -59,19 +59,7 @@ model_8bit = FluxTransformer2DModel.from_pretrained( model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype ``` -Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. - -```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig(load_in_8bit=True) - -model_8bit = FluxTransformer2DModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", - subfolder="transformer", - quantization_config=quantization_config -) -``` +Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`]. @@ -131,7 +119,7 @@ from diffusers import FluxTransformer2DModel, BitsAndBytesConfig quantization_config = BitsAndBytesConfig(load_in_4bit=True) model_4bit = FluxTransformer2DModel.from_pretrained( - "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" + "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer" ) ``` @@ -264,4 +252,9 @@ double_quant_model = SD3Transformer2DModel.from_pretrained( quantization_config=double_quant_config, ) model.dequantize() -``` \ No newline at end of file +``` + +## Resources + +* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4) +* [Training](https://gist.github.com/sayakpaul/05afd428bc089b47af7c016e42004527) \ No newline at end of file diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 5277ad2f9389..932a94571107 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -211,21 +211,28 @@ def load_model_dict_into_meta( set_module_kwargs["dtype"] = dtype # bnb params are flattened. - if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: - model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" - raise ValueError( - f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." - ) + if empty_state_dict[param_name].shape != param.shape: + if ( + is_quant_method_bnb + and hf_quantizer.pre_quantized + and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) + ): + hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) + elif not is_quant_method_bnb: + model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" + raise ValueError( + f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + ) - if not is_quantized or ( - not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device) + if is_quantized and ( + hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) ): + hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) + else: if accepts_dtype: set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) else: set_module_tensor_to_device(model, param_name, device, value=param) - else: - hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) return unexpected_keys diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 93852d29ef59..4c8483a3d6ee 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .auto import DiffusersAutoQuantizationConfig, DiffusersAutoQuantizer +from .auto import DiffusersAutoQuantizer from .base import DiffusersQuantizer diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index f231f279e13a..97cbcdc0e53f 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -33,10 +33,10 @@ } -class DiffusersAutoQuantizationConfig: +class DiffusersAutoQuantizer: """ - The auto diffusers quantization config class that takes care of automatically dispatching to the correct - quantization config given a quantization config stored in a dictionary. + The auto diffusers quantizer class that takes care of automatically instantiating to the correct + `DiffusersQuantizer` given the `QuantizationConfig`. """ @classmethod @@ -60,31 +60,11 @@ def from_dict(cls, quantization_config_dict: Dict): target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method] return target_cls.from_dict(quantization_config_dict) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - model_config = cls.load_config(pretrained_model_name_or_path, **kwargs) - if getattr(model_config, "quantization_config", None) is None: - raise ValueError( - f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized." - ) - quantization_config_dict = model_config.quantization_config - quantization_config = cls.from_dict(quantization_config_dict) - # Update with potential kwargs that are passed through from_pretrained. - quantization_config.update(kwargs) - return quantization_config - - -class DiffusersAutoQuantizer: - """ - The auto diffusers quantizer class that takes care of automatically instantiating to the correct - `DiffusersQuantizer` given the `QuantizationConfig`. - """ - @classmethod def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs): # Convert it to a QuantizationConfig if the q_config is a dict if isinstance(quantization_config, dict): - quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config) + quantization_config = cls.from_dict(quantization_config) quant_method = quantization_config.quant_method @@ -107,7 +87,16 @@ def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - quantization_config = DiffusersAutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + model_config = cls.load_config(pretrained_model_name_or_path, **kwargs) + if getattr(model_config, "quantization_config", None) is None: + raise ValueError( + f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized." + ) + quantization_config_dict = model_config.quantization_config + quantization_config = cls.from_dict(quantization_config_dict) + # Update with potential kwargs that are passed through from_pretrained. + quantization_config.update(kwargs) + return cls.from_config(quantization_config) @classmethod @@ -129,7 +118,7 @@ def merge_quantization_configs( warning_msg = "" if isinstance(quantization_config, dict): - quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config) + quantization_config = cls.from_dict(quantization_config) if warning_msg != "": warnings.warn(warning_msg) diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py index 017136a98854..6ec3885fe373 100644 --- a/src/diffusers/quantizers/base.py +++ b/src/diffusers/quantizers/base.py @@ -134,7 +134,7 @@ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization""" return max_memory - def check_quantized_param( + def check_if_quantized_param( self, model: "ModelMixin", param_value: "torch.Tensor", @@ -152,10 +152,13 @@ def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter": """ takes needed components from state_dict and creates quantized param. """ - if not hasattr(self, "check_quantized_param"): - raise AttributeError( - f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}." - ) + return + + def check_quantized_param_shape(self, *args, **kwargs): + """ + checks if the quantized param has expected shape. + """ + return True def validate_environment(self, *args, **kwargs): """ diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index e3041aba60ae..d5ac1611a571 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -106,7 +106,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": else: raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.") - def check_quantized_param( + def check_if_quantized_param( self, model: "ModelMixin", param_value: "torch.Tensor", @@ -204,6 +204,16 @@ def create_quantized_param( module._parameters[tensor_name] = new_value + def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape): + n = current_param_shape.numel() + inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1) + if loaded_param_shape != inferred_shape: + raise ValueError( + f"Expected the flattened shape of the current param ({param_name}) to be {loaded_param_shape} but is {inferred_shape}." + ) + else: + return True + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: # need more space for buffers that are created during quantization max_memory = {key: val * 0.90 for key, val in max_memory.items()} @@ -330,7 +340,6 @@ def __init__(self, quantization_config, **kwargs): if self.quantization_config.llm_int8_skip_modules is not None: self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules - # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit def validate_environment(self, *args, **kwargs): if not torch.cuda.is_available(): raise RuntimeError("No GPU found. A GPU is needed for quantization.") @@ -404,7 +413,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization") return torch.int8 - def check_quantized_param( + def check_if_quantized_param( self, model: "ModelMixin", param_value: "torch.Tensor", diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 6c1b24e31e2a..7b553434fbe9 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import gc +import os import tempfile import unittest import numpy as np +import safetensors.torch from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel from diffusers.utils import logging @@ -118,6 +120,9 @@ def get_dummy_inputs(self): class BnB4BitBasicTests(Base4bitTests): def setUp(self): + gc.collect() + torch.cuda.empty_cache() + # Models self.model_fp16 = SD3Transformer2DModel.from_pretrained( self.model_name, subfolder="transformer", torch_dtype=torch.float16 @@ -232,7 +237,7 @@ def test_linear_are_4bit(self): def test_config_from_pretrained(self): transformer_4bit = FluxTransformer2DModel.from_pretrained( - "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" + "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer" ) linear = get_some_linear_layer(transformer_4bit) self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) @@ -312,9 +317,42 @@ def test_bnb_4bit_wrong_config(self): with self.assertRaises(ValueError): _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add") + def test_bnb_4bit_errors_loading_incorrect_state_dict(self): + r""" + Test if loading with an incorrect state dict raises an error. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + nf4_config = BitsAndBytesConfig(load_in_4bit=True) + model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + model_4bit.save_pretrained(tmpdirname) + del model_4bit + + with self.assertRaises(ValueError) as err_context: + state_dict = safetensors.torch.load_file( + os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors") + ) + + # corrupt the state dict + key_to_target = "context_embedder.weight" # can be other keys too. + compatible_param = state_dict[key_to_target] + corrupted_param = torch.randn(compatible_param.shape[0] - 1, 1) + state_dict[key_to_target] = bnb.nn.Params4bit(corrupted_param, requires_grad=False) + safetensors.torch.save_file( + state_dict, os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors") + ) + + _ = SD3Transformer2DModel.from_pretrained(tmpdirname) + + assert key_to_target in str(err_context.exception) + class BnB4BitTrainingTests(Base4bitTests): def setUp(self): + gc.collect() + torch.cuda.empty_cache() + nf4_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", @@ -360,6 +398,9 @@ def test_training(self): @require_transformers_version_greater("4.44.0") class SlowBnb4BitTests(Base4bitTests): def setUp(self) -> None: + gc.collect() + torch.cuda.empty_cache() + nf4_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", @@ -447,8 +488,10 @@ def test_moving_to_cpu_throws_warning(self): @require_transformers_version_greater("4.44.0") class SlowBnb4BitFluxTests(Base4bitTests): def setUp(self) -> None: - # TODO: Copy sayakpaul/flux.1-dev-nf4-pkg to testing repo. - model_id = "sayakpaul/flux.1-dev-nf4-pkg" + gc.collect() + torch.cuda.empty_cache() + + model_id = "hf-internal-testing/flux.1-dev-nf4-pkg" t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") self.pipeline_4bit = DiffusionPipeline.from_pretrained( diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 2e4aec39b427..ba2402461c87 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -117,6 +117,9 @@ def get_dummy_inputs(self): class BnB8bitBasicTests(Base8bitTests): def setUp(self): + gc.collect() + torch.cuda.empty_cache() + # Models self.model_fp16 = SD3Transformer2DModel.from_pretrained( self.model_name, subfolder="transformer", torch_dtype=torch.float16 @@ -238,7 +241,7 @@ def test_llm_skip(self): def test_config_from_pretrained(self): transformer_8bit = FluxTransformer2DModel.from_pretrained( - "sayakpaul/flux.1-dev-int8-pkg", subfolder="transformer" + "hf-internal-testing/flux.1-dev-int8-pkg", subfolder="transformer" ) linear = get_some_linear_layer(transformer_8bit) self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) @@ -296,6 +299,9 @@ def test_device_and_dtype_assignment(self): class BnB8bitTrainingTests(Base8bitTests): def setUp(self): + gc.collect() + torch.cuda.empty_cache() + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) self.model_8bit = SD3Transformer2DModel.from_pretrained( self.model_name, subfolder="transformer", quantization_config=mixed_int8_config @@ -337,6 +343,9 @@ def test_training(self): @require_transformers_version_greater("4.44.0") class SlowBnb8bitTests(Base8bitTests): def setUp(self) -> None: + gc.collect() + torch.cuda.empty_cache() + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) model_8bit = SD3Transformer2DModel.from_pretrained( self.model_name, subfolder="transformer", quantization_config=mixed_int8_config @@ -427,8 +436,10 @@ def test_generate_quality_dequantize(self): @require_transformers_version_greater("4.44.0") class SlowBnb8bitFluxTests(Base8bitTests): def setUp(self) -> None: - # TODO: Copy sayakpaul/flux.1-dev-int8-pkg to testing repo. - model_id = "sayakpaul/flux.1-dev-int8-pkg" + gc.collect() + torch.cuda.empty_cache() + + model_id = "hf-internal-testing/flux.1-dev-int8-pkg" t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") self.pipeline_8bit = DiffusionPipeline.from_pretrained( @@ -466,6 +477,9 @@ def test_quality(self): @slow class BaseBnb8bitSerializationTests(Base8bitTests): def setUp(self): + gc.collect() + torch.cuda.empty_cache() + quantization_config = BitsAndBytesConfig( load_in_8bit=True, )