|
1 | | -import gc |
2 | 1 | import os |
3 | 2 |
|
4 | | -import torch |
5 | 3 | from loguru import logger |
6 | 4 | from safetensors import safe_open |
7 | 5 |
|
8 | 6 | from lightx2v.utils.envs import * |
| 7 | +from lightx2v.utils.lora_loader import LoRALoader |
9 | 8 |
|
10 | 9 |
|
11 | 10 | class WanLoraWrapper: |
12 | 11 | def __init__(self, wan_model): |
13 | 12 | self.model = wan_model |
14 | 13 | self.lora_metadata = {} |
| 14 | + self.lora_loader = LoRALoader() |
15 | 15 | self.override_dict = {} # On CPU |
16 | 16 |
|
17 | 17 | def load_lora(self, lora_path, lora_name=None): |
@@ -41,91 +41,16 @@ def apply_lora(self, lora_name, alpha=1.0): |
41 | 41 | return False |
42 | 42 |
|
43 | 43 | lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"]) |
| 44 | + |
44 | 45 | weight_dict = self.model.original_weight_dict |
45 | | - self._apply_lora_weights(weight_dict, lora_weights, alpha) |
| 46 | + self.lora_loader.apply_lora( |
| 47 | + weight_dict=weight_dict, |
| 48 | + lora_weights=lora_weights, |
| 49 | + alpha=alpha, |
| 50 | + strength=alpha, |
| 51 | + ) |
46 | 52 | self.model._apply_weights(weight_dict) |
47 | 53 |
|
48 | 54 | logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") |
49 | 55 | del lora_weights |
50 | 56 | return True |
51 | | - |
52 | | - @torch.no_grad() |
53 | | - def _apply_lora_weights(self, weight_dict, lora_weights, alpha): |
54 | | - lora_pairs = {} |
55 | | - lora_diffs = {} |
56 | | - |
57 | | - def try_lora_pair(key, prefix, suffix_a, suffix_b, target_suffix): |
58 | | - if key.endswith(suffix_a): |
59 | | - base_name = key[len(prefix) :].replace(suffix_a, target_suffix) |
60 | | - pair_key = key.replace(suffix_a, suffix_b) |
61 | | - if pair_key in lora_weights: |
62 | | - lora_pairs[base_name] = (key, pair_key) |
63 | | - |
64 | | - def try_lora_diff(key, prefix, suffix, target_suffix): |
65 | | - if key.endswith(suffix): |
66 | | - base_name = key[len(prefix) :].replace(suffix, target_suffix) |
67 | | - lora_diffs[base_name] = key |
68 | | - |
69 | | - prefixs = [ |
70 | | - "", # empty prefix |
71 | | - "diffusion_model.", |
72 | | - ] |
73 | | - for prefix in prefixs: |
74 | | - for key in lora_weights.keys(): |
75 | | - if not key.startswith(prefix): |
76 | | - continue |
77 | | - |
78 | | - try_lora_pair(key, prefix, "lora_A.weight", "lora_B.weight", "weight") |
79 | | - try_lora_pair(key, prefix, "lora_down.weight", "lora_up.weight", "weight") |
80 | | - try_lora_diff(key, prefix, "diff", "weight") |
81 | | - try_lora_diff(key, prefix, "diff_b", "bias") |
82 | | - try_lora_diff(key, prefix, "diff_m", "modulation") |
83 | | - |
84 | | - applied_count = 0 |
85 | | - for name, param in weight_dict.items(): |
86 | | - if name in lora_pairs: |
87 | | - if name not in self.override_dict: |
88 | | - self.override_dict[name] = param.clone().cpu() |
89 | | - name_lora_A, name_lora_B = lora_pairs[name] |
90 | | - lora_A = lora_weights[name_lora_A].to(param.device, param.dtype) |
91 | | - lora_B = lora_weights[name_lora_B].to(param.device, param.dtype) |
92 | | - if param.shape == (lora_B.shape[0], lora_A.shape[1]): |
93 | | - param += torch.matmul(lora_B, lora_A) * alpha |
94 | | - applied_count += 1 |
95 | | - elif name in lora_diffs: |
96 | | - if name not in self.override_dict: |
97 | | - self.override_dict[name] = param.clone().cpu() |
98 | | - |
99 | | - name_diff = lora_diffs[name] |
100 | | - lora_diff = lora_weights[name_diff].to(param.device, param.dtype) |
101 | | - if param.shape == lora_diff.shape: |
102 | | - param += lora_diff * alpha |
103 | | - applied_count += 1 |
104 | | - |
105 | | - logger.info(f"Applied {applied_count} LoRA weight adjustments") |
106 | | - if applied_count == 0: |
107 | | - logger.info( |
108 | | - "Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file." |
109 | | - ) |
110 | | - |
111 | | - @torch.no_grad() |
112 | | - def remove_lora(self): |
113 | | - logger.info(f"Removing LoRA ...") |
114 | | - |
115 | | - restored_count = 0 |
116 | | - for k, v in self.override_dict.items(): |
117 | | - self.model.original_weight_dict[k] = v.to(self.model.device) |
118 | | - restored_count += 1 |
119 | | - |
120 | | - logger.info(f"LoRA removed, restored {restored_count} weights") |
121 | | - |
122 | | - self.model._apply_weights(self.model.original_weight_dict) |
123 | | - |
124 | | - torch.cuda.empty_cache() |
125 | | - gc.collect() |
126 | | - |
127 | | - self.lora_metadata = {} |
128 | | - self.override_dict = {} |
129 | | - |
130 | | - def list_loaded_loras(self): |
131 | | - return list(self.lora_metadata.keys()) |
0 commit comments