diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..d0b98d91 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +diffsynth.egg-info/ +diffsynth/__pycache__/ +diffsynth/configs/__pycache__/ +diffsynth/controlnets/__pycache__/ +diffsynth/data/__pycache__/ +diffsynth/extensions/ESRGAN/__pycache__/ +diffsynth/extensions/RIFE/__pycache__/ +diffsynth/extensions/__pycache__/ +diffsynth/models/__pycache__/ +diffsynth/pipelines/__pycache__/ +diffsynth/processors/__pycache__/ +diffsynth/prompters/__pycache__/ +diffsynth/schedulers/__pycache__/ +diffsynth/vram_management/__pycache__/ +.vs \ No newline at end of file diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index 7d4f52db..3db3cbba 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -1,4 +1,11 @@ import torch +import time +from tqdm import tqdm +import psutil +import gc +import os +import platform +import multiprocessing from .sd_unet import SDUNet from .sdxl_unet import SDXLUNet from .sd_text_encoder import SDTextEncoder @@ -10,7 +17,67 @@ from .hunyuan_video_dit import HunyuanVideoDiT from .wan_video_dit import WanModel +# Global debug variable: when set to False, only minimal info is printed. +DEBUG = False +def debug_print(*args, **kwargs): + """Print debug messages only if DEBUG is True.""" + if DEBUG: + print(*args, **kwargs) + +def timing_decorator(func): + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + elapsed_time = end_time - start_time + if DEBUG: + print(f"⏱️ {func.__name__} took {elapsed_time:.4f} seconds") + return result + return wrapper + +def memory_usage(): + """Get current memory usage of the process""" + process = psutil.Process(os.getpid()) + memory_info = process.memory_info() + return f"{memory_info.rss / (1024 * 1024):.1f} MB" + +def optimize_cpu_threading(): + """Set optimal thread configuration for the current CPU""" + cpu_count = multiprocessing.cpu_count() + + # Get processor information + processor = platform.processor().lower() + + if "amd" in processor: + optimal_threads = max(1, cpu_count) + else: # Intel or other + optimal_threads = max(1, cpu_count // 2) + + os.environ["OMP_NUM_THREADS"] = str(optimal_threads) + os.environ["MKL_NUM_THREADS"] = str(optimal_threads) + + blas_info = "unknown" + try: + import torch.__config__ + config_info = torch.__config__.show() + if "mkl" in config_info.lower(): + blas_info = "MKL" + elif "openblas" in config_info.lower(): + blas_info = "OpenBLAS" + except: + pass + + if DEBUG: + print(f"CPU Optimization: {platform.processor()}") + print(f"Physical cores: {cpu_count // 2}, Total threads: {cpu_count}") + print(f"Using {optimal_threads} threads for computation") + print(f"BLAS backend: {blas_info}") + else: + print(f"CPU threading optimized: using {optimal_threads} threads") + + torch.set_num_threads(optimal_threads) + return optimal_threads class LoRAFromCivitai: def __init__(self): @@ -18,110 +85,327 @@ def __init__(self): self.lora_prefix = [] self.renamed_lora_prefix = {} self.special_keys = {} + self.stats = { + "tensor_movements_to_gpu": 0, + "tensor_movements_to_cpu": 0, + "lora_weights_processed": 0, + "format_conversions": 0, + } + # Set optimal thread count for CPU operations + self.optimal_threads = optimize_cpu_threading() + self.use_gpu = torch.cuda.is_available() + + # Enable tensor cores for matrix operations if available + if self.use_gpu and hasattr(torch.backends, 'cudnn'): + torch.backends.cudnn.benchmark = True - + @timing_decorator def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): + if DEBUG: + print(f"Converting state dict with prefix {lora_prefix}, memory usage: {memory_usage()}") + # Detect format for key in state_dict: if ".lora_up" in key: + if DEBUG: + print(f"Detected up/down format, keys: {len(state_dict)}") return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha) + if DEBUG: + print(f"Detected A/B format, keys: {len(state_dict)}") return self.convert_state_dict_AB(state_dict, lora_prefix, alpha) - + @timing_decorator def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "") state_dict_ = {} + if DEBUG: + print(f"Processing up/down conversion for {len(state_dict)} tensors...") + + # Determine optimal processing device + device = "cuda" if self.use_gpu else "cpu" + torch_dtype = torch.float16 if self.use_gpu else torch.float32 + + # Count applicable keys first + applicable_keys = [] for key in state_dict: - if ".lora_up" not in key: - continue - if not key.startswith(lora_prefix): - continue - weight_up = state_dict[key].to(device="cuda", dtype=torch.float16) - weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16) - if len(weight_up.shape) == 4: - weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32) - weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32) - lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - else: - lora_weight = alpha * torch.mm(weight_up, weight_down) - target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight" - for special_key in self.special_keys: - target_name = target_name.replace(special_key, self.special_keys[special_key]) - state_dict_[target_name] = lora_weight.cpu() + if ".lora_up" in key and key.startswith(lora_prefix): + applicable_keys.append(key) + + # Prepare batches for processing + BATCH_SIZE = 16 # Adjust based on memory constraints + if DEBUG: + print(f"Processing {len(applicable_keys)} tensors in batches of {BATCH_SIZE}...") + + with tqdm(total=len(applicable_keys), desc="Converting up/down weights") as pbar: + for i in range(0, len(applicable_keys), BATCH_SIZE): + batch_keys = applicable_keys[i:i+BATCH_SIZE] + for key in batch_keys: + # Track GPU tensor movements + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) + weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device=device, dtype=torch_dtype) + self.stats["tensor_movements_to_gpu"] += 2 + + # Matrix multiplication - faster on GPU, or optimized CPU + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32) + weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32) + lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_up, weight_down) + + target_key = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight" + state_dict_[target_key] = lora_weight.cpu() + self.stats["tensor_movements_to_cpu"] += 1 + self.stats["lora_weights_processed"] += 1 + + # Apply special key replacements + for special_key in self.special_keys: + if special_key in target_key: + state_dict_[target_key] = state_dict_[target_key].replace(special_key, self.special_keys[special_key]) + + pbar.update(1) + + # Clear memory after each batch + del weight_up, weight_down, lora_weight + if self.use_gpu: + torch.cuda.empty_cache() + + if DEBUG: + print(f"Up/down conversion complete, resulting in {len(state_dict_)} tensors, memory: {memory_usage()}") + else: + print(f"LoRA conversion complete: {len(state_dict_)} tensors processed") return state_dict_ - - def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16): + @timing_decorator + def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0): state_dict_ = {} + # Determine optimal processing device + device = "cuda" if self.use_gpu else "cpu" + torch_dtype = torch.float16 if self.use_gpu else torch.float32 + + # Collect applicable keys first + applicable_keys = [] for key in state_dict: - if ".lora_B." not in key: - continue - if not key.startswith(lora_prefix): - continue - weight_up = state_dict[key].to(device=device, dtype=torch_dtype) - weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) - if len(weight_up.shape) == 4: - weight_up = weight_up.squeeze(3).squeeze(2) - weight_down = weight_down.squeeze(3).squeeze(2) - lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - else: - lora_weight = alpha * torch.mm(weight_up, weight_down) - keys = key.split(".") - keys.pop(keys.index("lora_B")) - target_name = ".".join(keys) - target_name = target_name[len(lora_prefix):] - state_dict_[target_name] = lora_weight.cpu() + if ".lora_B." in key and key.startswith(lora_prefix): + applicable_keys.append(key) + + # Prepare batches for processing + BATCH_SIZE = 16 # Adjust based on memory constraints + if DEBUG: + print(f"Processing {len(applicable_keys)} tensors in batches of {BATCH_SIZE}...") + + with tqdm(total=len(applicable_keys), desc="Converting A/B weights") as pbar: + for i in range(0, len(applicable_keys), BATCH_SIZE): + batch_keys = applicable_keys[i:i+BATCH_SIZE] + for key in batch_keys: + # Load and process tensors + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) + weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) + self.stats["tensor_movements_to_gpu"] += 2 + + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_up, weight_down) + + # Extract target name + keys = key.split(".") + keys.pop(keys.index("lora_B")) + target_name = ".".join(keys) + target_name = target_name[len(lora_prefix):] + + # Store result + state_dict_[target_name] = lora_weight.cpu() + self.stats["tensor_movements_to_cpu"] += 1 + self.stats["lora_weights_processed"] += 1 + pbar.update(1) + + # Clear memory after each batch + del weight_up, weight_down, lora_weight + if self.use_gpu: + torch.cuda.empty_cache() + + if DEBUG: + print(f"A/B conversion complete, resulting in {len(state_dict_)} tensors, memory: {memory_usage()}") + else: + print(f"LoRA conversion complete: {len(state_dict_)} tensors processed") return state_dict_ - + @timing_decorator def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None): - state_dict_model = model.state_dict() + print(f"Starting LoRA loading process for {model.__class__.__name__}...") + + # Measure state dict loading time - use direct parameter access + start_state_dict = time.time() + state_dict_model = {} + for name, param in model.named_parameters(): + state_dict_model[name] = param + end_state_dict = time.time() + if DEBUG: + print(f"⏱️ Loading model parameters took {end_state_dict - start_state_dict:.4f} seconds, size: {len(state_dict_model)} tensors") + else: + print(f"Model parameters mapped: {len(state_dict_model)} parameters") + + # Measure LoRA conversion time + start_convert = time.time() state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha) - if model_resource == "diffusers": - state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora) - elif model_resource == "civitai": - state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora) + self.stats["format_conversions"] += 1 + end_convert = time.time() + if DEBUG: + print(f"⏱️ LoRA conversion took {end_convert - start_convert:.4f} seconds") + + # Measure format conversion time if applicable + if model_resource: + if DEBUG: + print(f"Converting format from {model_resource}...") + start_format = time.time() + if model_resource == "diffusers": + state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora) + elif model_resource == "civitai": + state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora) + self.stats["format_conversions"] += 1 + end_format = time.time() + if DEBUG: + print(f"⏱️ Format conversion took {end_format - start_format:.4f} seconds") + if isinstance(state_dict_lora, tuple): state_dict_lora = state_dict_lora[0] + if len(state_dict_lora) > 0: - print(f" {len(state_dict_lora)} tensors are updated.") - for name in state_dict_lora: - fp8=False - if state_dict_model[name].dtype == torch.float8_e4m3fn: - state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype) - fp8=True - state_dict_model[name] += state_dict_lora[name].to( - dtype=state_dict_model[name].dtype, device=state_dict_model[name].device) - if fp8: - state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn) - model.load_state_dict(state_dict_model) + if DEBUG: + print(f"Applying {len(state_dict_lora)} LoRA tensors to model weights...") + else: + print("Applying LoRA weights...") + + # Process in batches + BATCH_SIZE = 32 + lora_keys = list(state_dict_lora.keys()) + + start_update = time.time() + with tqdm(total=len(lora_keys), desc="Applying LoRA weights") as pbar: + for i in range(0, len(lora_keys), BATCH_SIZE): + batch_keys = lora_keys[i:i+BATCH_SIZE] + for name in batch_keys: + if name not in state_dict_model: + pbar.update(1) + continue + + param = state_dict_model[name] + + # Handle FP8 tensors + fp8 = False + if param.dtype == torch.float8_e4m3fn: + param_data = param.to(state_dict_lora[name].dtype) + fp8 = True + else: + param_data = param.data + + # Apply direct update (avoids load_state_dict overhead) + param.data = param_data + state_dict_lora[name].to( + dtype=param_data.dtype, device=param_data.device) + + if fp8: + param.data = param.data.to(torch.float8_e4m3fn) + + pbar.update(1) + + # Clear memory after each batch + if self.use_gpu: + torch.cuda.empty_cache() + + end_update = time.time() + if DEBUG: + print(f"⏱️ Weight update took {end_update - start_update:.4f} seconds") + else: + print("Weight update complete.") + else: + print("No LoRA tensors to apply!") + + if DEBUG: + print("\n==== LoRA LOADING STATISTICS ====") + print(f"Total tensor movements to GPU: {self.stats['tensor_movements_to_gpu']}") + print(f"Total tensor movements to CPU: {self.stats['tensor_movements_to_cpu']}") + print(f"Total LoRA weights processed: {self.stats['lora_weights_processed']}") + print(f"Total format conversions: {self.stats['format_conversions']}") + print(f"Final memory usage: {memory_usage()}") + print("================================") + else: + print(f"LoRA load complete: {self.stats['lora_weights_processed']} weights processed, GPU moves: {self.stats['tensor_movements_to_gpu']}, CPU moves: {self.stats['tensor_movements_to_cpu']}.") + + # Clear temporary data and run garbage collection + del state_dict_lora + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() - + @timing_decorator def match(self, model, state_dict_lora): - for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes): + if DEBUG: + print(f"Trying to match LoRA format for {model.__class__.__name__}, memory usage: {memory_usage()}") + match_results = [] + + for i, (lora_prefix, model_class) in enumerate(zip(self.lora_prefix, self.supported_model_classes)): if not isinstance(model, model_class): continue - state_dict_model = model.state_dict() + + if DEBUG: + print(f"Checking prefix '{lora_prefix}' for model class {model_class.__name__}") + + # Get parameter names + param_names = set(name for name, _ in model.named_parameters()) + for model_resource in ["diffusers", "civitai"]: try: + if DEBUG: + print(f" Attempting {model_resource} format...") + start_time = time.time() + + # Try conversion state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0) converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \ else model.__class__.state_dict_converter().from_civitai state_dict_lora_ = converter_fn(state_dict_lora_) + if isinstance(state_dict_lora_, tuple): state_dict_lora_ = state_dict_lora_[0] + if len(state_dict_lora_) == 0: + if DEBUG: + print(f" ❌ No matching tensors found for {model_resource} format") continue - for name in state_dict_lora_: - if name not in state_dict_model: + + # Verify the keys actually match the model (sample check) + valid_keys = 0 + for name in list(state_dict_lora_.keys())[:10]: + if name in param_names: + valid_keys += 1 + else: + if DEBUG: + print(f" ⚠️ Key not found in model: {name}") break - else: + + end_time = time.time() + + if valid_keys > 0: + if DEBUG: + print(f" ✅ Match found! Prefix: {lora_prefix}, Format: {model_resource}, Valid keys: {valid_keys}") + print(f" ⏱️ Match verification took {end_time - start_time:.4f} seconds") + else: + print("Matching format found.") return lora_prefix, model_resource - except: - pass + else: + if DEBUG: + print(f" ❌ No valid keys found for this format") + + except Exception as e: + if DEBUG: + print(f" ❌ Error during matching: {str(e)}") + if DEBUG: + print("❌ No match found for any format or prefix") return None - - +# Specialized classes derived from LoRAFromCivitai class SDLoRAFromCivitai(LoRAFromCivitai): def __init__(self): super().__init__() @@ -148,7 +432,6 @@ def __init__(self): "output.blocks": "model.diffusion_model.output_blocks", } - class SDXLLoRAFromCivitai(LoRAFromCivitai): def __init__(self): super().__init__() @@ -176,7 +459,6 @@ def __init__(self): "output.blocks": "model.diffusion_model.output_blocks", "2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers" } - class FluxLoRAFromCivitai(LoRAFromCivitai): def __init__(self): @@ -195,89 +477,245 @@ def __init__(self): "txt.mod": "txt_mod", } - - +class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai): + def __init__(self): + super().__init__() + self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT] + self.lora_prefix = ["diffusion_model.", "transformer."] + self.special_keys = {} + class GeneralLoRAFromPeft: def __init__(self): self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel] - - - def get_name_dict(self, lora_state_dict): - lora_name_dict = {} - for key in lora_state_dict: - if ".lora_B." not in key: - continue - keys = key.split(".") - if len(keys) > keys.index("lora_B") + 2: - keys.pop(keys.index("lora_B") + 1) - keys.pop(keys.index("lora_B")) - if keys[0] == "diffusion_model": - keys.pop(0) - target_name = ".".join(keys) - lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) - return lora_name_dict - - - def match(self, model: torch.nn.Module, state_dict_lora): - lora_name_dict = self.get_name_dict(state_dict_lora) - model_name_dict = {name: None for name, _ in model.named_parameters()} - matched_num = sum([i in model_name_dict for i in lora_name_dict]) - if matched_num == len(lora_name_dict): - return "", "" - else: - return None - - - def fetch_device_and_dtype(self, state_dict): - device, dtype = None, None - for name, param in state_dict.items(): - device, dtype = param.device, param.dtype - break - computation_device = device - computation_dtype = dtype - if computation_device == torch.device("cpu"): - if torch.cuda.is_available(): - computation_device = torch.device("cuda") - if computation_dtype == torch.float8_e4m3fn: - computation_dtype = torch.float32 - return device, dtype, computation_device, computation_dtype + self.stats = { + "tensor_movements_to_gpu": 0, + "tensor_movements_to_cpu": 0, + "lora_weights_processed": 0, + "parameter_updates": 0, + } + # Set optimal thread count for CPU operations + self.optimal_threads = optimize_cpu_threading() + self.use_gpu = torch.cuda.is_available() + + # Enable tensor cores for matrix operations if available + if self.use_gpu and hasattr(torch.backends, 'cudnn'): + torch.backends.cudnn.benchmark = True + def _get_target_name(self, key): + """Extract target parameter name from LoRA key""" + keys = key.split(".") + if len(keys) > keys.index("lora_B") + 2: + keys.pop(keys.index("lora_B") + 1) + keys.pop(keys.index("lora_B")) + target_name = ".".join(keys) + if target_name.startswith("diffusion_model."): + target_name = target_name[len("diffusion_model."):] + return target_name + + @timing_decorator + def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}): + if DEBUG: + print(f"Converting state dict with GeneralLoRAFromPeft, memory: {memory_usage()}") + device = "cuda" if self.use_gpu else "cpu" + torch_dtype = torch.float16 if self.use_gpu else torch.float32 + + state_dict_ = {} + + # Count applicable keys + applicable_keys = [key for key in state_dict if ".lora_B." in key] + + # Process in batches + BATCH_SIZE = 16 + if DEBUG: + print(f"Processing {len(applicable_keys)} tensors in batches of {BATCH_SIZE}...") + + with tqdm(total=len(applicable_keys), desc="Converting LoRA weights") as pbar: + for i in range(0, len(applicable_keys), BATCH_SIZE): + batch_keys = applicable_keys[i:i+BATCH_SIZE] + for key in batch_keys: + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) + weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) + self.stats["tensor_movements_to_gpu"] += 2 + + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_up, weight_down) + + target_name = self._get_target_name(key) + + if target_state_dict and target_name not in target_state_dict: + pbar.update(1) + continue + + state_dict_[target_name] = lora_weight.cpu() + self.stats["tensor_movements_to_cpu"] += 1 + self.stats["lora_weights_processed"] += 1 + pbar.update(1) + + # Clear memory after each batch + del weight_up, weight_down, lora_weight + if self.use_gpu: + torch.cuda.empty_cache() + + if DEBUG: + print(f"Conversion complete, resulting in {len(state_dict_)} tensors, memory: {memory_usage()}") + else: + print(f"General LoRA conversion complete: {len(state_dict_)} weights processed") + return state_dict_ + @timing_decorator def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""): - state_dict_model = model.state_dict() - device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model) - lora_name_dict = self.get_name_dict(state_dict_lora) - for name in lora_name_dict: - weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype) - weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype) - if len(weight_up.shape) == 4: - weight_up = weight_up.squeeze(3).squeeze(2) - weight_down = weight_down.squeeze(3).squeeze(2) - weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - else: - weight_lora = alpha * torch.mm(weight_up, weight_down) - weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype) - weight_patched = weight_model + weight_lora - state_dict_model[name] = weight_patched.to(device=device, dtype=dtype) - print(f" {len(lora_name_dict)} tensors are updated.") - model.load_state_dict(state_dict_model) - - + """Apply LoRA weights directly to model parameters with batched processing""" + print(f"Starting optimized LoRA loading for {model.__class__.__name__}...") + + # Create parameter lookup dict + start_map = time.time() + param_dict = {name: param for name, param in model.named_parameters()} + end_map = time.time() + if DEBUG: + print(f"⏱️ Parameter mapping took {end_map - start_map:.4f} seconds, found {len(param_dict)} parameters") + else: + print(f"Mapped {len(param_dict)} model parameters") + + # Count applicable LoRA parameters + lora_b_keys = [key for key in state_dict_lora if ".lora_B." in key] + print(f"Found {len(lora_b_keys)} LoRA parameter pairs to process") + + # Group parameters by shape for better memory access patterns + shape_groups = {} + for key in lora_b_keys: + target_name = self._get_target_name(key) + if target_name not in param_dict: + continue + shape = state_dict_lora[key].shape + if shape not in shape_groups: + shape_groups[shape] = [] + shape_groups[shape].append((key, target_name)) + + if DEBUG: + print(f"Organized into {len(shape_groups)} shape groups for efficient processing") + + # Process each shape group in batches + BATCH_SIZE = 32 + modified_count = 0 + + for shape, key_pairs in shape_groups.items(): + if DEBUG: + print(f"Processing {len(key_pairs)} parameters with shape {shape}") + for i in range(0, len(key_pairs), BATCH_SIZE): + batch = key_pairs[i:i+BATCH_SIZE] + for lora_key, target_name in batch: + param = param_dict[target_name] + dtype_for_calc = torch.float32 if param.dtype == torch.float8_e4m3fn else param.dtype + + # Load weights and compute LoRA update + weight_b = state_dict_lora[lora_key].to(device="cuda" if self.use_gpu else "cpu", dtype=dtype_for_calc) + weight_a = state_dict_lora[lora_key.replace(".lora_B.", ".lora_A.")].to(device="cuda" if self.use_gpu else "cpu", dtype=dtype_for_calc) + self.stats["tensor_movements_to_gpu"] += 2 + + if len(weight_b.shape) == 4: + weight_b = weight_b.squeeze(3).squeeze(2) + weight_a = weight_a.squeeze(3).squeeze(2) + lora_weight = alpha * torch.mm(weight_b, weight_a).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_b, weight_a) + + # Apply update directly to parameter + if param.dtype == torch.float8_e4m3fn: + param_float = param.to(torch.float32) + # Ensure both tensors are on the same device before addition + lora_weight_device = lora_weight.to(device=param_float.device) + param.data = (param_float + lora_weight_device).to(param.dtype) + del param_float, lora_weight_device + else: + param.data += lora_weight.to(dtype=param.dtype, device=param.device) + + del weight_a, weight_b, lora_weight + self.stats["parameter_updates"] += 1 + modified_count += 1 + + if self.use_gpu: + torch.cuda.empty_cache() + + if DEBUG: + print("\n==== OPTIMIZED LORA LOADING STATISTICS ====") + print(f"Total tensor movements to GPU: {self.stats['tensor_movements_to_gpu']}") + print(f"Total LoRA weights processed: {self.stats['lora_weights_processed']}") + print(f"Total parameters updated: {self.stats['parameter_updates']}") + print(f"Final memory usage: {memory_usage()}") + print("==========================================") + else: + print(f"Optimized LoRA load complete: updated {modified_count} tensors") + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + print(f"⏱️ {modified_count} tensors were updated successfully") -class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai): - def __init__(self): - super().__init__() - self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT] - self.lora_prefix = ["diffusion_model.", "transformer."] - self.special_keys = {} - + @timing_decorator + def match(self, model, state_dict_lora): + """Check if LoRA parameters match model parameters""" + if DEBUG: + print(f"Checking General LoRA compatibility for {model.__class__.__name__}...") + + for model_class in self.supported_model_classes: + if not isinstance(model, model_class): + continue + + # Create set of parameter names + start_param = time.time() + param_names = set(name for name, _ in model.named_parameters()) + end_param = time.time() + if DEBUG: + print(f"⏱️ Parameter name collection took {end_param - start_param:.4f} seconds, found {len(param_names)} names") + + # Check if a sample of LoRA keys map to model parameters + matched_count = 0 + checked_count = 0 + + start_check = time.time() + for key in state_dict_lora: + if ".lora_B." not in key: + continue + + target_name = self._get_target_name(key) + if target_name in param_names: + matched_count += 1 + + checked_count += 1 + if matched_count >= 5: # Found enough matches + break + if checked_count >= 50 and matched_count == 0: # Checked enough without matches + break + end_check = time.time() + + if DEBUG: + print(f"⏱️ Match check took {end_check - start_check:.4f} seconds") + print(f"Matched {matched_count}/{checked_count} checked parameters") + + if matched_count > 0: + if DEBUG: + print(f"✅ Compatible with GeneralLoRAFromPeft") + else: + print("LoRA compatibility check: PASS") + return "", "" + if DEBUG: + print("❌ Not compatible with GeneralLoRAFromPeft") + return None class FluxLoRAConverter: def __init__(self): pass @staticmethod + @timing_decorator def align_to_opensource_format(state_dict, alpha=1.0): + if DEBUG: + print(f"Converting Flux LoRA to opensource format, input keys: {len(state_dict)}") prefix_rename_dict = { "single_blocks": "lora_unet_single_blocks", "blocks": "lora_unet_double_blocks", @@ -286,7 +724,6 @@ def align_to_opensource_format(state_dict, alpha=1.0): "norm.linear": "modulation_lin", "to_qkv_mlp": "linear1", "proj_out": "linear2", - "norm1_a.linear": "img_mod_lin", "norm1_b.linear": "txt_mod_lin", "attn.a_to_qkv": "img_attn_qkv", @@ -303,7 +740,7 @@ def align_to_opensource_format(state_dict, alpha=1.0): "lora_A.weight": "lora_down.weight", } state_dict_ = {} - for name, param in state_dict.items(): + for name, param in tqdm(state_dict.items(), desc="Aligning to opensource format"): names = name.split(".") if names[-2] != "lora_A" and names[-2] != "lora_B": names.pop(-2) @@ -317,10 +754,17 @@ def align_to_opensource_format(state_dict, alpha=1.0): state_dict_[rename] = param if rename.endswith("lora_up.weight"): state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0] + if DEBUG: + print(f"Conversion complete, output keys: {len(state_dict_)}") + else: + print(f"Flux LoRA conversion complete: {len(state_dict_)} keys") return state_dict_ @staticmethod + @timing_decorator def align_to_diffsynth_format(state_dict): + if DEBUG: + print(f"Converting to diffsynth format, input keys: {len(state_dict)}") rename_dict = { "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", @@ -356,7 +800,7 @@ def guess_block_id(name): return i, name.replace(f"_{i}_", "_blockid_") return None, None state_dict_ = {} - for name, param in state_dict.items(): + for name, param in tqdm(state_dict.items(), desc="Aligning to diffsynth format"): block_id, source_name = guess_block_id(name) if source_name in rename_dict: target_name = rename_dict[source_name] @@ -364,8 +808,11 @@ def guess_block_id(name): state_dict_[target_name] = param else: state_dict_[name] = param + if DEBUG: + print(f"Conversion complete, output keys: {len(state_dict_)}") + else: + print(f"Diffsynth conversion complete: {len(state_dict_)} keys") return state_dict_ - def get_lora_loaders(): - return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()] + return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()] \ No newline at end of file diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 7ae3c50a..5675b44c 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -1,4 +1,4 @@ -import os, torch, json, importlib +import os, torch, json, importlib, gc from typing import List from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website @@ -109,6 +109,8 @@ def load_single_patch_model_from_single_file(state_dict, model_name, model_class return model + + def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device): loaded_model_names, loaded_models = [], [] for model_name, model_class in zip(model_names, model_classes): @@ -337,6 +339,19 @@ def __init__( self.load_models(downloaded_files + file_path_list) + def clear_models(self): + """Explicitly release loaded models.""" + print("[ModelManager] Clearing loaded models.") + for model in self.model: + del model + self.model = [] + self.model_path = [] + self.model_name = [] + gc.collect() # Force garbage collection after deleting models + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print("[ModelManager] Models cleared.") + def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None): print(f"Loading models from file: {file_path}") if len(state_dict) == 0: diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 439d3119..d862fb72 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -148,7 +148,7 @@ def denoising_model(self): def encode_prompt(self, prompt, positive=True): - prompt_emb = self.prompter.encode_prompt(prompt, positive=positive) + prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device) return {"context": prompt_emb} @@ -214,6 +214,7 @@ def __call__( tea_cache_model_id="", progress_bar_cmd=tqdm, progress_bar_st=None, + cancel_fn=None # new cancel_fn parameter ): # Parameter check height, width = self.check_resize_height_width(height, width) @@ -262,6 +263,11 @@ def __call__( # Denoise self.load_models_to_device(["dit"]) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Check for cancellation on each timestep + if cancel_fn is not None and cancel_fn(): + print("[CMD] Video generation cancelled by user mid-run.") + self.load_models_to_device([]) + return [] timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference