Skip to content
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
e634ff2
quantization config.
sayakpaul Aug 19, 2024
02a6dff
fix-copies
sayakpaul Aug 19, 2024
c385a2b
Merge branch 'main' into quantization-config
sayakpaul Aug 20, 2024
0355875
Merge branch 'main' into quantization-config
sayakpaul Aug 20, 2024
e41b494
Merge branch 'main' into quantization-config
sayakpaul Aug 20, 2024
dfb33eb
Merge branch 'main' into quantization-config
sayakpaul Aug 21, 2024
e492655
Merge branch 'main' into quantization-config
sayakpaul Aug 22, 2024
6e86cc0
fix
sayakpaul Aug 22, 2024
58a3d15
modules_to_not_convert
sayakpaul Aug 22, 2024
1d477f9
Merge branch 'main' into quantization-config
sayakpaul Aug 22, 2024
bd7f46d
Merge branch 'main' into quantization-config
sayakpaul Aug 23, 2024
d5d7bb6
Merge branch 'main' into quantization-config
sayakpaul Aug 28, 2024
44c8a75
Merge branch 'main' into quantization-config
sayakpaul Aug 28, 2024
6a0fcdc
add bitsandbytes utilities.
sayakpaul Aug 28, 2024
e4590fa
make progress.
sayakpaul Aug 28, 2024
77a1438
Merge branch 'main' into quantization-config
sayakpaul Aug 29, 2024
335ab6b
fixes
sayakpaul Aug 29, 2024
d44ef85
quality
sayakpaul Aug 29, 2024
210fa1e
up
sayakpaul Aug 29, 2024
f4feee1
up
sayakpaul Aug 29, 2024
e8c1722
Merge branch 'main' into quantization-config
sayakpaul Aug 29, 2024
7f86a71
Merge branch 'main' into quantization-config
sayakpaul Aug 29, 2024
ba671b6
minor
sayakpaul Aug 30, 2024
c1a9f13
up
sayakpaul Aug 30, 2024
4489c54
Merge branch 'main' into quantization-config
sayakpaul Aug 30, 2024
f2ca5e2
up
sayakpaul Aug 30, 2024
d6b8954
fix
sayakpaul Aug 30, 2024
45029e2
provide credits where due.
sayakpaul Aug 30, 2024
4eb468a
make configurations work.
sayakpaul Aug 30, 2024
939965d
fixes
sayakpaul Aug 30, 2024
8557166
Merge branch 'main' into quantization-config
sayakpaul Aug 30, 2024
d098d07
fix
sayakpaul Aug 30, 2024
c4a0074
update_missing_keys
sayakpaul Aug 30, 2024
ee45612
fix
sayakpaul Aug 30, 2024
b24c0a7
fix
sayakpaul Aug 31, 2024
473505c
make it work.
sayakpaul Aug 31, 2024
c795c82
fix
sayakpaul Aug 31, 2024
c1d5b96
Merge branch 'main' into quantization-config
sayakpaul Aug 31, 2024
af7caca
provide credits to transformers.
sayakpaul Aug 31, 2024
80967f5
empty commit
sayakpaul Sep 1, 2024
3bdf25a
handle to() better.
sayakpaul Sep 2, 2024
27415cc
tests
sayakpaul Sep 2, 2024
51cac09
change to bnb from bitsandbytes
sayakpaul Sep 2, 2024
15f3032
fix tests
sayakpaul Sep 2, 2024
77c9fdb
better safeguard.
sayakpaul Sep 2, 2024
ddc9f29
change merging status
sayakpaul Sep 2, 2024
44c4109
courtesy to transformers.
sayakpaul Sep 2, 2024
27666a8
move upper.
sayakpaul Sep 2, 2024
3464d83
better
sayakpaul Sep 2, 2024
b106124
Merge branch 'main' into quantization-config
sayakpaul Sep 2, 2024
330fa0a
Merge branch 'main' into quantization-config
sayakpaul Sep 2, 2024
abc8607
make the unused kwargs warning friendlier.
sayakpaul Sep 3, 2024
31725aa
harmonize changes with https://github.com/huggingface/transformers/pu…
sayakpaul Sep 3, 2024
e5938a6
style
sayakpaul Sep 3, 2024
444588f
trainin tests
sayakpaul Sep 3, 2024
d3360ce
Merge branch 'main' into quantization-config
sayakpaul Sep 3, 2024
d8b35f4
Merge branch 'main' into quantization-config
sayakpaul Sep 3, 2024
859f2d7
Merge branch 'main' into quantization-config
sayakpaul Sep 4, 2024
3b2d6e1
feedback part i.
sayakpaul Sep 4, 2024
5799954
Add Flux inpainting and Flux Img2Img (#9135)
Gothos Sep 4, 2024
8e4bd08
Revert "Add Flux inpainting and Flux Img2Img (#9135)"
sayakpaul Sep 6, 2024
835d4ad
tests
sayakpaul Sep 6, 2024
27075fe
don
sayakpaul Sep 6, 2024
5c00c1c
Merge branch 'main' into quantization-config
sayakpaul Sep 6, 2024
5d633a0
Merge branch 'main' into quantization-config
sayakpaul Sep 8, 2024
c381fe0
Apply suggestions from code review
sayakpaul Sep 10, 2024
3c92878
Merge branch 'main' into quantization-config
sayakpaul Sep 10, 2024
acdeb25
contribution guide.
sayakpaul Sep 11, 2024
aa295b7
Merge branch 'main' into quantization-config
sayakpaul Sep 11, 2024
7f7c9ce
Merge branch 'main' into quantization-config
sayakpaul Sep 15, 2024
55f96d8
Merge branch 'main' into quantization-config
sayakpaul Sep 15, 2024
b28cc65
changes
sayakpaul Sep 17, 2024
8328e86
Merge branch 'main' into quantization-config
sayakpaul Sep 17, 2024
9758942
empty
sayakpaul Sep 17, 2024
b1a9878
fix tests
sayakpaul Sep 17, 2024
971305b
harmonize with https://github.com/huggingface/transformers/pull/33546.
sayakpaul Sep 18, 2024
f41adf1
numpy_cosine_distance
sayakpaul Sep 19, 2024
0bcb88b
Merge branch 'main' into quantization-config
sayakpaul Sep 19, 2024
55b3696
Merge branch 'main' into quantization-config
sayakpaul Sep 20, 2024
4cb3a6d
Merge branch 'main' into quantization-config
sayakpaul Sep 23, 2024
8a03eae
Merge branch 'main' into quantization-config
sayakpaul Sep 24, 2024
53f0a92
Merge branch 'main' into quantization-config
sayakpaul Sep 26, 2024
6aab47c
Merge branch 'main' into quantization-config
sayakpaul Sep 27, 2024
9b9a610
resolved conflicts,
sayakpaul Sep 29, 2024
510d57a
Merge branch 'main' into quantization-config
sayakpaul Oct 10, 2024
555a5ae
config_dict modification.
sayakpaul Oct 10, 2024
da10365
remove if config comment.
sayakpaul Oct 10, 2024
71316a6
note for load_state_dict changes.
sayakpaul Oct 10, 2024
12f5c59
float8 check.
sayakpaul Oct 10, 2024
5e722cd
quantizer.
sayakpaul Oct 10, 2024
c78dd0c
raise an error for non-True low_cpu_mem_usage values when using quant.
sayakpaul Oct 10, 2024
af3ecea
low_cpu_mem_usage shenanigans when using fp32 modules.
sayakpaul Oct 10, 2024
a473d28
don't re-assign _pre_quantization_type.
sayakpaul Oct 10, 2024
870d74f
make comments clear.
sayakpaul Oct 10, 2024
3e6cfeb
remove comments.
sayakpaul Oct 10, 2024
673993c
handle mixed types better when moving to cpu.
sayakpaul Oct 10, 2024
0d5f2f7
add tests to check if we're throwing warning rightly.
sayakpaul Oct 10, 2024
3cb20fe
better check.
sayakpaul Oct 10, 2024
10940a9
fix 8bit test_quality.
sayakpaul Oct 10, 2024
c0a88ae
Merge branch 'main' into quantization-config
sayakpaul Oct 10, 2024
dcc5bc5
Merge branch 'main' into quantization-config
sayakpaul Oct 12, 2024
5e0b4eb
Merge branch 'main' into quantization-config
sayakpaul Oct 12, 2024
569dd96
Merge branch 'main' into quantization-config
sayakpaul Oct 13, 2024
8bdc846
Merge branch 'main' into quantization-config
sayakpaul Oct 15, 2024
ff8ddef
handle dtype more robustly.
sayakpaul Oct 15, 2024
de6394a
better message when keep_in_fp32_modules.
sayakpaul Oct 15, 2024
81bb48a
handle dtype casting.
sayakpaul Oct 15, 2024
c5e62ae
Merge branch 'main' into quantization-config
sayakpaul Oct 15, 2024
d023b40
Merge branch 'main' into quantization-config
sayakpaul Oct 16, 2024
a3d2655
Merge branch 'main' into quantization-config
sayakpaul Oct 16, 2024
700b0f3
Merge branch 'main' into quantization-config
sayakpaul Oct 18, 2024
0ae70fe
fix dtype checks in pipeline.
sayakpaul Oct 18, 2024
ecdf1d0
fix warning message.
sayakpaul Oct 18, 2024
aea3398
Update src/diffusers/models/modeling_utils.py
sayakpaul Oct 18, 2024
3a91974
Merge branch 'main' into quantization-config
sayakpaul Oct 18, 2024
5d8e844
Merge branch 'main' into quantization-config
sayakpaul Oct 19, 2024
501a6ba
mitigate the confusing cpu warning
sayakpaul Oct 19, 2024
1a931cb
Merge branch 'main' into quantization-config
sayakpaul Oct 21, 2024
2fa8fb9
Merge branch 'main' into quantization-config
sayakpaul Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
Expand Down Expand Up @@ -123,7 +124,6 @@
"VQModel",
]
)

