Skip to content

Commit 7927554

Browse files
authored
refactor wan lora loader (#691)
Co-authored-by: gushiqiao <975033167>
1 parent eec8238 commit 7927554

File tree

3 files changed

+11
-86
lines changed

3 files changed

+11
-86
lines changed

examples/qwen_image/qwen_2511_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# Suitable for RTX 30/40/50 consumer GPUs
2323
# pipe.enable_offload(
2424
# cpu_offload=True,
25-
# offload_granularity="block",
25+
# offload_granularity="block", #["block", "phase"]
2626
# text_encoder_offload=True,
2727
# vae_offload=False,
2828
# )

examples/qwen_image/qwen_2511_with_distill_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# Suitable for RTX 30/40/50 consumer GPUs
2323
# pipe.enable_offload(
2424
# cpu_offload=True,
25-
# offload_granularity="block",
25+
# offload_granularity="block", #["block", "phase"]
2626
# text_encoder_offload=True,
2727
# vae_offload=False,
2828
# )
Lines changed: 9 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
import gc
21
import os
32

4-
import torch
53
from loguru import logger
64
from safetensors import safe_open
75

86
from lightx2v.utils.envs import *
7+
from lightx2v.utils.lora_loader import LoRALoader
98

109

1110
class WanLoraWrapper:
1211
def __init__(self, wan_model):
1312
self.model = wan_model
1413
self.lora_metadata = {}
14+
self.lora_loader = LoRALoader()
1515
self.override_dict = {} # On CPU
1616

1717
def load_lora(self, lora_path, lora_name=None):
@@ -41,91 +41,16 @@ def apply_lora(self, lora_name, alpha=1.0):
4141
return False
4242

4343
lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"])
44+
4445
weight_dict = self.model.original_weight_dict
45-
self._apply_lora_weights(weight_dict, lora_weights, alpha)
46+
self.lora_loader.apply_lora(
47+
weight_dict=weight_dict,
48+
lora_weights=lora_weights,
49+
alpha=alpha,
50+
strength=alpha,
51+
)
4652
self.model._apply_weights(weight_dict)
4753

4854
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
4955
del lora_weights
5056
return True
51-
52-
@torch.no_grad()
53-
def _apply_lora_weights(self, weight_dict, lora_weights, alpha):
54-
lora_pairs = {}
55-
lora_diffs = {}
56-
57-
def try_lora_pair(key, prefix, suffix_a, suffix_b, target_suffix):
58-
if key.endswith(suffix_a):
59-
base_name = key[len(prefix) :].replace(suffix_a, target_suffix)
60-
pair_key = key.replace(suffix_a, suffix_b)
61-
if pair_key in lora_weights:
62-
lora_pairs[base_name] = (key, pair_key)
63-
64-
def try_lora_diff(key, prefix, suffix, target_suffix):
65-
if key.endswith(suffix):
66-
base_name = key[len(prefix) :].replace(suffix, target_suffix)
67-
lora_diffs[base_name] = key
68-
69-
prefixs = [
70-
"", # empty prefix
71-
"diffusion_model.",
72-
]
73-
for prefix in prefixs:
74-
for key in lora_weights.keys():
75-
if not key.startswith(prefix):
76-
continue
77-
78-
try_lora_pair(key, prefix, "lora_A.weight", "lora_B.weight", "weight")
79-
try_lora_pair(key, prefix, "lora_down.weight", "lora_up.weight", "weight")
80-
try_lora_diff(key, prefix, "diff", "weight")
81-
try_lora_diff(key, prefix, "diff_b", "bias")
82-
try_lora_diff(key, prefix, "diff_m", "modulation")
83-
84-
applied_count = 0
85-
for name, param in weight_dict.items():
86-
if name in lora_pairs:
87-
if name not in self.override_dict:
88-
self.override_dict[name] = param.clone().cpu()
89-
name_lora_A, name_lora_B = lora_pairs[name]
90-
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
91-
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
92-
if param.shape == (lora_B.shape[0], lora_A.shape[1]):
93-
param += torch.matmul(lora_B, lora_A) * alpha
94-
applied_count += 1
95-
elif name in lora_diffs:
96-
if name not in self.override_dict:
97-
self.override_dict[name] = param.clone().cpu()
98-
99-
name_diff = lora_diffs[name]
100-
lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
101-
if param.shape == lora_diff.shape:
102-
param += lora_diff * alpha
103-
applied_count += 1
104-
105-
logger.info(f"Applied {applied_count} LoRA weight adjustments")
106-
if applied_count == 0:
107-
logger.info(
108-
"Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file."
109-
)
110-
111-
@torch.no_grad()
112-
def remove_lora(self):
113-
logger.info(f"Removing LoRA ...")
114-
115-
restored_count = 0
116-
for k, v in self.override_dict.items():
117-
self.model.original_weight_dict[k] = v.to(self.model.device)
118-
restored_count += 1
119-
120-
logger.info(f"LoRA removed, restored {restored_count} weights")
121-
122-
self.model._apply_weights(self.model.original_weight_dict)
123-
124-
torch.cuda.empty_cache()
125-
gc.collect()
126-
127-
self.lora_metadata = {}
128-
self.override_dict = {}
129-
130-
def list_loaded_loras(self):
131-
return list(self.lora_metadata.keys())

0 commit comments

Comments
 (0)