|
1 | | -from diffsynth import ModelManager, KolorsImagePipeline |
| 1 | +from diffsynth import KolorsImagePipeline, load_state_dict, ChatGLMModel, SDXLUNet, SDXLVAEEncoder |
2 | 2 | from peft import LoraConfig, inject_adapter_in_model |
3 | 3 | from torchvision import transforms |
4 | 4 | from PIL import Image |
@@ -40,23 +40,40 @@ def __len__(self): |
40 | 40 |
|
41 | 41 |
|
42 | 42 |
|
| 43 | +def load_model_from_diffsynth(ModelClass, model_kwargs, state_dict_path, torch_dtype, device): |
| 44 | + model = ModelClass(**model_kwargs).to(dtype=torch_dtype, device=device) |
| 45 | + state_dict = load_state_dict(state_dict_path, torch_dtype=torch_dtype) |
| 46 | + model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict)) |
| 47 | + return model |
| 48 | + |
| 49 | + |
| 50 | +def load_model_from_transformers(ModelClass, model_kwargs, state_dict_path, torch_dtype, device): |
| 51 | + model = ModelClass.from_pretrained(state_dict_path, torch_dtype=torch_dtype) |
| 52 | + model = model.to(dtype=torch_dtype, device=device) |
| 53 | + return model |
| 54 | + |
| 55 | + |
| 56 | + |
43 | 57 | class LightningModel(pl.LightningModule): |
44 | | - def __init__(self, torch_dtype=torch.float16, learning_rate=1e-4, pretrained_weights=[], lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True): |
| 58 | + def __init__( |
| 59 | + self, |
| 60 | + pretrained_unet_path, pretrained_text_encoder_path, pretrained_fp16_vae_path, |
| 61 | + torch_dtype=torch.float16, learning_rate=1e-4, lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True |
| 62 | + ): |
45 | 63 | super().__init__() |
46 | 64 |
|
47 | 65 | # Load models |
48 | | - model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device) |
49 | | - model_manager.load_models(pretrained_weights) |
50 | | - self.pipe = KolorsImagePipeline.from_model_manager(model_manager) |
| 66 | + self.pipe = KolorsImagePipeline(device=self.device, torch_dtype=torch_dtype) |
| 67 | + self.pipe.text_encoder = load_model_from_transformers(ChatGLMModel, {}, pretrained_text_encoder_path, torch_dtype, self.device) |
| 68 | + self.pipe.unet = load_model_from_diffsynth(SDXLUNet, {"is_kolors": True}, pretrained_unet_path, torch_dtype, self.device) |
| 69 | + self.pipe.vae_encoder = load_model_from_diffsynth(SDXLVAEEncoder, {}, pretrained_fp16_vae_path, torch_dtype, self.device) |
51 | 70 |
|
52 | 71 | # Freeze parameters |
53 | 72 | self.pipe.text_encoder.requires_grad_(False) |
54 | 73 | self.pipe.unet.requires_grad_(False) |
55 | | - self.pipe.vae_decoder.requires_grad_(False) |
56 | 74 | self.pipe.vae_encoder.requires_grad_(False) |
57 | 75 | self.pipe.text_encoder.eval() |
58 | 76 | self.pipe.unet.train() |
59 | | - self.pipe.vae_decoder.eval() |
60 | 77 | self.pipe.vae_encoder.eval() |
61 | 78 |
|
62 | 79 | # Add LoRA to UNet |
@@ -88,7 +105,7 @@ def training_step(self, batch, batch_idx): |
88 | 105 | self.pipe.text_encoder, text, clip_skip=2, device=self.device, positive=True, |
89 | 106 | ) |
90 | 107 | height, width = image.shape[-2:] |
91 | | - latents = self.pipe.vae_encoder(image.to(dtype=torch.float32, device=self.device)).to(self.pipe.torch_dtype) |
| 108 | + latents = self.pipe.vae_encoder(image.to(self.device)) |
92 | 109 | noise = torch.randn_like(latents) |
93 | 110 | timestep = torch.randint(0, 1100, (1,), device=self.device)[0] |
94 | 111 | add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) |
@@ -126,11 +143,25 @@ def on_save_checkpoint(self, checkpoint): |
126 | 143 | def parse_args(): |
127 | 144 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
128 | 145 | parser.add_argument( |
129 | | - "--pretrained_path", |
| 146 | + "--pretrained_unet_path", |
| 147 | + type=str, |
| 148 | + default=None, |
| 149 | + required=True, |
| 150 | + help="Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`.", |
| 151 | + ) |
| 152 | + parser.add_argument( |
| 153 | + "--pretrained_text_encoder_path", |
| 154 | + type=str, |
| 155 | + default=None, |
| 156 | + required=True, |
| 157 | + help="Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`.", |
| 158 | + ) |
| 159 | + parser.add_argument( |
| 160 | + "--pretrained_fp16_vae_path", |
130 | 161 | type=str, |
131 | 162 | default=None, |
132 | 163 | required=True, |
133 | | - help="Path to pretrained model. For example, `models/kolors/Kolors`.", |
| 164 | + help="Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`.", |
134 | 165 | ) |
135 | 166 | parser.add_argument( |
136 | 167 | "--dataset_path", |
@@ -267,11 +298,9 @@ def parse_args(): |
267 | 298 |
|
268 | 299 | # model |
269 | 300 | model = LightningModel( |
270 | | - pretrained_weights=[ |
271 | | - os.path.join(args.pretrained_path, "text_encoder"), |
272 | | - os.path.join(args.pretrained_path, "unet/diffusion_pytorch_model.safetensors"), |
273 | | - os.path.join(args.pretrained_path, "vae/diffusion_pytorch_model.safetensors"), |
274 | | - ], |
| 301 | + args.pretrained_unet_path, |
| 302 | + args.pretrained_text_encoder_path, |
| 303 | + args.pretrained_fp16_vae_path, |
275 | 304 | torch_dtype=torch.float32 if args.precision == "32" else torch.float16, |
276 | 305 | learning_rate=args.learning_rate, |
277 | 306 | lora_rank=args.lora_rank, |
|
0 commit comments