diff --git a/examples/qwen_image/qwen_2511_fp8.py b/examples/qwen_image/qwen_2511_fp8.py index b81618f9..907ddcd8 100755 --- a/examples/qwen_image/qwen_2511_fp8.py +++ b/examples/qwen_image/qwen_2511_fp8.py @@ -22,7 +22,7 @@ # Suitable for RTX 30/40/50 consumer GPUs # pipe.enable_offload( # cpu_offload=True, -# offload_granularity="block", +# offload_granularity="block", #["block", "phase"] # text_encoder_offload=True, # vae_offload=False, # ) diff --git a/examples/qwen_image/qwen_2511_with_distill_lora.py b/examples/qwen_image/qwen_2511_with_distill_lora.py index 0e8736df..b1955f27 100755 --- a/examples/qwen_image/qwen_2511_with_distill_lora.py +++ b/examples/qwen_image/qwen_2511_with_distill_lora.py @@ -22,7 +22,7 @@ # Suitable for RTX 30/40/50 consumer GPUs # pipe.enable_offload( # cpu_offload=True, -# offload_granularity="block", +# offload_granularity="block", #["block", "phase"] # text_encoder_offload=True, # vae_offload=False, # ) diff --git a/lightx2v/models/networks/wan/lora_adapter.py b/lightx2v/models/networks/wan/lora_adapter.py index 93a46923..0fd31887 100755 --- a/lightx2v/models/networks/wan/lora_adapter.py +++ b/lightx2v/models/networks/wan/lora_adapter.py @@ -1,17 +1,17 @@ -import gc import os -import torch from loguru import logger from safetensors import safe_open from lightx2v.utils.envs import * +from lightx2v.utils.lora_loader import LoRALoader class WanLoraWrapper: def __init__(self, wan_model): self.model = wan_model self.lora_metadata = {} + self.lora_loader = LoRALoader() self.override_dict = {} # On CPU def load_lora(self, lora_path, lora_name=None): @@ -41,91 +41,16 @@ def apply_lora(self, lora_name, alpha=1.0): return False lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"]) + weight_dict = self.model.original_weight_dict - self._apply_lora_weights(weight_dict, lora_weights, alpha) + self.lora_loader.apply_lora( + weight_dict=weight_dict, + lora_weights=lora_weights, + alpha=alpha, + strength=alpha, + ) self.model._apply_weights(weight_dict) logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") del lora_weights return True - - @torch.no_grad() - def _apply_lora_weights(self, weight_dict, lora_weights, alpha): - lora_pairs = {} - lora_diffs = {} - - def try_lora_pair(key, prefix, suffix_a, suffix_b, target_suffix): - if key.endswith(suffix_a): - base_name = key[len(prefix) :].replace(suffix_a, target_suffix) - pair_key = key.replace(suffix_a, suffix_b) - if pair_key in lora_weights: - lora_pairs[base_name] = (key, pair_key) - - def try_lora_diff(key, prefix, suffix, target_suffix): - if key.endswith(suffix): - base_name = key[len(prefix) :].replace(suffix, target_suffix) - lora_diffs[base_name] = key - - prefixs = [ - "", # empty prefix - "diffusion_model.", - ] - for prefix in prefixs: - for key in lora_weights.keys(): - if not key.startswith(prefix): - continue - - try_lora_pair(key, prefix, "lora_A.weight", "lora_B.weight", "weight") - try_lora_pair(key, prefix, "lora_down.weight", "lora_up.weight", "weight") - try_lora_diff(key, prefix, "diff", "weight") - try_lora_diff(key, prefix, "diff_b", "bias") - try_lora_diff(key, prefix, "diff_m", "modulation") - - applied_count = 0 - for name, param in weight_dict.items(): - if name in lora_pairs: - if name not in self.override_dict: - self.override_dict[name] = param.clone().cpu() - name_lora_A, name_lora_B = lora_pairs[name] - lora_A = lora_weights[name_lora_A].to(param.device, param.dtype) - lora_B = lora_weights[name_lora_B].to(param.device, param.dtype) - if param.shape == (lora_B.shape[0], lora_A.shape[1]): - param += torch.matmul(lora_B, lora_A) * alpha - applied_count += 1 - elif name in lora_diffs: - if name not in self.override_dict: - self.override_dict[name] = param.clone().cpu() - - name_diff = lora_diffs[name] - lora_diff = lora_weights[name_diff].to(param.device, param.dtype) - if param.shape == lora_diff.shape: - param += lora_diff * alpha - applied_count += 1 - - logger.info(f"Applied {applied_count} LoRA weight adjustments") - if applied_count == 0: - logger.info( - "Warning: No LoRA weights were applied. Expected naming conventions: 'diffusion_model..lora_A.weight' and 'diffusion_model..lora_B.weight'. Please verify the LoRA weight file." - ) - - @torch.no_grad() - def remove_lora(self): - logger.info(f"Removing LoRA ...") - - restored_count = 0 - for k, v in self.override_dict.items(): - self.model.original_weight_dict[k] = v.to(self.model.device) - restored_count += 1 - - logger.info(f"LoRA removed, restored {restored_count} weights") - - self.model._apply_weights(self.model.original_weight_dict) - - torch.cuda.empty_cache() - gc.collect() - - self.lora_metadata = {} - self.override_dict = {} - - def list_loaded_loras(self): - return list(self.lora_metadata.keys())