|
1 | 1 | import os |
2 | 2 | import re |
3 | 3 | import torch |
4 | | -from typing import Callable, Dict, List, Tuple, Optional |
5 | | -from safetensors.torch import load_file |
| 4 | +from typing import Callable, Dict, Optional |
6 | 5 | from tqdm import tqdm |
7 | 6 | from PIL import Image |
8 | 7 | from dataclasses import dataclass |
9 | | -from diffsynth_engine.models.base import LoRAStateDictConverter, split_suffix |
10 | | -from diffsynth_engine.models.basic.lora import LoRAContext, LoRALinear, LoRAConv2d |
| 8 | + |
| 9 | +from diffsynth_engine.models.base import split_suffix |
| 10 | +from diffsynth_engine.models.basic.lora import LoRAContext |
11 | 11 | from diffsynth_engine.models.basic.timestep import TemporalTimesteps |
12 | 12 | from diffsynth_engine.models.sdxl import ( |
13 | 13 | SDXLTextEncoder, |
|
17 | 17 | SDXLUNet, |
18 | 18 | sdxl_unet_config, |
19 | 19 | ) |
20 | | -from diffsynth_engine.pipelines import BasePipeline |
| 20 | +from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter |
21 | 21 | from diffsynth_engine.tokenizers import CLIPTokenizer |
22 | 22 | from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler |
23 | 23 | from diffsynth_engine.algorithm.sampler import EulerSampler |
@@ -305,45 +305,10 @@ def predict_noise(self, latents, timestep, prompt_emb, add_text_embeds, add_time |
305 | 305 | ) |
306 | 306 | return noise_pred |
307 | 307 |
|
308 | | - def load_lora(self, path: str, scale: float, fused: bool = False, save_original_weight: bool = True): |
309 | | - self.load_loras([(path, scale)], fused, save_original_weight) |
310 | | - |
311 | | - def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = False, save_original_weight: bool = True): |
312 | | - for lora_path, lora_scale in lora_list: |
313 | | - state_dict = load_file(lora_path, device="cpu") |
314 | | - lora_state_dict = self.lora_converter.convert(state_dict) |
315 | | - for model_name, state_dict in lora_state_dict.items(): |
316 | | - model = getattr(self, model_name) |
317 | | - for key, param in state_dict.items(): |
318 | | - module = model.get_submodule(key) |
319 | | - if not isinstance(module, (LoRALinear, LoRAConv2d)): |
320 | | - raise ValueError(f"Unsupported lora key: {key}") |
321 | | - lora_args = { |
322 | | - "name": key, |
323 | | - "scale": lora_scale, |
324 | | - "rank": param["rank"], |
325 | | - "alpha": param["alpha"], |
326 | | - "up": param["up"], |
327 | | - "down": param["down"], |
328 | | - "device": self.device, |
329 | | - "dtype": self.dtype, |
330 | | - "save_original_weight": save_original_weight, |
331 | | - } |
332 | | - if fused: |
333 | | - module.add_frozen_lora(**lora_args) |
334 | | - else: |
335 | | - module.add_lora(**lora_args) |
336 | | - |
337 | 308 | def unload_loras(self): |
338 | | - for key, module in self.unet.named_modules(): |
339 | | - if isinstance(module, (LoRALinear, LoRAConv2d)): |
340 | | - module.clear() |
341 | | - for key, module in self.text_encoder.named_modules(): |
342 | | - if isinstance(module, (LoRALinear, LoRAConv2d)): |
343 | | - module.clear() |
344 | | - for key, module in self.text_encoder_2.named_modules(): |
345 | | - if isinstance(module, (LoRALinear, LoRAConv2d)): |
346 | | - module.clear() |
| 309 | + self.unet.unload_loras() |
| 310 | + self.text_encoder.unload_loras() |
| 311 | + self.text_encoder_2.unload_loras() |
347 | 312 |
|
348 | 313 | @torch.no_grad() |
349 | 314 | def __call__( |
|
0 commit comments