_import_structure["optimization"] = [
"get_constant_schedule",
"get_constant_schedule_with_warmup",
Expand Down Expand Up @@ -155,6 +155,7 @@
"StableDiffusionMixin",
]
)
_import_structure["quantizers"] = ["DiffusersQuantizer"]
_import_structure["schedulers"].extend(
[
"AmusedScheduler",
Expand Down Expand Up @@ -526,6 +527,7 @@

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig

try:
if not is_onnx_available():
Expand Down Expand Up @@ -619,6 +621,7 @@
ScoreSdeVePipeline,
StableDiffusionMixin,
)
from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
CMStochasticIterativeScheduler,
Expand Down
13 changes: 12 additions & 1 deletion src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,8 @@ def extract_init_dict(cls, config_dict, **kwargs):
init_dict[key] = config_dict.pop(key)

# 4. Give nice warning if unexpected values have been passed
if len(config_dict) > 0:
only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because quantization_config isn't a part of any model's __init__().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to not add to cofig_dict if it is not going into __init__, i.e. at line 511

 # remove private attributes
 config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
# remove quantization_config
 config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config")}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot remove quantization_config from the config of a model as that would prevent loading of the quantized models via from_pretrained().

quantization_config isn't used for initializing a model, it's used to determine what kind of quantization configuration to inject inside the given model. This is why it's only used in from_pretrained() of ModelMixin.

LMK if you have a better idea to handle it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not remove them from the config, just not adding to the config_dict inside this extract_init_dict method: basically, the cofig_dict in this function goes through these steps:

  1. it is used to create init_dict: the quantisation config will not go there, so it is not affected if we do not add it to config_dict
  2. it is used to throw a warning after we createdinit_dict, if the quantisation configs were not there, we do not need to throw a warning for it
  3. it goes into unused_kwargs - so I think this is the only difference it would make, do we need the quantisation config to be in unused_kwargs returned by extract_init_dict? I think unused_kwargs is only used to send additional warnings for unexpected stuff, but since quantisation config is expected, and we have already decided not to send a warning here inside extract_init_dict - I think it does not need to go to the unused_kwargs here?
    @classmethod
    def extract_init_dict(cls, config_dict, **kwargs):
         ...
        config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"

        # remove private attributes
        config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
        
+      # remove quantization_config
+      config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config")}
        
        
        ## here we use config_dict to create `init_dict` which will be passed to `__init__` method
        init_dict = {}
        for key in expected_keys:
                 ...
                init_dict[key] = config_dict.pop(key)
-      only_quant_config_remaining = len(config_dict) == 1 and "quantization_config" in config_dict
-      if len(config_dict) > 0 and not only_quant_config_remaining:
+     if len(config_dict) > 0:
            logger.warning(
                f"The config attributes {config_dict} were passed to {cls.__name__}, "
                "but are not expected and will be ignored. Please verify your "
                f"{cls.config_name} configuration file."
            )
       ....
        # 6. Define unused keyword arguments
        unused_kwargs = {**config_dict, **kwargs}

        return init_dict, unused_kwargs, hidden_config_dict

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Resolved in 555a5ae.

if len(config_dict) > 0 and not only_quant_config_remaining:
logger.warning(
f"The config attributes {config_dict} were passed to {cls.__name__}, "
"but are not expected and will be ignored. Please verify your "
Expand Down Expand Up @@ -586,10 +587,20 @@ def to_json_saveable(value):
value = value.as_posix()
return value

# IFWatermarker, for example, doesn't have a `config`.
if hasattr(self, "config") and "quantization_config" in self.config:
config_dict["quantization_config"] = (
self.config.quantization_config.to_dict()
if not isinstance(self.config.quantization_config, dict)
else self.config.quantization_config
)

config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
# Don't save "_ignore_files" or "_use_default_values"
config_dict.pop("_ignore_files", None)
config_dict.pop("_use_default_values", None)
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = config_dict.pop("_pre_quantization_dtype", None)

return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

Expand Down
109 changes: 100 additions & 9 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch
from huggingface_hub.utils import EntryNotFoundError

from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
Expand Down Expand Up @@ -53,11 +54,36 @@


# Adapted from `transformers` (see modeling_utils.py)
def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
def _determine_device_map(
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
):
if isinstance(device_map, str):
special_dtypes = {}
if hf_quantizer is not None:
special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
special_dtypes.update(
{
name: torch.float32
for name, _ in model.named_parameters()
if any(m in name for m in keep_in_fp32_modules)
}
)

target_dtype = torch_dtype
if hf_quantizer is not None:
target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)

no_split_modules = model._get_no_split_modules(device_map)
device_map_kwargs = {"no_split_module_classes": no_split_modules}

if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
device_map_kwargs["special_dtypes"] = special_dtypes
elif len(special_dtypes) > 0:
logger.warning(
"This model has some weights that should be kept in higher precision, you need to upgrade "
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
)

if device_map != "sequential":
max_memory = get_balanced_memory(
model,
Expand All @@ -69,8 +95,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
else:
max_memory = get_max_memory(max_memory)

if hf_quantizer is not None:
max_memory = hf_quantizer.adjust_max_memory(max_memory)

device_map_kwargs["max_memory"] = max_memory
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)

if hf_quantizer is not None:
hf_quantizer.validate_environment(device_map=device_map)

return device_map

Expand Down Expand Up @@ -99,6 +131,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
if isinstance(checkpoint_file, dict):
return checkpoint_file
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
Expand Down Expand Up @@ -136,29 +170,57 @@ def load_model_dict_into_meta(
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
hf_quantizer=None,
keep_in_fp32_modules=None,
) -> List[str]:
device = device or torch.device("cpu")
device = device or torch.device("cpu") if hf_quantizer is None else device
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More on this in the later changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not specific to this PR but device = device or torch.device("cpu") is a bit dangerous because theoretically, 0 is a valid device but it would be considered falsy. AFAICT it's not problematic for the existing code, but something to keep in mind.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a comment about it too.

dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None

accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())

unexpected_keys = []
empty_state_dict = model.state_dict()
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")

for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue

if empty_state_dict[param_name].shape != param.shape:
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
if (
keep_in_fp32_modules is not None
and any(
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
)
and dtype == torch.float16
):
param = param.to(torch.float32)
else:
param = param.to(dtype)

is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
if not is_quantized and not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because bnb quantized params are usually flattened.

model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)

if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
if (
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device))
):
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)

return unexpected_keys


Expand Down Expand Up @@ -228,3 +290,32 @@ def _fetch_index_file(
index_file = None

return index_file


# Adapted from
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
weight_map = sharded_metadata.get("weight_map", None)
if weight_map is None:
raise KeyError("'weight_map' key not found in the shard index file.")

# Collect all unique safetensors files from weight_map
files_to_load = set(weight_map.values())
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
merged_state_dict = {}

# Load tensors from each unique file
for file_name in files_to_load:
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")

if is_safetensors:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
else:
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))

return merged_state_dict
Loading