Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/diffusers/loaders/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
Expand All @@ -36,7 +36,7 @@ class FluxTransformer2DLoadersMixin:
Load layers into a [`FluxTransformer2DModel`].
"""

def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
Expand Down Expand Up @@ -82,11 +82,12 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)

return image_projection

def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
Expand Down Expand Up @@ -151,15 +152,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(value_dict)
else:
device = self.device
device_map = {"": self.device}
dtype = self.dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)

key_id += 1

return attn_procs

def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]

Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/loaders/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def _convert_ip_adapter_attn_to_diffusers(
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
else:
device_map = {"": self.device}
load_model_dict_into_meta(
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)

return attn_procs
Expand Down Expand Up @@ -144,7 +145,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(
if not low_cpu_mem_usage:
image_proj.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_proj, updated_state_dict, device=self.device, dtype=self.dtype)
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)

return image_proj

Expand Down
16 changes: 9 additions & 7 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
Expand Down Expand Up @@ -143,7 +143,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
adapter_name = kwargs.pop("adapter_name", None)
_pipeline = kwargs.pop("_pipeline", None)
network_alphas = kwargs.pop("network_alphas", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
allow_pickle = False

if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
Expand Down Expand Up @@ -540,7 +540,7 @@ def _get_custom_diffusion_state_dict(self):

return state_dict

def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
Expand Down Expand Up @@ -753,11 +753,12 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)

return image_projection

def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
Expand Down Expand Up @@ -846,13 +847,14 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
else:
device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
device_map = {"": device}
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)

key_id += 2

return attn_procs

def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def is_peft_version(operation: str, version: str):
version (`str`):
A version string
"""
if not _peft_version:
if not _peft_available:
return False
return compare_versions(parse(_peft_version), operation, version)

Expand All @@ -829,7 +829,7 @@ def is_bitsandbytes_version(operation: str, version: str):
version (`str`):
A version string
"""
if not _bitsandbytes_version:
if not _bitsandbytes_available:
return False
return compare_versions(parse(_bitsandbytes_version), operation, version)

Expand Down
Loading