Skip to content

Commit 9471bff

Browse files
authored
Merge pull request #107 from modelscope/Artiprocher-dev
reduce VRAM requirements in Kolors LoRA
2 parents 9c6607f + 3f8eea4 commit 9471bff

File tree

3 files changed

+79
-36
lines changed

3 files changed

+79
-36
lines changed

diffsynth/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@
174174
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
175175
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
176176
],
177+
"SDXL-vae-fp16-fix": [
178+
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
179+
],
177180
}
178181
Preset_model_id: TypeAlias = Literal[
179182
"HunyuanDiT",
@@ -201,6 +204,7 @@
201204
"StableDiffusion3",
202205
"StableDiffusion3_without_T5",
203206
"Kolors",
207+
"SDXL-vae-fp16-fix",
204208
]
205209
Preset_model_website: TypeAlias = Literal[
206210
"HuggingFace",

examples/train/kolors/README.md

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,27 @@ Kolors is a Chinese diffusion model, which is based on ChatGLM and Stable Diffus
44

55
## Download models
66

7-
The following files will be used for constructing Kolors. You can download them from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors).
7+
The following files will be used for constructing Kolors. You can download Kolors from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors). Due to precision overflow issues, we need to download an additional VAE model (from [huggingface](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) or [modelscope](https://modelscope.cn/models/AI-ModelScope/sdxl-vae-fp16-fix)).
88

99
```
10-
models/kolors/Kolors
11-
├── text_encoder
12-
│ ├── config.json
13-
│ ├── pytorch_model-00001-of-00007.bin
14-
│ ├── pytorch_model-00002-of-00007.bin
15-
│ ├── pytorch_model-00003-of-00007.bin
16-
│ ├── pytorch_model-00004-of-00007.bin
17-
│ ├── pytorch_model-00005-of-00007.bin
18-
│ ├── pytorch_model-00006-of-00007.bin
19-
│ ├── pytorch_model-00007-of-00007.bin
20-
│ └── pytorch_model.bin.index.json
21-
├── unet
22-
│ └── diffusion_pytorch_model.safetensors
23-
└── vae
10+
models
11+
├── kolors
12+
│ └── Kolors
13+
│ ├── text_encoder
14+
│ │ ├── config.json
15+
│ │ ├── pytorch_model-00001-of-00007.bin
16+
│ │ ├── pytorch_model-00002-of-00007.bin
17+
│ │ ├── pytorch_model-00003-of-00007.bin
18+
│ │ ├── pytorch_model-00004-of-00007.bin
19+
│ │ ├── pytorch_model-00005-of-00007.bin
20+
│ │ ├── pytorch_model-00006-of-00007.bin
21+
│ │ ├── pytorch_model-00007-of-00007.bin
22+
│ │ └── pytorch_model.bin.index.json
23+
│ ├── unet
24+
│ │ └── diffusion_pytorch_model.safetensors
25+
│ └── vae
26+
│ └── diffusion_pytorch_model.safetensors
27+
└── sdxl-vae-fp16-fix
2428
└── diffusion_pytorch_model.safetensors
2529
```
2630

@@ -29,7 +33,7 @@ You can use the following code to download these files:
2933
```python
3034
from diffsynth import download_models
3135

32-
download_models(["Kolors"])
36+
download_models(["Kolors", "SDXL-vae-fp16-fix"])
3337
```
3438

3539
## Train
@@ -70,24 +74,30 @@ file_name,text
7074

7175
We provide a training script `train_kolors_lora.py`. Before you run this training script, please copy it to the root directory of this project.
7276

73-
The following settings are recommended. **We found the UNet model suffers from precision overflow issues, thus the training script doesn't support float16. 40GB VRAM is required. We are working on overcoming this pitfall.**
77+
The following settings are recommended. 22GB VRAM is required.
7478

7579
```
7680
CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \
77-
--pretrained_path models/kolors/Kolors \
81+
--pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
82+
--pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
83+
--pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
7884
--dataset_path data/dog \
7985
--output_path ./models \
8086
--max_epochs 10 \
8187
--center_crop \
8288
--use_gradient_checkpointing \
83-
--precision 32
89+
--precision "16-mixed"
8490
```
8591

8692
Optional arguments:
8793
```
8894
-h, --help show this help message and exit
89-
--pretrained_path PRETRAINED_PATH
90-
Path to pretrained model. For example, `models/kolors/Kolors`.
95+
--pretrained_unet_path PRETRAINED_UNET_PATH
96+
Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`.
97+
--pretrained_text_encoder_path PRETRAINED_TEXT_ENCODER_PATH
98+
Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`.
99+
--pretrained_fp16_vae_path PRETRAINED_FP16_VAE_PATH
100+
Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`.
91101
--dataset_path DATASET_PATH
92102
The path of the Dataset.
93103
--output_path OUTPUT_PATH

examples/train/kolors/train_kolors_lora.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from diffsynth import ModelManager, KolorsImagePipeline
1+
from diffsynth import KolorsImagePipeline, load_state_dict, ChatGLMModel, SDXLUNet, SDXLVAEEncoder
22
from peft import LoraConfig, inject_adapter_in_model
33
from torchvision import transforms
44
from PIL import Image
@@ -40,23 +40,40 @@ def __len__(self):
4040

4141

4242

43+
def load_model_from_diffsynth(ModelClass, model_kwargs, state_dict_path, torch_dtype, device):
44+
model = ModelClass(**model_kwargs).to(dtype=torch_dtype, device=device)
45+
state_dict = load_state_dict(state_dict_path, torch_dtype=torch_dtype)
46+
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
47+
return model
48+
49+
50+
def load_model_from_transformers(ModelClass, model_kwargs, state_dict_path, torch_dtype, device):
51+
model = ModelClass.from_pretrained(state_dict_path, torch_dtype=torch_dtype)
52+
model = model.to(dtype=torch_dtype, device=device)
53+
return model
54+
55+
56+
4357
class LightningModel(pl.LightningModule):
44-
def __init__(self, torch_dtype=torch.float16, learning_rate=1e-4, pretrained_weights=[], lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True):
58+
def __init__(
59+
self,
60+
pretrained_unet_path, pretrained_text_encoder_path, pretrained_fp16_vae_path,
61+
torch_dtype=torch.float16, learning_rate=1e-4, lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True
62+
):
4563
super().__init__()
4664

4765
# Load models
48-
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
49-
model_manager.load_models(pretrained_weights)
50-
self.pipe = KolorsImagePipeline.from_model_manager(model_manager)
66+
self.pipe = KolorsImagePipeline(device=self.device, torch_dtype=torch_dtype)
67+
self.pipe.text_encoder = load_model_from_transformers(ChatGLMModel, {}, pretrained_text_encoder_path, torch_dtype, self.device)
68+
self.pipe.unet = load_model_from_diffsynth(SDXLUNet, {"is_kolors": True}, pretrained_unet_path, torch_dtype, self.device)
69+
self.pipe.vae_encoder = load_model_from_diffsynth(SDXLVAEEncoder, {}, pretrained_fp16_vae_path, torch_dtype, self.device)
5170

5271
# Freeze parameters
5372
self.pipe.text_encoder.requires_grad_(False)
5473
self.pipe.unet.requires_grad_(False)
55-
self.pipe.vae_decoder.requires_grad_(False)
5674
self.pipe.vae_encoder.requires_grad_(False)
5775
self.pipe.text_encoder.eval()
5876
self.pipe.unet.train()
59-
self.pipe.vae_decoder.eval()
6077
self.pipe.vae_encoder.eval()
6178

6279
# Add LoRA to UNet
@@ -88,7 +105,7 @@ def training_step(self, batch, batch_idx):
88105
self.pipe.text_encoder, text, clip_skip=2, device=self.device, positive=True,
89106
)
90107
height, width = image.shape[-2:]
91-
latents = self.pipe.vae_encoder(image.to(dtype=torch.float32, device=self.device)).to(self.pipe.torch_dtype)
108+
latents = self.pipe.vae_encoder(image.to(self.device))
92109
noise = torch.randn_like(latents)
93110
timestep = torch.randint(0, 1100, (1,), device=self.device)[0]
94111
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
@@ -126,11 +143,25 @@ def on_save_checkpoint(self, checkpoint):
126143
def parse_args():
127144
parser = argparse.ArgumentParser(description="Simple example of a training script.")
128145
parser.add_argument(
129-
"--pretrained_path",
146+
"--pretrained_unet_path",
147+
type=str,
148+
default=None,
149+
required=True,
150+
help="Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`.",
151+
)
152+
parser.add_argument(
153+
"--pretrained_text_encoder_path",
154+
type=str,
155+
default=None,
156+
required=True,
157+
help="Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`.",
158+
)
159+
parser.add_argument(
160+
"--pretrained_fp16_vae_path",
130161
type=str,
131162
default=None,
132163
required=True,
133-
help="Path to pretrained model. For example, `models/kolors/Kolors`.",
164+
help="Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`.",
134165
)
135166
parser.add_argument(
136167
"--dataset_path",
@@ -267,11 +298,9 @@ def parse_args():
267298

268299
# model
269300
model = LightningModel(
270-
pretrained_weights=[
271-
os.path.join(args.pretrained_path, "text_encoder"),
272-
os.path.join(args.pretrained_path, "unet/diffusion_pytorch_model.safetensors"),
273-
os.path.join(args.pretrained_path, "vae/diffusion_pytorch_model.safetensors"),
274-
],
301+
args.pretrained_unet_path,
302+
args.pretrained_text_encoder_path,
303+
args.pretrained_fp16_vae_path,
275304
torch_dtype=torch.float32 if args.precision == "32" else torch.float16,
276305
learning_rate=args.learning_rate,
277306
lora_rank=args.lora_rank,

0 commit comments

Comments
 (0)