diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index b9b86cf480a2..76fefc1260d0 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -24,7 +24,7 @@ from .. import __version__ from ..quantizers import DiffusersAutoQuantizer from ..utils import deprecate, is_accelerate_available, logging -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache from .single_file_utils import ( SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, @@ -431,10 +431,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, ) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() else: _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 5fafcb02be6f..a804ea80a9ad 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -46,7 +46,7 @@ ) from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT from ..utils.hub_utils import _get_model_file -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache if is_transformers_available(): @@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm( if is_accelerate_available(): load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() else: model.load_state_dict(diffusers_format_checkpoint, strict=False) @@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint( if is_accelerate_available(): load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() else: model.load_state_dict(diffusers_format_checkpoint) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index af03d09029c1..0de809594864 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -19,7 +19,7 @@ ) from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import is_accelerate_available, is_torch_version, logging -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache if is_accelerate_available(): @@ -82,7 +82,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us device_map = {"": self.device} load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) empty_device_cache() - device_synchronize() return image_projection @@ -158,7 +157,6 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_ key_id += 1 empty_device_cache() - device_synchronize() return attn_procs diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index 4421f46dfc02..1bc3a9c7a851 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -18,7 +18,7 @@ from ..models.embeddings import IPAdapterTimeImageProjection from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import is_accelerate_available, is_torch_version, logging -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache logger = logging.get_logger(__name__) @@ -82,7 +82,6 @@ def _convert_ip_adapter_attn_to_diffusers( ) empty_device_cache() - device_synchronize() return attn_procs @@ -152,7 +151,6 @@ def _convert_ip_adapter_image_proj_to_diffusers( device_map = {"": self.device} load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype) empty_device_cache() - device_synchronize() return image_proj diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 89c6449ff5d6..1d698e5a8b53 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -43,7 +43,7 @@ is_torch_version, logging, ) -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache from .lora_base import _func_optionally_disable_offloading from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .utils import AttnProcsLayers @@ -755,7 +755,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us device_map = {"": self.device} load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype) empty_device_cache() - device_synchronize() return image_projection @@ -854,7 +853,6 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_ key_id += 2 empty_device_cache() - device_synchronize() return attn_procs diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d7b2136b4afc..ec7629bddcb6 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -62,7 +62,7 @@ load_or_create_model_card, populate_model_card, ) -from ..utils.torch_utils import device_synchronize, empty_device_cache +from ..utils.torch_utils import empty_device_cache from .model_loading_utils import ( _caching_allocator_warmup, _determine_device_map, @@ -1540,10 +1540,7 @@ def _load_pretrained_model( assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) - # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is - # required because we move tensors with non_blocking=True, which is slightly faster for model loading. empty_device_cache() - device_synchronize() if offload_index is not None and len(offload_index) > 0: save_offload_index(offload_index, offload_folder)