Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/qwen_image/qwen_2511_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# Suitable for RTX 30/40/50 consumer GPUs
# pipe.enable_offload(
# cpu_offload=True,
# offload_granularity="block",
# offload_granularity="block", #["block", "phase"]
# text_encoder_offload=True,
# vae_offload=False,
# )
Expand Down
2 changes: 1 addition & 1 deletion examples/qwen_image/qwen_2511_with_distill_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# Suitable for RTX 30/40/50 consumer GPUs
# pipe.enable_offload(
# cpu_offload=True,
# offload_granularity="block",
# offload_granularity="block", #["block", "phase"]
# text_encoder_offload=True,
# vae_offload=False,
# )
Expand Down
93 changes: 9 additions & 84 deletions lightx2v/models/networks/wan/lora_adapter.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import gc
import os

import torch
from loguru import logger
from safetensors import safe_open

from lightx2v.utils.envs import *
from lightx2v.utils.lora_loader import LoRALoader


class WanLoraWrapper:
def __init__(self, wan_model):
self.model = wan_model
self.lora_metadata = {}
self.lora_loader = LoRALoader()
self.override_dict = {} # On CPU

def load_lora(self, lora_path, lora_name=None):
Expand Down Expand Up @@ -41,91 +41,16 @@ def apply_lora(self, lora_name, alpha=1.0):
return False

lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"])

weight_dict = self.model.original_weight_dict
self._apply_lora_weights(weight_dict, lora_weights, alpha)
self.lora_loader.apply_lora(
weight_dict=weight_dict,
lora_weights=lora_weights,
alpha=alpha,
strength=alpha,
)
Comment on lines +46 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call to self.lora_loader.apply_lora passes the alpha value to both the alpha and strength parameters. Based on the implementation of LoRALoader, this will result in the LoRA weights being scaled by alpha twice (e.g., proportional to alpha^2 for diff-based LoRAs), which is likely not the intended behavior and differs from the original implementation.

The strength parameter seems to be intended for an additional, separate strength multiplier, and it defaults to 1.0. To match the expected behavior of applying the LoRA strength once, you should probably omit the strength parameter in this call.

        self.lora_loader.apply_lora(
            weight_dict=weight_dict,
            lora_weights=lora_weights,
            alpha=alpha,
        )

self.model._apply_weights(weight_dict)

logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
del lora_weights
return True

@torch.no_grad()
def _apply_lora_weights(self, weight_dict, lora_weights, alpha):
lora_pairs = {}
lora_diffs = {}

def try_lora_pair(key, prefix, suffix_a, suffix_b, target_suffix):
if key.endswith(suffix_a):
base_name = key[len(prefix) :].replace(suffix_a, target_suffix)
pair_key = key.replace(suffix_a, suffix_b)
if pair_key in lora_weights:
lora_pairs[base_name] = (key, pair_key)

def try_lora_diff(key, prefix, suffix, target_suffix):
if key.endswith(suffix):
base_name = key[len(prefix) :].replace(suffix, target_suffix)
lora_diffs[base_name] = key

prefixs = [
"", # empty prefix
"diffusion_model.",
]
for prefix in prefixs:
for key in lora_weights.keys():
if not key.startswith(prefix):
continue

try_lora_pair(key, prefix, "lora_A.weight", "lora_B.weight", "weight")
try_lora_pair(key, prefix, "lora_down.weight", "lora_up.weight", "weight")
try_lora_diff(key, prefix, "diff", "weight")
try_lora_diff(key, prefix, "diff_b", "bias")
try_lora_diff(key, prefix, "diff_m", "modulation")

applied_count = 0
for name, param in weight_dict.items():
if name in lora_pairs:
if name not in self.override_dict:
self.override_dict[name] = param.clone().cpu()
name_lora_A, name_lora_B = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
if param.shape == (lora_B.shape[0], lora_A.shape[1]):
param += torch.matmul(lora_B, lora_A) * alpha
applied_count += 1
elif name in lora_diffs:
if name not in self.override_dict:
self.override_dict[name] = param.clone().cpu()

name_diff = lora_diffs[name]
lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
if param.shape == lora_diff.shape:
param += lora_diff * alpha
applied_count += 1

logger.info(f"Applied {applied_count} LoRA weight adjustments")
if applied_count == 0:
logger.info(
"Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model.<layer_name>.lora_A.weight' and 'diffusion_model.<layer_name>.lora_B.weight'. Please verify the LoRA weight file."
)

@torch.no_grad()
def remove_lora(self):
logger.info(f"Removing LoRA ...")

restored_count = 0
for k, v in self.override_dict.items():
self.model.original_weight_dict[k] = v.to(self.model.device)
restored_count += 1

logger.info(f"LoRA removed, restored {restored_count} weights")

self.model._apply_weights(self.model.original_weight_dict)

torch.cuda.empty_cache()
gc.collect()

self.lora_metadata = {}
self.override_dict = {}

def list_loaded_loras(self):
return list(self.lora_metadata.keys())
Comment on lines 52 to -131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This refactoring removes the remove_lora and list_loaded_loras methods. The removal of remove_lora is a significant functionality regression, as it's no longer possible to unload a LoRA after it has been applied.

The underlying mechanism that supported this feature, which involved storing original weights in self.override_dict, has also been removed from the LoRA application logic. As a result, the self.override_dict attribute in WanLoraWrapper is now unused.

If the ability to remove LoRAs is a required feature, this functionality needs to be reinstated. If it's intentionally being removed, the now-dead code (self.override_dict) should also be removed from the __init__ method for clarity.