Skip to content

Commit 9e4873b

Browse files
committed
update
1 parent 8385f45 commit 9e4873b

File tree

8 files changed

+99
-57
lines changed

8 files changed

+99
-57
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .. import __version__
2525
from ..quantizers import DiffusersAutoQuantizer
2626
from ..utils import deprecate, is_accelerate_available, logging
27+
from ..utils.torch_utils import device_synchronize, empty_device_cache
2728
from .single_file_utils import (
2829
SingleFileComponentError,
2930
convert_animatediff_checkpoint_to_diffusers,
@@ -430,6 +431,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
430431
keep_in_fp32_modules=keep_in_fp32_modules,
431432
unexpected_keys=unexpected_keys,
432433
)
434+
empty_device_cache()
435+
device_synchronize()
433436
else:
434437
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
435438

src/diffusers/loaders/single_file_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
4848
from ..utils.hub_utils import _get_model_file
49+
from ..utils.torch_utils import device_synchronize, empty_device_cache
4950

5051

5152
if is_transformers_available():
@@ -1689,6 +1690,8 @@ def create_diffusers_clip_model_from_ldm(
16891690

16901691
if is_accelerate_available():
16911692
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1693+
empty_device_cache()
1694+
device_synchronize()
16921695
else:
16931696
model.load_state_dict(diffusers_format_checkpoint, strict=False)
16941697

@@ -2148,6 +2151,8 @@ def create_diffusers_t5_model_from_checkpoint(
21482151

21492152
if is_accelerate_available():
21502153
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2154+
empty_device_cache()
2155+
device_synchronize()
21512156
else:
21522157
model.load_state_dict(diffusers_format_checkpoint)
21532158

src/diffusers/loaders/transformer_flux.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818
MultiIPAdapterImageProjection,
1919
)
2020
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
21-
from ..utils import (
22-
is_accelerate_available,
23-
is_torch_version,
24-
logging,
25-
)
21+
from ..utils import is_accelerate_available, is_torch_version, logging
22+
from ..utils.torch_utils import device_synchronize, empty_device_cache
2623

2724

2825
if is_accelerate_available():
@@ -84,6 +81,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8481
else:
8582
device_map = {"": self.device}
8683
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
84+
empty_device_cache()
85+
device_synchronize()
8786

8887
return image_projection
8988

@@ -158,6 +157,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
158157

159158
key_id += 1
160159

160+
empty_device_cache()
161+
device_synchronize()
162+
161163
return attn_procs
162164

163165
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):

src/diffusers/loaders/transformer_sd3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..models.embeddings import IPAdapterTimeImageProjection
1919
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
2020
from ..utils import is_accelerate_available, is_torch_version, logging
21+
from ..utils.torch_utils import device_synchronize, empty_device_cache
2122

2223

2324
logger = logging.get_logger(__name__)
@@ -80,6 +81,9 @@ def _convert_ip_adapter_attn_to_diffusers(
8081
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
8182
)
8283

84+
empty_device_cache()
85+
device_synchronize()
86+
8387
return attn_procs
8488

8589
def _convert_ip_adapter_image_proj_to_diffusers(
@@ -147,6 +151,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(
147151
else:
148152
device_map = {"": self.device}
149153
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
154+
empty_device_cache()
155+
device_synchronize()
150156

151157
return image_proj
152158

src/diffusers/loaders/unet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
is_torch_version,
4545
logging,
4646
)
47+
from ..utils.torch_utils import device_synchronize, empty_device_cache
4748
from .lora_base import _func_optionally_disable_offloading
4849
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
4950
from .utils import AttnProcsLayers
@@ -752,6 +753,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
752753
else:
753754
device_map = {"": self.device}
754755
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
756+
empty_device_cache()
757+
device_synchronize()
755758

756759
return image_projection
757760

@@ -849,6 +852,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
849852

850853
key_id += 2
851854

855+
empty_device_cache()
856+
device_synchronize()
857+
852858
return attn_procs
853859

854860
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):

