|
16 | 16 | from typing import Callable, Dict, List, Optional, Union |
17 | 17 |
|
18 | 18 | import torch |
| 19 | +import re |
19 | 20 | from huggingface_hub.utils import validate_hf_hub_args |
20 | 21 |
|
21 | 22 | from ..utils import ( |
@@ -4805,50 +4806,152 @@ def lora_state_dict( |
4805 | 4806 | return state_dict |
4806 | 4807 |
|
4807 | 4808 | @classmethod |
4808 | | - def _maybe_expand_t2v_lora_for_i2v( |
4809 | | - cls, |
4810 | | - transformer: torch.nn.Module, |
4811 | | - state_dict, |
4812 | | - ): |
4813 | | - if transformer.config.image_dim is None: |
4814 | | - return state_dict |
| 4809 | + def _modified_maybe_expand_t2v_lora( # Renamed for clarity |
| 4810 | + # cls, # if it were a classmethod |
| 4811 | + transformer: torch.nn.Module, |
| 4812 | + state_dict: Dict[str, torch.Tensor], |
| 4813 | + lora_filename_for_rank_inference: Optional[str] = None # Optional: for rank hint |
| 4814 | + ) -> Dict[str, torch.Tensor]: |
4815 | 4815 |
|
4816 | 4816 | target_device = transformer.device |
| 4817 | + # Default dtype from transformer, can be refined if LoRA weights have a different one |
| 4818 | + lora_weights_dtype = next(iter(transformer.parameters())).dtype |
| 4819 | + |
| 4820 | + # --- Infer LoRA rank and potentially refine dtype from existing LoRA weights --- |
| 4821 | + inferred_rank = None |
| 4822 | + if state_dict: # If LoRA state_dict already has entries from the T2V LoRA |
| 4823 | + for k, v_tensor in state_dict.items(): |
| 4824 | + if k.endswith(".lora_A.weight"): # Standard LoRA weight key part |
| 4825 | + inferred_rank = v_tensor.shape[0] # rank is the output dim of lora_A |
| 4826 | + lora_weights_dtype = v_tensor.dtype # Use dtype of existing LoRA weights |
| 4827 | + break # Found rank and dtype |
| 4828 | + |
| 4829 | + if inferred_rank is None and lora_filename_for_rank_inference: |
| 4830 | + match = re.search(r"rank(\d+)", lora_filename_for_rank_inference, re.IGNORECASE) |
| 4831 | + if match: |
| 4832 | + inferred_rank = int(match.group(1)) |
| 4833 | + print(f"INFO: Inferred LoRA rank {inferred_rank} from filename for padding.") |
| 4834 | + |
| 4835 | + # Determine if the original LoRA format (the T2V part) uses biases for lora_B |
| 4836 | + lora_format_has_bias = any(".lora_B.bias" in k for k in state_dict.keys()) |
| 4837 | + |
| 4838 | + # --- Part 1: Original I2V expansion for standard transformer.blocks --- |
| 4839 | + # (Assuming transformer.config and transformer.blocks structure for this part) |
| 4840 | + if hasattr(transformer, 'config') and hasattr(transformer.config, 'image_dim') and \ |
| 4841 | + transformer.config.image_dim is not None and hasattr(transformer, 'blocks'): |
| 4842 | + |
| 4843 | + standard_block_keys_present = any(k.startswith("transformer.blocks.") for k in state_dict) |
| 4844 | + |
| 4845 | + if standard_block_keys_present and inferred_rank is not None: |
| 4846 | + num_blocks_in_lora = 0 |
| 4847 | + block_indices = set() |
| 4848 | + for k_lora in state_dict: |
| 4849 | + if "transformer.blocks." in k_lora: |
| 4850 | + try: |
| 4851 | + block_idx_str = k_lora.split("transformer.blocks.")[1].split(".")[0] |
| 4852 | + if block_idx_str.isdigit(): |
| 4853 | + block_indices.add(int(block_idx_str)) |
| 4854 | + except IndexError: |
| 4855 | + pass |
| 4856 | + if block_indices: |
| 4857 | + num_blocks_in_lora = max(block_indices) + 1 |
| 4858 | + |
| 4859 | + is_i2v_lora_standard_blocks = any( |
| 4860 | + k.startswith("transformer.blocks.") and "add_k_proj" in k for k in state_dict |
| 4861 | + ) and any( |
| 4862 | + k.startswith("transformer.blocks.") and "add_v_proj" in k for k in state_dict |
| 4863 | + ) |
4817 | 4864 |
|
4818 | | - if any(k.startswith("transformer.blocks.") for k in state_dict): |
4819 | | - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k}) |
4820 | | - is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) |
4821 | | - has_bias = any(".lora_B.bias" in k for k in state_dict) |
4822 | | - |
4823 | | - if is_i2v_lora: |
4824 | | - return state_dict |
4825 | | - |
4826 | | - for i in range(num_blocks): |
4827 | | - for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): |
4828 | | - # These keys should exist if the block `i` was part of the T2V LoRA. |
4829 | | - ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight" |
4830 | | - ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight" |
4831 | | - |
4832 | | - if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict: |
4833 | | - continue |
4834 | | - |
4835 | | - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( |
4836 | | - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device |
4837 | | - ) |
4838 | | - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( |
4839 | | - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device |
4840 | | - ) |
4841 | | - |
4842 | | - # If the original LoRA had biases (indicated by has_bias) |
4843 | | - # AND the specific reference bias key exists for this block. |
4844 | | - |
4845 | | - ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias" |
4846 | | - if has_bias and ref_key_lora_B_bias in state_dict: |
4847 | | - ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias] |
4848 | | - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like( |
4849 | | - ref_lora_B_bias_tensor, |
4850 | | - device=target_device, |
4851 | | - ) |
| 4865 | + if not is_i2v_lora_standard_blocks and num_blocks_in_lora > 0: |
| 4866 | + print(f"INFO: Expanding T2V LoRA for I2V compatibility (standard blocks). Rank: {inferred_rank}") |
| 4867 | + for i in range(num_blocks_in_lora): |
| 4868 | + # Check if block 'i' relevant parts are in the T2V LoRA |
| 4869 | + ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight" |
| 4870 | + if ref_key_lora_A not in state_dict: |
| 4871 | + continue # This block's specific part wasn't in the LoRA. |
| 4872 | + |
| 4873 | + try: |
| 4874 | + model_block = transformer.blocks[i] |
| 4875 | + # Ensure these target layers exist in the model's standard block |
| 4876 | + if not (hasattr(model_block, 'attn2') and \ |
| 4877 | + hasattr(model_block.attn2, 'add_k_proj') and \ |
| 4878 | + hasattr(model_block.attn2, 'add_v_proj')): |
| 4879 | + continue |
| 4880 | + add_k_proj_layer = model_block.attn2.add_k_proj |
| 4881 | + add_v_proj_layer = model_block.attn2.add_v_proj |
| 4882 | + except (AttributeError, IndexError): |
| 4883 | + print(f"WARN: Cannot access standard block {i} or its I2V layers for expansion.") |
| 4884 | + continue |
| 4885 | + |
| 4886 | + for proj_name_suffix, model_linear_layer in [("add_k_proj", add_k_proj_layer), |
| 4887 | + ("add_v_proj", add_v_proj_layer)]: |
| 4888 | + if not isinstance(model_linear_layer, nn.Linear): continue |
| 4889 | + |
| 4890 | + lora_A_key = f"transformer.blocks.{i}.attn2.{proj_name_suffix}.lora_A.weight" |
| 4891 | + lora_B_key = f"transformer.blocks.{i}.attn2.{proj_name_suffix}.lora_B.weight" |
| 4892 | + |
| 4893 | + if lora_A_key not in state_dict: |
| 4894 | + state_dict[lora_A_key] = torch.zeros( |
| 4895 | + (inferred_rank, model_linear_layer.in_features), |
| 4896 | + device=target_device, dtype=lora_weights_dtype |
| 4897 | + ) |
| 4898 | + if lora_B_key not in state_dict: |
| 4899 | + state_dict[lora_B_key] = torch.zeros( |
| 4900 | + (model_linear_layer.out_features, inferred_rank), |
| 4901 | + device=target_device, dtype=lora_weights_dtype |
| 4902 | + ) |
| 4903 | + |
| 4904 | + if lora_format_has_bias and model_linear_layer.bias is not None: |
| 4905 | + lora_B_bias_key = f"transformer.blocks.{i}.attn2.{proj_name_suffix}.lora_B.bias" |
| 4906 | + if lora_B_bias_key not in state_dict: |
| 4907 | + state_dict[lora_B_bias_key] = torch.zeros_like( |
| 4908 | + model_linear_layer.bias, device=target_device, |
| 4909 | + dtype=model_linear_layer.bias.dtype |
| 4910 | + ) |
| 4911 | + elif inferred_rank is None: |
| 4912 | + print("INFO: LoRA rank not inferred. Skipping I2V expansion for standard blocks.") |
| 4913 | + # else: not standard_block_keys_present or no I2V capability. |
| 4914 | + |
| 4915 | + # --- Part 2: Pad LoRA for WanVACETransformer3DModel vace_blocks.X.proj_out --- |
| 4916 | + # Dynamically check for WanVACETransformer3DModel availability for isinstance |
| 4917 | + VACEModelClass = globals().get("WanVACETransformer3DModel") |
| 4918 | + |
| 4919 | + if VACEModelClass and isinstance(transformer, VACEModelClass) and hasattr(transformer, 'vace_blocks'): |
| 4920 | + if inferred_rank is None: |
| 4921 | + print("WARNING: LoRA rank not determined. Skipping VACE block padding for proj_out.") |
| 4922 | + else: |
| 4923 | + print(f"INFO: Transformer is WanVACE. Padding LoRA for vace_blocks.X.proj_out. Rank: {inferred_rank}") |
| 4924 | + for i, vace_block_module in enumerate(transformer.vace_blocks): |
| 4925 | + if hasattr(vace_block_module, 'proj_out') and isinstance(vace_block_module.proj_out, nn.Linear): |
| 4926 | + proj_out_layer = vace_block_module.proj_out |
| 4927 | + |
| 4928 | + # Keys for the vace_block's proj_out LoRA layers |
| 4929 | + # These are the keys PEFT expects in the state_dict *before* adding adapter name context |
| 4930 | + lora_A_key = f"vace_blocks.{i}.proj_out.lora_A.weight" |
| 4931 | + lora_B_key = f"vace_blocks.{i}.proj_out.lora_B.weight" |
| 4932 | + |
| 4933 | + if lora_A_key not in state_dict: |
| 4934 | + state_dict[lora_A_key] = torch.zeros( |
| 4935 | + (inferred_rank, proj_out_layer.in_features), |
| 4936 | + device=target_device, dtype=lora_weights_dtype |
| 4937 | + ) |
| 4938 | + # print(f"Padded: {lora_A_key}") |
| 4939 | + |
| 4940 | + if lora_B_key not in state_dict: |
| 4941 | + state_dict[lora_B_key] = torch.zeros( |
| 4942 | + (proj_out_layer.out_features, inferred_rank), |
| 4943 | + device=target_device, dtype=lora_weights_dtype |
| 4944 | + ) |
| 4945 | + # print(f"Padded: {lora_B_key}") |
| 4946 | + |
| 4947 | + if lora_format_has_bias and proj_out_layer.bias is not None: |
| 4948 | + lora_B_bias_key = f"vace_blocks.{i}.proj_out.lora_B.bias" |
| 4949 | + if lora_B_bias_key not in state_dict: |
| 4950 | + state_dict[lora_B_bias_key] = torch.zeros_like( |
| 4951 | + proj_out_layer.bias, device=target_device, dtype=proj_out_layer.bias.dtype |
| 4952 | + ) |
| 4953 | + # print(f"Padded: {lora_B_bias_key}") |
| 4954 | + # else: VACE block 'i' might not have proj_out or it's not Linear. |
4852 | 4955 |
|
4853 | 4956 | return state_dict |
4854 | 4957 |
|
|
0 commit comments