Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
Expand Down Expand Up @@ -431,10 +431,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

Expand Down
8 changes: 1 addition & 7 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache


if is_transformers_available():
Expand Down Expand Up @@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm(

if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)

Expand Down Expand Up @@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(

if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()
else:
model.load_state_dict(diffusers_format_checkpoint)

Expand Down
4 changes: 1 addition & 3 deletions src/diffusers/loaders/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache


if is_accelerate_available():
Expand Down Expand Up @@ -82,7 +82,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()

return image_projection

Expand Down Expand Up @@ -158,7 +157,6 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
key_id += 1

empty_device_cache()
device_synchronize()

return attn_procs

Expand Down
4 changes: 1 addition & 3 deletions src/diffusers/loaders/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -82,7 +82,6 @@ def _convert_ip_adapter_attn_to_diffusers(
)

empty_device_cache()
device_synchronize()

return attn_procs

Expand Down Expand Up @@ -152,7 +151,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()

return image_proj

Expand Down
4 changes: 1 addition & 3 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
is_torch_version,
logging,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
Expand Down Expand Up @@ -755,7 +755,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
empty_device_cache()
device_synchronize()

return image_projection

Expand Down Expand Up @@ -854,7 +853,6 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
key_id += 2

empty_device_cache()
device_synchronize()

return attn_procs

Expand Down
5 changes: 1 addition & 4 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
load_or_create_model_card,
populate_model_card,
)
from ..utils.torch_utils import device_synchronize, empty_device_cache
from ..utils.torch_utils import empty_device_cache
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
Expand Down Expand Up @@ -1540,10 +1540,7 @@ def _load_pretrained_model(
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)

# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
empty_device_cache()
device_synchronize()

if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)
Expand Down