Skip to content

Commit 4b12864

Browse files
authored
supports loading LoRA with parallel Wan model (#30)
* fix cfg parallel * ParallelModel supports load/unload LoRA * fix param name * fix split tensor * update assert message * add example
1 parent ea6fde6 commit 4b12864

File tree

10 files changed

+188
-188
lines changed

10 files changed

+188
-188
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .base import PreTrainedModel, StateDictConverter
2+
3+
4+
__all__ = [
5+
"PreTrainedModel",
6+
"StateDictConverter",
7+
]

diffsynth_engine/models/base.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,15 @@
11
import os
22
import torch
33
import torch.nn as nn
4-
from typing import Dict, Union
4+
from typing import Dict, List, Union
55
from safetensors.torch import load_file
66

7+
from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
78
from diffsynth_engine.models.utils import no_init_weights
89

910

10-
class LoRAStateDictConverter:
11-
def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
12-
return {"lora": lora_state_dict}
13-
14-
15-
StateDictType = Dict[str, torch.Tensor]
16-
17-
1811
class StateDictConverter:
19-
def convert(self, state_dict: StateDictType) -> StateDictType:
12+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
2013
return state_dict
2114

2215

@@ -40,6 +33,22 @@ def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype
4033
model.to(device=device, dtype=dtype, non_blocking=True)
4134
return model
4235

36+
def load_loras(self, lora_args: List[Dict[str, any]], fused: bool = True):
37+
for args in lora_args:
38+
key = args["name"]
39+
module = self.get_submodule(key)
40+
if not isinstance(module, (LoRALinear, LoRAConv2d)):
41+
raise ValueError(f"Unsupported lora key: {key}")
42+
if fused:
43+
module.add_frozen_lora(**args)
44+
else:
45+
module.add_lora(**args)
46+
47+
def unload_loras(self):
48+
for module in self.modules():
49+
if isinstance(module, (LoRALinear, LoRAConv2d)):
50+
module.clear()
51+
4352

4453
def split_suffix(name: str):
4554
suffix_list = [

diffsynth_engine/pipelines/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from .base import BasePipeline
1+
from .base import BasePipeline, LoRAStateDictConverter
22
from .flux_image import FluxImagePipeline, FluxModelConfig
33
from .sdxl_image import SDXLImagePipeline, SDXLModelConfig
44
from .sd_image import SDImagePipeline, SDModelConfig
55
from .wan_video import WanVideoPipeline, WanModelConfig
66

77
__all__ = [
88
"BasePipeline",
9+
"LoRAStateDictConverter",
910
"FluxImagePipeline",
1011
"FluxModelConfig",
1112
"SDXLImagePipeline",

diffsynth_engine/pipelines/base.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import torch
33
import numpy as np
4-
from typing import Dict, List
4+
from typing import Dict, List, Tuple
55
from PIL import Image, ImageOps
66
from einops import repeat
77
from dataclasses import dataclass
@@ -19,7 +19,14 @@ class ModelConfig:
1919
pass
2020

2121

22+
class LoRAStateDictConverter:
23+
def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
24+
return {"lora": lora_state_dict}
25+
26+
2227
class BasePipeline:
28+
lora_converter = LoRAStateDictConverter()
29+
2330
def __init__(self, device="cuda:0", dtype=torch.float16):
2431
super().__init__()
2532
self.device = device
@@ -43,6 +50,36 @@ def from_state_dict(
4350
) -> "BasePipeline":
4451
raise NotImplementedError()
4552

53+
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
54+
for lora_path, lora_scale in lora_list:
55+
logger.info(f"loading lora from {lora_path} with scale {lora_scale}")
56+
state_dict = load_file(lora_path, device="cpu")
57+
lora_state_dict = self.lora_converter.convert(state_dict)
58+
for model_name, state_dict in lora_state_dict.items():
59+
model = getattr(self, model_name)
60+
lora_args = []
61+
for key, param in state_dict.items():
62+
lora_args.append(
63+
{
64+
"name": key,
65+
"scale": lora_scale,
66+
"rank": param["rank"],
67+
"alpha": param["alpha"],
68+
"up": param["up"],
69+
"down": param["down"],
70+
"device": self.device,
71+
"dtype": self.dtype,
72+
"save_original_weight": save_original_weight,
73+
}
74+
)
75+
model.load_loras(lora_args, fused=fused)
76+
77+
def load_lora(self, path: str, scale: float, fused: bool = True, save_original_weight: bool = False):
78+
self.load_loras([(path, scale)], fused, save_original_weight)
79+
80+
def unload_loras(self):
81+
raise NotImplementedError()
82+
4683
@staticmethod
4784
def load_model_checkpoint(
4885
checkpoint_path: str, device: str = "cpu", dtype: torch.dtype = torch.float16

diffsynth_engine/pipelines/flux_image.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import os
33
import torch
44
import math
5-
from typing import Callable, Dict, List, Tuple, Optional
6-
from safetensors.torch import load_file
5+
from typing import Callable, Dict, Optional
76
from tqdm import tqdm
87
from PIL import Image
98
from dataclasses import dataclass
@@ -16,9 +15,8 @@
1615
flux_dit_config,
1716
flux_text_encoder_config,
1817
)
19-
from diffsynth_engine.models.basic.lora import LoRAContext, LoRALinear, LoRAConv2d
20-
from diffsynth_engine.models.base import LoRAStateDictConverter
21-
from diffsynth_engine.pipelines import BasePipeline
18+
from diffsynth_engine.models.basic.lora import LoRAContext
19+
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
2220
from diffsynth_engine.tokenizers import CLIPTokenizer, T5TokenizerFast
2321
from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
2422
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
@@ -298,42 +296,10 @@ def from_pretrained(
298296
pipe.enable_sequential_cpu_offload()
299297
return pipe
300298

301-
def load_lora(self, path: str, scale: float, fused: bool = False, save_original_weight: bool = True):
302-
self.load_loras([(path, scale)], fused, save_original_weight)
303-
304-
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = False, save_original_weight: bool = True):
305-
for lora_path, lora_scale in lora_list:
306-
state_dict = load_file(lora_path, device="cpu")
307-
lora_state_dict = self.lora_converter.convert(state_dict)
308-
for model_name, state_dict in lora_state_dict.items():
309-
model = getattr(self, model_name)
310-
for key, param in state_dict.items():
311-
module = model.get_submodule(key)
312-
if not isinstance(module, (LoRALinear, LoRAConv2d)):
313-
raise ValueError(f"Unsupported lora key: {key}")
314-
lora_args = {
315-
"name": key,
316-
"scale": lora_scale,
317-
"rank": param["rank"],
318-
"alpha": param["alpha"],
319-
"up": param["up"],
320-
"down": param["down"],
321-
"device": self.device,
322-
"dtype": self.dtype,
323-
"save_original_weight": save_original_weight,
324-
}
325-
if fused:
326-
module.add_frozen_lora(**lora_args)
327-
else:
328-
module.add_lora(**lora_args)
329-
330299
def unload_loras(self):
331-
for key, module in self.dit.named_modules():
332-
if isinstance(module, (LoRALinear, LoRAConv2d)):
333-
module.clear()
334-
for key, module in self.text_encoder_1.named_modules():
335-
if isinstance(module, (LoRALinear, LoRAConv2d)):
336-
module.clear()
300+
self.dit.unload_loras()
301+
self.text_encoder_1.unload_loras()
302+
self.text_encoder_2.unload_loras()
337303

338304
@classmethod
339305
def from_state_dict(

diffsynth_engine/pipelines/sd_image.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
import os
33
import torch
44
from dataclasses import dataclass
5-
from typing import Callable, Dict, Optional, List, Tuple
6-
from safetensors.torch import load_file
5+
from typing import Callable, Dict, Optional
76
from tqdm import tqdm
87
from PIL import Image
98

10-
from diffsynth_engine.models.base import LoRAStateDictConverter, split_suffix
11-
from diffsynth_engine.models.basic.lora import LoRAContext, LoRALinear, LoRAConv2d
9+
from diffsynth_engine.models.base import split_suffix
10+
from diffsynth_engine.models.basic.lora import LoRAContext
1211
from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config
13-
from diffsynth_engine.pipelines import BasePipeline
12+
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
1413
from diffsynth_engine.tokenizers import CLIPTokenizer
1514
from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
1615
from diffsynth_engine.algorithm.sampler import EulerSampler
@@ -275,42 +274,9 @@ def predict_noise(self, latents, timestep, prompt_emb):
275274
)
276275
return noise_pred
277276

278-
def load_lora(self, path: str, scale: float, fused: bool = False, save_original_weight: bool = True):
279-
self.load_loras([(path, scale)], fused, save_original_weight)
280-
281-
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = False, save_original_weight: bool = True):
282-
for lora_path, lora_scale in lora_list:
283-
state_dict = load_file(lora_path, device="cpu")
284-
lora_state_dict = self.lora_converter.convert(state_dict)
285-
for model_name, state_dict in lora_state_dict.items():
286-
model = getattr(self, model_name)
287-
for key, param in state_dict.items():
288-
module = model.get_submodule(key)
289-
if not isinstance(module, (LoRALinear, LoRAConv2d)):
290-
raise ValueError(f"Unsupported lora key: {key}")
291-
lora_args = {
292-
"name": key,
293-
"scale": lora_scale,
294-
"rank": param["rank"],
295-
"alpha": param["alpha"],
296-
"up": param["up"],
297-
"down": param["down"],
298-
"device": self.device,
299-
"dtype": self.dtype,
300-
"save_original_weight": save_original_weight,
301-
}
302-
if fused:
303-
module.add_frozen_lora(**lora_args)
304-
else:
305-
module.add_lora(**lora_args)
306-
307277
def unload_loras(self):
308-
for key, module in self.unet.named_modules():
309-
if isinstance(module, (LoRALinear, LoRAConv2d)):
310-
module.clear()
311-
for key, module in self.text_encoder.named_modules():
312-
if isinstance(module, (LoRALinear, LoRAConv2d)):
313-
module.clear()
278+
self.unet.unload_loras()
279+
self.text_encoder.unload_loras()
314280

315281
@torch.no_grad()
316282
def __call__(

diffsynth_engine/pipelines/sdxl_image.py

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import os
22
import re
33
import torch
4-
from typing import Callable, Dict, List, Tuple, Optional
5-
from safetensors.torch import load_file
4+
from typing import Callable, Dict, Optional
65
from tqdm import tqdm
76
from PIL import Image
87
from dataclasses import dataclass
9-
from diffsynth_engine.models.base import LoRAStateDictConverter, split_suffix
10-
from diffsynth_engine.models.basic.lora import LoRAContext, LoRALinear, LoRAConv2d
8+
9+
from diffsynth_engine.models.base import split_suffix
10+
from diffsynth_engine.models.basic.lora import LoRAContext
1111
from diffsynth_engine.models.basic.timestep import TemporalTimesteps
1212
from diffsynth_engine.models.sdxl import (
1313
SDXLTextEncoder,
@@ -17,7 +17,7 @@
1717
SDXLUNet,
1818
sdxl_unet_config,
1919
)
20-
from diffsynth_engine.pipelines import BasePipeline
20+
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
2121
from diffsynth_engine.tokenizers import CLIPTokenizer
2222
from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
2323
from diffsynth_engine.algorithm.sampler import EulerSampler
@@ -305,45 +305,10 @@ def predict_noise(self, latents, timestep, prompt_emb, add_text_embeds, add_time
305305
)
306306
return noise_pred
307307

308-
def load_lora(self, path: str, scale: float, fused: bool = False, save_original_weight: bool = True):
309-
self.load_loras([(path, scale)], fused, save_original_weight)
310-
311-
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = False, save_original_weight: bool = True):
312-
for lora_path, lora_scale in lora_list:
313-
state_dict = load_file(lora_path, device="cpu")
314-
lora_state_dict = self.lora_converter.convert(state_dict)
315-
for model_name, state_dict in lora_state_dict.items():
316-
model = getattr(self, model_name)
317-
for key, param in state_dict.items():
318-
module = model.get_submodule(key)
319-
if not isinstance(module, (LoRALinear, LoRAConv2d)):
320-
raise ValueError(f"Unsupported lora key: {key}")
321-
lora_args = {
322-
"name": key,
323-
"scale": lora_scale,
324-
"rank": param["rank"],
325-
"alpha": param["alpha"],
326-
"up": param["up"],
327-
"down": param["down"],
328-
"device": self.device,
329-
"dtype": self.dtype,
330-
"save_original_weight": save_original_weight,
331-
}
332-
if fused:
333-
module.add_frozen_lora(**lora_args)
334-
else:
335-
module.add_lora(**lora_args)
336-
337308
def unload_loras(self):
338-
for key, module in self.unet.named_modules():
339-
if isinstance(module, (LoRALinear, LoRAConv2d)):
340-
module.clear()
341-
for key, module in self.text_encoder.named_modules():
342-
if isinstance(module, (LoRALinear, LoRAConv2d)):
343-
module.clear()
344-
for key, module in self.text_encoder_2.named_modules():
345-
if isinstance(module, (LoRALinear, LoRAConv2d)):
346-
module.clear()
309+
self.unet.unload_loras()
310+
self.text_encoder.unload_loras()
311+
self.text_encoder_2.unload_loras()
347312

348313
@torch.no_grad()
349314
def __call__(

0 commit comments

Comments
 (0)