|
19 | 19 |
|
20 | 20 | from ...configuration_utils import ConfigMixin, register_to_config |
21 | 21 | from ...loaders import FromOriginalModelMixin, PeftAdapterMixin |
| 22 | +from ...loaders.transformers_sd3 import SD3Transformer2DLoadersMixin |
22 | 23 | from ...models.attention import FeedForward, JointTransformerBlock |
23 | 24 | from ...models.attention_processor import ( |
24 | 25 | Attention, |
25 | 26 | AttentionProcessor, |
26 | 27 | FusedJointAttnProcessor2_0, |
27 | | - IPAdapterJointAttnProcessor2_0, |
28 | 28 | JointAttnProcessor2_0, |
29 | 29 | ) |
30 | | -from ...models.modeling_utils import ModelMixin, load_model_dict_into_meta |
| 30 | +from ...models.modeling_utils import ModelMixin |
31 | 31 | from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero |
32 | 32 | from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers |
33 | 33 | from ...utils.torch_utils import maybe_allow_in_graph |
34 | | -from ..embeddings import CombinedTimestepTextProjEmbeddings, IPAdapterTimeImageProjection, PatchEmbed |
| 34 | +from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed |
35 | 35 | from ..modeling_outputs import Transformer2DModelOutput |
36 | 36 |
|
37 | 37 |
|
@@ -104,7 +104,9 @@ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): |
104 | 104 | return hidden_states |
105 | 105 |
|
106 | 106 |
|
107 | | -class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): |
| 107 | +class SD3Transformer2DModel( |
| 108 | + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin |
| 109 | +): |
108 | 110 | """ |
109 | 111 | The Transformer model introduced in Stable Diffusion 3. |
110 | 112 |
|
@@ -331,89 +333,6 @@ def _set_gradient_checkpointing(self, module, value=False): |
331 | 333 | if hasattr(module, "gradient_checkpointing"): |
332 | 334 | module.gradient_checkpointing = value |
333 | 335 |
|
334 | | - def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool) -> None: |
335 | | - """Sets IP-Adapter attention processors, image projection, and loads state_dict. |
336 | | -
|
337 | | - Args: |
338 | | - state_dict (`Dict`): |
339 | | - PyTorch state dict with keys "ip_adapter", which contains parameters for attention processors, and |
340 | | - "image_proj", which contains parameters for image projection net. |
341 | | - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): |
342 | | - Speed up model loading only loading the pretrained weights and not initializing the weights. This also |
343 | | - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. |
344 | | - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this |
345 | | - argument to `True` will raise an error. |
346 | | - """ |
347 | | - # IP-Adapter cross attention parameters |
348 | | - hidden_size = self.config.attention_head_dim * self.config.num_attention_heads |
349 | | - ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads |
350 | | - timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1] |
351 | | - |
352 | | - # Dict where key is transformer layer index, value is attention processor's state dict |
353 | | - # ip_adapter state dict keys example: "0.norm_ip.linear.weight" |
354 | | - layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))} |
355 | | - for key, weights in state_dict["ip_adapter"].items(): |
356 | | - idx, name = key.split(".", maxsplit=1) |
357 | | - layer_state_dict[int(idx)][name] = weights |
358 | | - |
359 | | - # Create IP-Adapter attention processor |
360 | | - attn_procs = {} |
361 | | - for idx, name in enumerate(self.attn_processors.keys()): |
362 | | - attn_procs[name] = IPAdapterJointAttnProcessor2_0( |
363 | | - hidden_size=hidden_size, |
364 | | - ip_hidden_states_dim=ip_hidden_states_dim, |
365 | | - head_dim=self.config.attention_head_dim, |
366 | | - timesteps_emb_dim=timesteps_emb_dim, |
367 | | - ).to(self.device, dtype=self.dtype) |
368 | | - |
369 | | - if not low_cpu_mem_usage: |
370 | | - attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) |
371 | | - else: |
372 | | - load_model_dict_into_meta( |
373 | | - attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype |
374 | | - ) |
375 | | - |
376 | | - self.set_attn_processor(attn_procs) |
377 | | - |
378 | | - # Convert image_proj state dict to diffusers |
379 | | - image_proj_state_dict = {} |
380 | | - for key, value in state_dict["image_proj"].items(): |
381 | | - if key.startswith("layers."): |
382 | | - idx = key.split(".")[1] |
383 | | - key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0") |
384 | | - key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1") |
385 | | - key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q") |
386 | | - key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv") |
387 | | - key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0") |
388 | | - key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm") |
389 | | - key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj") |
390 | | - key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2") |
391 | | - key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj") |
392 | | - image_proj_state_dict[key] = value |
393 | | - |
394 | | - # Image projetion parameters |
395 | | - embed_dim = image_proj_state_dict["proj_in.weight"].shape[1] |
396 | | - output_dim = image_proj_state_dict["proj_out.weight"].shape[0] |
397 | | - hidden_dim = image_proj_state_dict["proj_in.weight"].shape[0] |
398 | | - heads = image_proj_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 |
399 | | - num_queries = image_proj_state_dict["latents"].shape[1] |
400 | | - timestep_in_dim = image_proj_state_dict["time_embedding.linear_1.weight"].shape[1] |
401 | | - |
402 | | - # Image projection |
403 | | - self.image_proj = IPAdapterTimeImageProjection( |
404 | | - embed_dim=embed_dim, |
405 | | - output_dim=output_dim, |
406 | | - hidden_dim=hidden_dim, |
407 | | - heads=heads, |
408 | | - num_queries=num_queries, |
409 | | - timestep_in_dim=timestep_in_dim, |
410 | | - ).to(device=self.device, dtype=self.dtype) |
411 | | - |
412 | | - if not low_cpu_mem_usage: |
413 | | - self.image_proj.load_state_dict(image_proj_state_dict, strict=True) |
414 | | - else: |
415 | | - load_model_dict_into_meta(self.image_proj, image_proj_state_dict, device=self.device, dtype=self.dtype) |
416 | | - |
417 | 336 | def forward( |
418 | 337 | self, |
419 | 338 | hidden_states: torch.FloatTensor, |
|
0 commit comments