src/diffusers/models/model_loading_utils.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -231,16 +231,6 @@ def load_model_dict_into_meta(
231231

232232
is_quantized = hf_quantizer is not None
233233
empty_state_dict = model.state_dict()
234-
expanded_device_map = {}
235-
236-
# if device_map is not None:
237-
# for param_name, param in state_dict.items():
238-
# if param_name not in empty_state_dict:
239-
# continue
240-
# param_device = _determine_param_device(param_name, device_map)
241-
# expanded_device_map[param_name] = param_device
242-
# print(expanded_device_map)
243-
# _caching_allocator_warmup(model, expanded_device_map, dtype)
244234

245235
for param_name, param in state_dict.items():
246236
if param_name not in empty_state_dict:
@@ -310,7 +300,15 @@ def load_model_dict_into_meta(
310300
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
311301
)
312302
else:
313-
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
303+
set_module_tensor_to_device(
304+
model,
305+
param_name,
306+
param_device,
307+
value=param,
308+
non_blocking=True,
309+
_empty_cache=False,
310+
**set_module_kwargs,
311+
)
314312

315313
return offload_index, state_dict_index
316314

@@ -533,6 +531,41 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
533531
return parsed_parameters
534532

535533

534+
def _find_mismatched_keys(
535+
state_dict,
536+
model_state_dict,
537+
loaded_keys,
538+
ignore_mismatched_sizes,
539+
):
540+
mismatched_keys = []
541+
if not ignore_mismatched_sizes:
542+
return mismatched_keys
543+
for checkpoint_key in loaded_keys:
544+
model_key = checkpoint_key
545+
# If the checkpoint is sharded, we may not have the key here.
546+
if checkpoint_key not in state_dict:
547+
continue
548+
549+
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
550+
mismatched_keys.append(
551+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
552+
)
553+
del state_dict[checkpoint_key]
554+
return mismatched_keys
555+
556+
557+
def _expand_device_map(device_map, param_names):
558+
"""
559+
Expand a device map to return the correspondence parameter name to device.
560+
"""
561+
new_device_map = {}
562+
for module, device in device_map.items():
563+
new_device_map.update(
564+
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
565+
)
566+
return new_device_map
567+
568+
536569
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
537570
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
538571
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each

src/diffusers/models/modeling_utils.py

Lines changed: 19 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,14 @@
6262
load_or_create_model_card,
6363
populate_model_card,
6464
)
65+
from ..utils.torch_utils import device_synchronize, empty_device_cache
6566
from .model_loading_utils import (
67+
_caching_allocator_warmup,
6668
_determine_device_map,
69+
_expand_device_map,
6770
_fetch_index_file,
6871
_fetch_index_file_legacy,
72+
_find_mismatched_keys,
6973
_load_state_dict_into_model,
7074
load_model_dict_into_meta,
7175
load_state_dict,
@@ -1469,11 +1473,6 @@ def _load_pretrained_model(
14691473
for pat in cls._keys_to_ignore_on_load_unexpected:
14701474
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
14711475

1472-
mismatched_keys = []
1473-
1474-
assign_to_params_buffers = None
1475-
error_msgs = []
1476-
14771476
# Deal with offload
14781477
if device_map is not None and "disk" in device_map.values():
14791478
if offload_folder is None:
@@ -1482,18 +1481,21 @@ def _load_pretrained_model(
14821481
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
14831482
" offers the weights in this format."
14841483
)
1485-
if offload_folder is not None:
1484+
else:
14861485
os.makedirs(offload_folder, exist_ok=True)
14871486
if offload_state_dict is None:
14881487
offload_state_dict = True
14891488

1489+
# Caching allocator warmup
1490+
if device_map is not None:
1491+
expanded_device_map = _expand_device_map(device_map, expected_keys)
1492+
_caching_allocator_warmup(model, expanded_device_map, dtype)
1493+
14901494
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
1495+
state_dict_folder, state_dict_index = None, None
14911496
if offload_state_dict:
14921497
state_dict_folder = tempfile.mkdtemp()
14931498
state_dict_index = {}
1494-
else:
1495-
state_dict_folder = None
1496-
state_dict_index = None
14971499

14981500
if state_dict is not None:
14991501
# load_state_dict will manage the case where we pass a dict instead of a file
@@ -1503,38 +1505,14 @@ def _load_pretrained_model(
15031505
if len(resolved_model_file) > 1:
15041506
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
15051507

1508+
mismatched_keys = []
1509+
assign_to_params_buffers = None
1510+
error_msgs = []
1511+
15061512
for shard_file in resolved_model_file:
15071513
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
1508-
1509-
def _find_mismatched_keys(
1510-
state_dict,
1511-
model_state_dict,
1512-
loaded_keys,
1513-
ignore_mismatched_sizes,
1514-
):
1515-
mismatched_keys = []
1516-
if ignore_mismatched_sizes:
1517-
for checkpoint_key in loaded_keys:
1518-
model_key = checkpoint_key
1519-
# If the checkpoint is sharded, we may not have the key here.
1520-
if checkpoint_key not in state_dict:
1521-
continue
1522-
1523-
if (
1524-
model_key in model_state_dict
1525-
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1526-
):
1527-
mismatched_keys.append(
1528-
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1529-
)
1530-
del state_dict[checkpoint_key]
1531-
return mismatched_keys
1532-
15331514
mismatched_keys += _find_mismatched_keys(
1534-
state_dict,
1535-
model_state_dict,
1536-
loaded_keys,
1537-
ignore_mismatched_sizes,
1515+
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
15381516
)
15391517

15401518
if low_cpu_mem_usage:
@@ -1554,11 +1532,11 @@ def _find_mismatched_keys(
15541532
else:
15551533
if assign_to_params_buffers is None:
15561534
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
1557-
15581535
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
15591536

1560-
torch.cuda.synchronize()
1561-
1537+
empty_device_cache()
1538+
device_synchronize()
1539+
15621540
if offload_index is not None and len(offload_index) > 0:
15631541
save_offload_index(offload_index, offload_folder)
15641542
offload_index = None

src/diffusers/utils/torch_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,5 +182,14 @@ def get_device():
182182
def empty_device_cache(device_type: Optional[str] = None):
183183
if device_type is None:
184184
device_type = get_device()
185+
if device_type in ["cpu"]:
186+
return
185187
device_mod = getattr(torch, device_type, torch.cuda)
186188
device_mod.empty_cache()
189+
190+
191+
def device_synchronize(device_type: Optional[str] = None):
192+
if device_type is None:
193+
device_type = get_device()
194+
device_mod = getattr(torch, device_type, torch.cuda)
195+
device_mod.synchronize()

0 commit comments

Comments
 (0)