1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ from contextlib import nullcontext
1415from typing import Dict
1516
1617from ..models .attention_processor import SD3IPAdapterJointAttnProcessor2_0
1718from ..models .embeddings import IPAdapterTimeImageProjection
1819from ..models .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT , load_model_dict_into_meta
20+ from ..utils import is_accelerate_available , is_torch_version , logging
21+
22+
23+ logger = logging .get_logger (__name__ )
1924
2025
2126class SD3Transformer2DLoadersMixin :
2227 """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
2328
24- def _load_ip_adapter_weights (self , state_dict : Dict , low_cpu_mem_usage : bool = _LOW_CPU_MEM_USAGE_DEFAULT ) -> None :
25- """Sets IP-Adapter attention processors, image projection, and loads state_dict.
29+ def _convert_ip_adapter_attn_to_diffusers (
30+ self , state_dict : Dict , low_cpu_mem_usage : bool = _LOW_CPU_MEM_USAGE_DEFAULT
31+ ) -> Dict :
32+ if low_cpu_mem_usage :
33+ if is_accelerate_available ():
34+ from accelerate import init_empty_weights
35+
36+ else :
37+ low_cpu_mem_usage = False
38+ logger .warning (
39+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
40+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
41+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
42+ " install accelerate\n ```\n ."
43+ )
44+
45+ if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
46+ raise NotImplementedError (
47+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
48+ " `low_cpu_mem_usage=False`."
49+ )
2650
27- Args:
28- state_dict (`Dict`):
29- State dict with keys "ip_adapter", which contains parameters for attention processors, and
30- "image_proj", which contains parameters for image projection net.
31- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
32- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
33- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
34- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
35- argument to `True` will raise an error.
36- """
3751 # IP-Adapter cross attention parameters
3852 hidden_size = self .config .attention_head_dim * self .config .num_attention_heads
3953 ip_hidden_states_dim = self .config .attention_head_dim * self .config .num_attention_heads
40- timesteps_emb_dim = state_dict ["ip_adapter" ][ " 0.norm_ip.linear.weight" ].shape [1 ]
54+ timesteps_emb_dim = state_dict ["0.norm_ip.linear.weight" ].shape [1 ]
4155
4256 # Dict where key is transformer layer index, value is attention processor's state dict
4357 # ip_adapter state dict keys example: "0.norm_ip.linear.weight"
4458 layer_state_dict = {idx : {} for idx in range (len (self .attn_processors ))}
45- for key , weights in state_dict [ "ip_adapter" ] .items ():
59+ for key , weights in state_dict .items ():
4660 idx , name = key .split ("." , maxsplit = 1 )
4761 layer_state_dict [int (idx )][name ] = weights
4862
49- # Create IP-Adapter attention processor
63+ # Create IP-Adapter attention processor & load state_dict
5064 attn_procs = {}
65+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
5166 for idx , name in enumerate (self .attn_processors .keys ()):
52- attn_procs [name ] = SD3IPAdapterJointAttnProcessor2_0 (
53- hidden_size = hidden_size ,
54- ip_hidden_states_dim = ip_hidden_states_dim ,
55- head_dim = self .config .attention_head_dim ,
56- timesteps_emb_dim = timesteps_emb_dim ,
57- ).to (self .device , dtype = self .dtype )
67+ with init_context ():
68+ attn_procs [name ] = SD3IPAdapterJointAttnProcessor2_0 (
69+ hidden_size = hidden_size ,
70+ ip_hidden_states_dim = ip_hidden_states_dim ,
71+ head_dim = self .config .attention_head_dim ,
72+ timesteps_emb_dim = timesteps_emb_dim ,
73+ )
5874
5975 if not low_cpu_mem_usage :
6076 attn_procs [name ].load_state_dict (layer_state_dict [idx ], strict = True )
@@ -63,27 +79,90 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _
6379 attn_procs [name ], layer_state_dict [idx ], device = self .device , dtype = self .dtype
6480 )
6581
66- self .set_attn_processor (attn_procs )
82+ return attn_procs
83+
84+ def _convert_ip_adapter_image_proj_to_diffusers (
85+ self , state_dict : Dict , low_cpu_mem_usage : bool = _LOW_CPU_MEM_USAGE_DEFAULT
86+ ) -> IPAdapterTimeImageProjection :
87+ if low_cpu_mem_usage :
88+ if is_accelerate_available ():
89+ from accelerate import init_empty_weights
90+
91+ else :
92+ low_cpu_mem_usage = False
93+ logger .warning (
94+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
95+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
96+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip"
97+ " install accelerate\n ```\n ."
98+ )
99+
100+ if low_cpu_mem_usage is True and not is_torch_version (">=" , "1.9.0" ):
101+ raise NotImplementedError (
102+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
103+ " `low_cpu_mem_usage=False`."
104+ )
105+
106+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
107+
108+ # Convert to diffusers
109+ updated_state_dict = {}
110+ for key , value in state_dict .items ():
111+ # InstantX/SD3.5-Large-IP-Adapter
112+ if key .startswith ("layers." ):
113+ idx = key .split ("." )[1 ]
114+ key = key .replace (f"layers.{ idx } .0.norm1" , f"layers.{ idx } .ln0" )
115+ key = key .replace (f"layers.{ idx } .0.norm2" , f"layers.{ idx } .ln1" )
116+ key = key .replace (f"layers.{ idx } .0.to_q" , f"layers.{ idx } .attn.to_q" )
117+ key = key .replace (f"layers.{ idx } .0.to_kv" , f"layers.{ idx } .attn.to_kv" )
118+ key = key .replace (f"layers.{ idx } .0.to_out" , f"layers.{ idx } .attn.to_out.0" )
119+ key = key .replace (f"layers.{ idx } .1.0" , f"layers.{ idx } .adaln_norm" )
120+ key = key .replace (f"layers.{ idx } .1.1" , f"layers.{ idx } .ff.net.0.proj" )
121+ key = key .replace (f"layers.{ idx } .1.3" , f"layers.{ idx } .ff.net.2" )
122+ key = key .replace (f"layers.{ idx } .2.1" , f"layers.{ idx } .adaln_proj" )
123+ updated_state_dict [key ] = value
67124
68125 # Image projetion parameters
69- embed_dim = state_dict [ "image_proj" ] ["proj_in.weight" ].shape [1 ]
70- output_dim = state_dict [ "image_proj" ] ["proj_out.weight" ].shape [0 ]
71- hidden_dim = state_dict [ "image_proj" ] ["proj_in.weight" ].shape [0 ]
72- heads = state_dict [ "image_proj" ] ["layers.0.attn.to_q.weight" ].shape [0 ] // 64
73- num_queries = state_dict [ "image_proj" ] ["latents" ].shape [1 ]
74- timestep_in_dim = state_dict [ "image_proj" ] ["time_embedding.linear_1.weight" ].shape [1 ]
126+ embed_dim = updated_state_dict ["proj_in.weight" ].shape [1 ]
127+ output_dim = updated_state_dict ["proj_out.weight" ].shape [0 ]
128+ hidden_dim = updated_state_dict ["proj_in.weight" ].shape [0 ]
129+ heads = updated_state_dict ["layers.0.attn.to_q.weight" ].shape [0 ] // 64
130+ num_queries = updated_state_dict ["latents" ].shape [1 ]
131+ timestep_in_dim = updated_state_dict ["time_embedding.linear_1.weight" ].shape [1 ]
75132
76133 # Image projection
77- self .image_proj = IPAdapterTimeImageProjection (
78- embed_dim = embed_dim ,
79- output_dim = output_dim ,
80- hidden_dim = hidden_dim ,
81- heads = heads ,
82- num_queries = num_queries ,
83- timestep_in_dim = timestep_in_dim ,
84- ).to (device = self .device , dtype = self .dtype )
134+ with init_context ():
135+ image_proj = IPAdapterTimeImageProjection (
136+ embed_dim = embed_dim ,
137+ output_dim = output_dim ,
138+ hidden_dim = hidden_dim ,
139+ heads = heads ,
140+ num_queries = num_queries ,
141+ timestep_in_dim = timestep_in_dim ,
142+ )
85143
86144 if not low_cpu_mem_usage :
87- self . image_proj .load_state_dict (state_dict [ "image_proj" ] , strict = True )
145+ image_proj .load_state_dict (updated_state_dict , strict = True )
88146 else :
89- load_model_dict_into_meta (self .image_proj , state_dict ["image_proj" ], device = self .device , dtype = self .dtype )
147+ load_model_dict_into_meta (image_proj , updated_state_dict , device = self .device , dtype = self .dtype )
148+
149+ return image_proj
150+
151+ def _load_ip_adapter_weights (self , state_dict : Dict , low_cpu_mem_usage : bool = _LOW_CPU_MEM_USAGE_DEFAULT ) -> None :
152+ """Sets IP-Adapter attention processors, image projection, and loads state_dict.
153+
154+ Args:
155+ state_dict (`Dict`):
156+ State dict with keys "ip_adapter", which contains parameters for attention processors, and
157+ "image_proj", which contains parameters for image projection net.
158+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
159+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
160+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
161+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
162+ argument to `True` will raise an error.
163+ """
164+
165+ attn_procs = self ._convert_ip_adapter_attn_to_diffusers (state_dict ["ip_adapter" ], low_cpu_mem_usage )
166+ self .set_attn_processor (attn_procs )
167+
168+ self .image_proj = self ._convert_ip_adapter_image_proj_to_diffusers (state_dict ["image_proj" ], low_cpu_mem_usage )
0 commit comments