diff --git a/examples/research_projects/ip_adapter/README.md b/examples/research_projects/ip_adapter/README.md new file mode 100644 index 000000000000..04a6c86e5305 --- /dev/null +++ b/examples/research_projects/ip_adapter/README.md @@ -0,0 +1,226 @@ +# IP Adapter Training Example + +[IP Adapter](https://arxiv.org/abs/2308.06721) is a novel approach designed to enhance text-to-image models such as Stable Diffusion by enabling them to generate images based on image prompts rather than text prompts alone. Unlike traditional methods that rely solely on complex text prompts, IP Adapter introduces the concept of using image prompts, leveraging the idea that "an image is worth a thousand words." By decoupling cross-attention layers for text and image features, IP Adapter effectively integrates image prompts into the generation process without the need for extensive fine-tuning or large computing resources. + +## Training locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the example folder and run + +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell e.g. a notebook + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +Certainly! Below is the documentation in pure Markdown format: + +### Accelerate Launch Command Documentation + +#### Description: +The Accelerate launch command is used to train a model using multiple GPUs and mixed precision training. It launches the training script `tutorial_train_ip-adapter.py` with specified parameters and configurations. + +#### Usage Example: + +``` +accelerate launch --mixed_precision "fp16" \ +tutorial_train_ip-adapter.py \ +--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \ +--image_encoder_path="{image_encoder_path}" \ +--data_json_file="{data.json}" \ +--data_root_path="{image_path}" \ +--mixed_precision="fp16" \ +--resolution=512 \ +--train_batch_size=8 \ +--dataloader_num_workers=4 \ +--learning_rate=1e-04 \ +--weight_decay=0.01 \ +--output_dir="{output_dir}" \ +--save_steps=10000 +``` + +### Multi-GPU Script: +``` +accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \ + tutorial_train_ip-adapter.py \ + --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \ + --image_encoder_path="{image_encoder_path}" \ + --data_json_file="{data.json}" \ + --data_root_path="{image_path}" \ + --mixed_precision="fp16" \ + --resolution=512 \ + --train_batch_size=8 \ + --dataloader_num_workers=4 \ + --learning_rate=1e-04 \ + --weight_decay=0.01 \ + --output_dir="{output_dir}" \ + --save_steps=10000 +``` + +#### Parameters: +- `--num_processes`: Number of processes to launch for distributed training (in this example, 8 processes). +- `--multi_gpu`: Flag indicating the usage of multiple GPUs for training. +- `--mixed_precision "fp16"`: Enables mixed precision training with 16-bit floating-point precision. +- `tutorial_train_ip-adapter.py`: Name of the training script to be executed. +- `--pretrained_model_name_or_path`: Path or identifier for a pretrained model. +- `--image_encoder_path`: Path to the CLIP image encoder. +- `--data_json_file`: Path to the training data in JSON format. +- `--data_root_path`: Root path where training images are located. +- `--resolution`: Resolution of input images (512x512 in this example). +- `--train_batch_size`: Batch size for training data (8 in this example). +- `--dataloader_num_workers`: Number of subprocesses for data loading (4 in this example). +- `--learning_rate`: Learning rate for training (1e-04 in this example). +- `--weight_decay`: Weight decay for regularization (0.01 in this example). +- `--output_dir`: Directory to save model checkpoints and predictions. +- `--save_steps`: Frequency of saving checkpoints during training (10000 in this example). + +### Inference + +#### Description: +The provided inference code is used to load a trained model checkpoint and extract the components related to image projection and IP (Image Processing) adapter. These components are then saved into a binary file for later use in inference. + +#### Usage Example: +```python +from safetensors.torch import load_file, save_file + +# Load the trained model checkpoint in safetensors format +ckpt = "checkpoint-50000/pytorch_model.safetensors" +sd = load_file(ckpt) # Using safetensors load function + +# Extract image projection and IP adapter components +image_proj_sd = {} +ip_sd = {} + +for k in sd: + if k.startswith("unet"): + pass # Skip unet-related keys + elif k.startswith("image_proj_model"): + image_proj_sd[k.replace("image_proj_model.", "")] = sd[k] + elif k.startswith("adapter_modules"): + ip_sd[k.replace("adapter_modules.", "")] = sd[k] + +# Save the components into separate safetensors files +save_file(image_proj_sd, "image_proj.safetensors") +save_file(ip_sd, "ip_adapter.safetensors") +``` + +### Sample Inference Script using the CLIP Model + +```python + +import torch +from safetensors.torch import load_file +from transformers import CLIPProcessor, CLIPModel # Using the Hugging Face CLIP model + +# Load model components from safetensors +image_proj_ckpt = "image_proj.safetensors" +ip_adapter_ckpt = "ip_adapter.safetensors" + +# Load the saved weights +image_proj_sd = load_file(image_proj_ckpt) +ip_adapter_sd = load_file(ip_adapter_ckpt) + +# Define the model Parameters +class ImageProjectionModel(torch.nn.Module): + def __init__(self, input_dim=768, output_dim=512): # CLIP's default embedding size is 768 + super().__init__() + self.model = torch.nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.model(x) + +class IPAdapterModel(torch.nn.Module): + def __init__(self, input_dim=512, output_dim=10): # Example for 10 classes + super().__init__() + self.model = torch.nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.model(x) + +# Initialize models +image_proj_model = ImageProjectionModel() +ip_adapter_model = IPAdapterModel() + +# Load weights into models +image_proj_model.load_state_dict(image_proj_sd) +ip_adapter_model.load_state_dict(ip_adapter_sd) + +# Set models to evaluation mode +image_proj_model.eval() +ip_adapter_model.eval() + +#Inference pipeline +def inference(image_tensor): + """ + Run inference using the loaded models. + + Args: + image_tensor: Preprocessed image tensor from CLIPProcessor + + Returns: + Final inference results + """ + with torch.no_grad(): + # Step 1: Project the image features + image_proj = image_proj_model(image_tensor) + + # Step 2: Pass the projected features through the IP Adapter + result = ip_adapter_model(image_proj) + + return result + +# Using CLIP for image preprocessing +processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") +clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + +#Image file path +image_path = "path/to/image.jpg" + +# Preprocess the image +inputs = processor(images=image_path, return_tensors="pt") +image_features = clip_model.get_image_features(inputs["pixel_values"]) + +# Normalize the image features as per CLIP's recommendations +image_features = image_features / image_features.norm(dim=-1, keepdim=True) + +# Run inference +output = inference(image_features) +print("Inference output:", output) +``` + +#### Parameters: +- `ckpt`: Path to the trained model checkpoint file. +- `map_location="cpu"`: Specifies that the model should be loaded onto the CPU. +- `image_proj_sd`: Dictionary to store the components related to image projection. +- `ip_sd`: Dictionary to store the components related to the IP adapter. +- `"unet"`, `"image_proj_model"`, `"adapter_modules"`: Prefixes indicating components of the model. \ No newline at end of file diff --git a/examples/research_projects/ip_adapter/requirements.txt b/examples/research_projects/ip_adapter/requirements.txt new file mode 100644 index 000000000000..749aa795015d --- /dev/null +++ b/examples/research_projects/ip_adapter/requirements.txt @@ -0,0 +1,4 @@ +accelerate +torchvision +transformers>=4.25.1 +ip_adapter diff --git a/examples/research_projects/ip_adapter/tutorial_train_faceid.py b/examples/research_projects/ip_adapter/tutorial_train_faceid.py new file mode 100644 index 000000000000..3e337ec02f7f --- /dev/null +++ b/examples/research_projects/ip_adapter/tutorial_train_faceid.py @@ -0,0 +1,415 @@ +import argparse +import itertools +import json +import os +import random +import time +from pathlib import Path + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration +from ip_adapter.attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor +from ip_adapter.ip_adapter_faceid import MLPProjModel +from PIL import Image +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel + + +# Dataset +class MyDataset(torch.utils.data.Dataset): + def __init__( + self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path="" + ): + super().__init__() + + self.tokenizer = tokenizer + self.size = size + self.i_drop_rate = i_drop_rate + self.t_drop_rate = t_drop_rate + self.ti_drop_rate = ti_drop_rate + self.image_root_path = image_root_path + + self.data = json.load( + open(json_file) + ) # list of dict: [{"image_file": "1.png", "id_embed_file": "faceid.bin"}] + + self.transform = transforms.Compose( + [ + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(self.size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __getitem__(self, idx): + item = self.data[idx] + text = item["text"] + image_file = item["image_file"] + + # read image + raw_image = Image.open(os.path.join(self.image_root_path, image_file)) + image = self.transform(raw_image.convert("RGB")) + + face_id_embed = torch.load(item["id_embed_file"], map_location="cpu") + face_id_embed = torch.from_numpy(face_id_embed) + + # drop + drop_image_embed = 0 + rand_num = random.random() + if rand_num < self.i_drop_rate: + drop_image_embed = 1 + elif rand_num < (self.i_drop_rate + self.t_drop_rate): + text = "" + elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): + text = "" + drop_image_embed = 1 + if drop_image_embed: + face_id_embed = torch.zeros_like(face_id_embed) + # get text and tokenize + text_input_ids = self.tokenizer( + text, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ).input_ids + + return { + "image": image, + "text_input_ids": text_input_ids, + "face_id_embed": face_id_embed, + "drop_image_embed": drop_image_embed, + } + + def __len__(self): + return len(self.data) + + +def collate_fn(data): + images = torch.stack([example["image"] for example in data]) + text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) + face_id_embed = torch.stack([example["face_id_embed"] for example in data]) + drop_image_embeds = [example["drop_image_embed"] for example in data] + + return { + "images": images, + "text_input_ids": text_input_ids, + "face_id_embed": face_id_embed, + "drop_image_embeds": drop_image_embeds, + } + + +class IPAdapter(torch.nn.Module): + """IP-Adapter""" + + def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): + super().__init__() + self.unet = unet + self.image_proj_model = image_proj_model + self.adapter_modules = adapter_modules + + if ckpt_path is not None: + self.load_from_checkpoint(ckpt_path) + + def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): + ip_tokens = self.image_proj_model(image_embeds) + encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) + # Predict the noise residual + noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample + return noise_pred + + def load_from_checkpoint(self, ckpt_path: str): + # Calculate original checksums + orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + state_dict = torch.load(ckpt_path, map_location="cpu") + + # Load state dict for image_proj_model and adapter_modules + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) + + # Calculate new checksums + new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + # Verify if the weights have changed + assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" + assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" + + print(f"Successfully loaded weights from checkpoint {ckpt_path}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_ip_adapter_path", + type=str, + default=None, + help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", + ) + parser.add_argument( + "--data_json_file", + type=str, + default=None, + required=True, + help="Training data", + ) + parser.add_argument( + "--data_root_path", + type=str, + default="", + required=True, + help="Training data root path", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + required=True, + help="Path to CLIP image encoder", + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-ip_adapter", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help=("The resolution for input images"), + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Learning rate to use.", + ) + parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--save_steps", + type=int, + default=2000, + help=("Save a checkpoint of the training state every X updates"), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + # image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) + # freeze parameters of models to save more memory + unet.requires_grad_(False) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + # image_encoder.requires_grad_(False) + + # ip-adapter + image_proj_model = MLPProjModel( + cross_attention_dim=unet.config.cross_attention_dim, + id_embeddings_dim=512, + num_tokens=4, + ) + # init adapter modules + lora_rank = 128 + attn_procs = {} + unet_sd = unet.state_dict() + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank + ) + else: + layer_name = name.split(".processor")[0] + weights = { + "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], + "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], + } + attn_procs[name] = LoRAIPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank + ) + attn_procs[name].load_state_dict(weights, strict=False) + unet.set_attn_processor(attn_procs) + adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) + + ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + # unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + # image_encoder.to(accelerator.device, dtype=weight_dtype) + + # optimizer + params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters()) + optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) + + # dataloader + train_dataset = MyDataset( + args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Prepare everything with our `accelerator`. + ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) + + global_step = 0 + for epoch in range(0, args.num_train_epochs): + begin = time.perf_counter() + for step, batch in enumerate(train_dataloader): + load_data_time = time.perf_counter() - begin + with accelerator.accumulate(ip_adapter): + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode( + batch["images"].to(accelerator.device, dtype=weight_dtype) + ).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + image_embeds = batch["face_id_embed"].to(accelerator.device, dtype=weight_dtype) + + with torch.no_grad(): + encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] + + noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds) + + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() + + # Backpropagate + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + + if accelerator.is_main_process: + print( + "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( + epoch, step, load_data_time, time.perf_counter() - begin, avg_loss + ) + ) + + global_step += 1 + + if global_step % args.save_steps == 0: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + + begin = time.perf_counter() + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py b/examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py new file mode 100644 index 000000000000..9a3513f4c549 --- /dev/null +++ b/examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py @@ -0,0 +1,422 @@ +import argparse +import itertools +import json +import os +import random +import time +from pathlib import Path + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration +from ip_adapter.ip_adapter import ImageProjModel +from ip_adapter.utils import is_torch2_available +from PIL import Image +from torchvision import transforms +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel + + +if is_torch2_available(): + from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor + from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor +else: + from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor + + +# Dataset +class MyDataset(torch.utils.data.Dataset): + def __init__( + self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path="" + ): + super().__init__() + + self.tokenizer = tokenizer + self.size = size + self.i_drop_rate = i_drop_rate + self.t_drop_rate = t_drop_rate + self.ti_drop_rate = ti_drop_rate + self.image_root_path = image_root_path + + self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}] + + self.transform = transforms.Compose( + [ + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(self.size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + self.clip_image_processor = CLIPImageProcessor() + + def __getitem__(self, idx): + item = self.data[idx] + text = item["text"] + image_file = item["image_file"] + + # read image + raw_image = Image.open(os.path.join(self.image_root_path, image_file)) + image = self.transform(raw_image.convert("RGB")) + clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values + + # drop + drop_image_embed = 0 + rand_num = random.random() + if rand_num < self.i_drop_rate: + drop_image_embed = 1 + elif rand_num < (self.i_drop_rate + self.t_drop_rate): + text = "" + elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): + text = "" + drop_image_embed = 1 + # get text and tokenize + text_input_ids = self.tokenizer( + text, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ).input_ids + + return { + "image": image, + "text_input_ids": text_input_ids, + "clip_image": clip_image, + "drop_image_embed": drop_image_embed, + } + + def __len__(self): + return len(self.data) + + +def collate_fn(data): + images = torch.stack([example["image"] for example in data]) + text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) + clip_images = torch.cat([example["clip_image"] for example in data], dim=0) + drop_image_embeds = [example["drop_image_embed"] for example in data] + + return { + "images": images, + "text_input_ids": text_input_ids, + "clip_images": clip_images, + "drop_image_embeds": drop_image_embeds, + } + + +class IPAdapter(torch.nn.Module): + """IP-Adapter""" + + def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): + super().__init__() + self.unet = unet + self.image_proj_model = image_proj_model + self.adapter_modules = adapter_modules + + if ckpt_path is not None: + self.load_from_checkpoint(ckpt_path) + + def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): + ip_tokens = self.image_proj_model(image_embeds) + encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) + # Predict the noise residual + noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample + return noise_pred + + def load_from_checkpoint(self, ckpt_path: str): + # Calculate original checksums + orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + state_dict = torch.load(ckpt_path, map_location="cpu") + + # Load state dict for image_proj_model and adapter_modules + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) + + # Calculate new checksums + new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + # Verify if the weights have changed + assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" + assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" + + print(f"Successfully loaded weights from checkpoint {ckpt_path}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_ip_adapter_path", + type=str, + default=None, + help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", + ) + parser.add_argument( + "--data_json_file", + type=str, + default=None, + required=True, + help="Training data", + ) + parser.add_argument( + "--data_root_path", + type=str, + default="", + required=True, + help="Training data root path", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + required=True, + help="Path to CLIP image encoder", + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-ip_adapter", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help=("The resolution for input images"), + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Learning rate to use.", + ) + parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--save_steps", + type=int, + default=2000, + help=("Save a checkpoint of the training state every X updates"), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) + # freeze parameters of models to save more memory + unet.requires_grad_(False) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + image_encoder.requires_grad_(False) + + # ip-adapter + image_proj_model = ImageProjModel( + cross_attention_dim=unet.config.cross_attention_dim, + clip_embeddings_dim=image_encoder.config.projection_dim, + clip_extra_context_tokens=4, + ) + # init adapter modules + attn_procs = {} + unet_sd = unet.state_dict() + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + layer_name = name.split(".processor")[0] + weights = { + "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], + "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], + } + attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + attn_procs[name].load_state_dict(weights) + unet.set_attn_processor(attn_procs) + adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) + + ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + # unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + image_encoder.to(accelerator.device, dtype=weight_dtype) + + # optimizer + params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters()) + optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) + + # dataloader + train_dataset = MyDataset( + args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Prepare everything with our `accelerator`. + ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) + + global_step = 0 + for epoch in range(0, args.num_train_epochs): + begin = time.perf_counter() + for step, batch in enumerate(train_dataloader): + load_data_time = time.perf_counter() - begin + with accelerator.accumulate(ip_adapter): + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode( + batch["images"].to(accelerator.device, dtype=weight_dtype) + ).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + with torch.no_grad(): + image_embeds = image_encoder( + batch["clip_images"].to(accelerator.device, dtype=weight_dtype) + ).image_embeds + image_embeds_ = [] + for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]): + if drop_image_embed == 1: + image_embeds_.append(torch.zeros_like(image_embed)) + else: + image_embeds_.append(image_embed) + image_embeds = torch.stack(image_embeds_) + + with torch.no_grad(): + encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] + + noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds) + + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() + + # Backpropagate + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + + if accelerator.is_main_process: + print( + "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( + epoch, step, load_data_time, time.perf_counter() - begin, avg_loss + ) + ) + + global_step += 1 + + if global_step % args.save_steps == 0: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + + begin = time.perf_counter() + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/ip_adapter/tutorial_train_plus.py b/examples/research_projects/ip_adapter/tutorial_train_plus.py new file mode 100644 index 000000000000..e777ea1f0047 --- /dev/null +++ b/examples/research_projects/ip_adapter/tutorial_train_plus.py @@ -0,0 +1,445 @@ +import argparse +import itertools +import json +import os +import random +import time +from pathlib import Path + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration +from ip_adapter.resampler import Resampler +from ip_adapter.utils import is_torch2_available +from PIL import Image +from torchvision import transforms +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel + + +if is_torch2_available(): + from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor + from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor +else: + from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor + + +# Dataset +class MyDataset(torch.utils.data.Dataset): + def __init__( + self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path="" + ): + super().__init__() + + self.tokenizer = tokenizer + self.size = size + self.i_drop_rate = i_drop_rate + self.t_drop_rate = t_drop_rate + self.ti_drop_rate = ti_drop_rate + self.image_root_path = image_root_path + + self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}] + + self.transform = transforms.Compose( + [ + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(self.size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + self.clip_image_processor = CLIPImageProcessor() + + def __getitem__(self, idx): + item = self.data[idx] + text = item["text"] + image_file = item["image_file"] + + # read image + raw_image = Image.open(os.path.join(self.image_root_path, image_file)) + image = self.transform(raw_image.convert("RGB")) + clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values + + # drop + drop_image_embed = 0 + rand_num = random.random() + if rand_num < self.i_drop_rate: + drop_image_embed = 1 + elif rand_num < (self.i_drop_rate + self.t_drop_rate): + text = "" + elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): + text = "" + drop_image_embed = 1 + # get text and tokenize + text_input_ids = self.tokenizer( + text, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ).input_ids + + return { + "image": image, + "text_input_ids": text_input_ids, + "clip_image": clip_image, + "drop_image_embed": drop_image_embed, + } + + def __len__(self): + return len(self.data) + + +def collate_fn(data): + images = torch.stack([example["image"] for example in data]) + text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) + clip_images = torch.cat([example["clip_image"] for example in data], dim=0) + drop_image_embeds = [example["drop_image_embed"] for example in data] + + return { + "images": images, + "text_input_ids": text_input_ids, + "clip_images": clip_images, + "drop_image_embeds": drop_image_embeds, + } + + +class IPAdapter(torch.nn.Module): + """IP-Adapter""" + + def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): + super().__init__() + self.unet = unet + self.image_proj_model = image_proj_model + self.adapter_modules = adapter_modules + + if ckpt_path is not None: + self.load_from_checkpoint(ckpt_path) + + def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): + ip_tokens = self.image_proj_model(image_embeds) + encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) + # Predict the noise residual + noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample + return noise_pred + + def load_from_checkpoint(self, ckpt_path: str): + # Calculate original checksums + orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + state_dict = torch.load(ckpt_path, map_location="cpu") + + # Check if 'latents' exists in both the saved state_dict and the current model's state_dict + strict_load_image_proj_model = True + if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict(): + # Check if the shapes are mismatched + if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape: + print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.") + print("Removing 'latents' from checkpoint and loading the rest of the weights.") + del state_dict["image_proj"]["latents"] + strict_load_image_proj_model = False + + # Load state dict for image_proj_model and adapter_modules + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) + + # Calculate new checksums + new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + # Verify if the weights have changed + assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" + assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" + + print(f"Successfully loaded weights from checkpoint {ckpt_path}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_ip_adapter_path", + type=str, + default=None, + help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", + ) + parser.add_argument( + "--num_tokens", + type=int, + default=16, + help="Number of tokens to query from the CLIP image encoding.", + ) + parser.add_argument( + "--data_json_file", + type=str, + default=None, + required=True, + help="Training data", + ) + parser.add_argument( + "--data_root_path", + type=str, + default="", + required=True, + help="Training data root path", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + required=True, + help="Path to CLIP image encoder", + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-ip_adapter", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help=("The resolution for input images"), + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Learning rate to use.", + ) + parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--save_steps", + type=int, + default=2000, + help=("Save a checkpoint of the training state every X updates"), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) + # freeze parameters of models to save more memory + unet.requires_grad_(False) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + image_encoder.requires_grad_(False) + + # ip-adapter-plus + image_proj_model = Resampler( + dim=unet.config.cross_attention_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=args.num_tokens, + embedding_dim=image_encoder.config.hidden_size, + output_dim=unet.config.cross_attention_dim, + ff_mult=4, + ) + # init adapter modules + attn_procs = {} + unet_sd = unet.state_dict() + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + layer_name = name.split(".processor")[0] + weights = { + "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], + "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], + } + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens + ) + attn_procs[name].load_state_dict(weights) + unet.set_attn_processor(attn_procs) + adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) + + ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + # unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + image_encoder.to(accelerator.device, dtype=weight_dtype) + + # optimizer + params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters()) + optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) + + # dataloader + train_dataset = MyDataset( + args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Prepare everything with our `accelerator`. + ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) + + global_step = 0 + for epoch in range(0, args.num_train_epochs): + begin = time.perf_counter() + for step, batch in enumerate(train_dataloader): + load_data_time = time.perf_counter() - begin + with accelerator.accumulate(ip_adapter): + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode( + batch["images"].to(accelerator.device, dtype=weight_dtype) + ).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + clip_images = [] + for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]): + if drop_image_embed == 1: + clip_images.append(torch.zeros_like(clip_image)) + else: + clip_images.append(clip_image) + clip_images = torch.stack(clip_images, dim=0) + with torch.no_grad(): + image_embeds = image_encoder( + clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True + ).hidden_states[-2] + + with torch.no_grad(): + encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] + + noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds) + + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() + + # Backpropagate + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + + if accelerator.is_main_process: + print( + "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( + epoch, step, load_data_time, time.perf_counter() - begin, avg_loss + ) + ) + + global_step += 1 + + if global_step % args.save_steps == 0: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + + begin = time.perf_counter() + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/ip_adapter/tutorial_train_sdxl.py b/examples/research_projects/ip_adapter/tutorial_train_sdxl.py new file mode 100644 index 000000000000..cd7dffe13a80 --- /dev/null +++ b/examples/research_projects/ip_adapter/tutorial_train_sdxl.py @@ -0,0 +1,520 @@ +import argparse +import itertools +import json +import os +import random +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration +from ip_adapter.ip_adapter import ImageProjModel +from ip_adapter.utils import is_torch2_available +from PIL import Image +from torchvision import transforms +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel + + +if is_torch2_available(): + from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor + from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor +else: + from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor + + +# Dataset +class MyDataset(torch.utils.data.Dataset): + def __init__( + self, + json_file, + tokenizer, + tokenizer_2, + size=1024, + center_crop=True, + t_drop_rate=0.05, + i_drop_rate=0.05, + ti_drop_rate=0.05, + image_root_path="", + ): + super().__init__() + + self.tokenizer = tokenizer + self.tokenizer_2 = tokenizer_2 + self.size = size + self.center_crop = center_crop + self.i_drop_rate = i_drop_rate + self.t_drop_rate = t_drop_rate + self.ti_drop_rate = ti_drop_rate + self.image_root_path = image_root_path + + self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}] + + self.transform = transforms.Compose( + [ + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.clip_image_processor = CLIPImageProcessor() + + def __getitem__(self, idx): + item = self.data[idx] + text = item["text"] + image_file = item["image_file"] + + # read image + raw_image = Image.open(os.path.join(self.image_root_path, image_file)) + + # original size + original_width, original_height = raw_image.size + original_size = torch.tensor([original_height, original_width]) + + image_tensor = self.transform(raw_image.convert("RGB")) + # random crop + delta_h = image_tensor.shape[1] - self.size + delta_w = image_tensor.shape[2] - self.size + assert not all([delta_h, delta_w]) + + if self.center_crop: + top = delta_h // 2 + left = delta_w // 2 + else: + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + image = transforms.functional.crop(image_tensor, top=top, left=left, height=self.size, width=self.size) + crop_coords_top_left = torch.tensor([top, left]) + + clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values + + # drop + drop_image_embed = 0 + rand_num = random.random() + if rand_num < self.i_drop_rate: + drop_image_embed = 1 + elif rand_num < (self.i_drop_rate + self.t_drop_rate): + text = "" + elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): + text = "" + drop_image_embed = 1 + + # get text and tokenize + text_input_ids = self.tokenizer( + text, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ).input_ids + + text_input_ids_2 = self.tokenizer_2( + text, + max_length=self.tokenizer_2.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ).input_ids + + return { + "image": image, + "text_input_ids": text_input_ids, + "text_input_ids_2": text_input_ids_2, + "clip_image": clip_image, + "drop_image_embed": drop_image_embed, + "original_size": original_size, + "crop_coords_top_left": crop_coords_top_left, + "target_size": torch.tensor([self.size, self.size]), + } + + def __len__(self): + return len(self.data) + + +def collate_fn(data): + images = torch.stack([example["image"] for example in data]) + text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) + text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0) + clip_images = torch.cat([example["clip_image"] for example in data], dim=0) + drop_image_embeds = [example["drop_image_embed"] for example in data] + original_size = torch.stack([example["original_size"] for example in data]) + crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data]) + target_size = torch.stack([example["target_size"] for example in data]) + + return { + "images": images, + "text_input_ids": text_input_ids, + "text_input_ids_2": text_input_ids_2, + "clip_images": clip_images, + "drop_image_embeds": drop_image_embeds, + "original_size": original_size, + "crop_coords_top_left": crop_coords_top_left, + "target_size": target_size, + } + + +class IPAdapter(torch.nn.Module): + """IP-Adapter""" + + def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): + super().__init__() + self.unet = unet + self.image_proj_model = image_proj_model + self.adapter_modules = adapter_modules + + if ckpt_path is not None: + self.load_from_checkpoint(ckpt_path) + + def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds): + ip_tokens = self.image_proj_model(image_embeds) + encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) + # Predict the noise residual + noise_pred = self.unet( + noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs + ).sample + return noise_pred + + def load_from_checkpoint(self, ckpt_path: str): + # Calculate original checksums + orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + state_dict = torch.load(ckpt_path, map_location="cpu") + + # Load state dict for image_proj_model and adapter_modules + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) + + # Calculate new checksums + new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) + new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) + + # Verify if the weights have changed + assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" + assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" + + print(f"Successfully loaded weights from checkpoint {ckpt_path}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_ip_adapter_path", + type=str, + default=None, + help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", + ) + parser.add_argument( + "--data_json_file", + type=str, + default=None, + required=True, + help="Training data", + ) + parser.add_argument( + "--data_root_path", + type=str, + default="", + required=True, + help="Training data root path", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + required=True, + help="Path to CLIP image encoder", + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-ip_adapter", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help=("The resolution for input images"), + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Learning rate to use.", + ) + parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--noise_offset", type=float, default=None, help="noise offset") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--save_steps", + type=int, + default=2000, + help=("Save a checkpoint of the training state every X updates"), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2") + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2" + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) + # freeze parameters of models to save more memory + unet.requires_grad_(False) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + text_encoder_2.requires_grad_(False) + image_encoder.requires_grad_(False) + + # ip-adapter + num_tokens = 4 + image_proj_model = ImageProjModel( + cross_attention_dim=unet.config.cross_attention_dim, + clip_embeddings_dim=image_encoder.config.projection_dim, + clip_extra_context_tokens=num_tokens, + ) + # init adapter modules + attn_procs = {} + unet_sd = unet.state_dict() + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + layer_name = name.split(".processor")[0] + weights = { + "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], + "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], + } + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens + ) + attn_procs[name].load_state_dict(weights) + unet.set_attn_processor(attn_procs) + adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) + + ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + # unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device) # use fp32 + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder_2.to(accelerator.device, dtype=weight_dtype) + image_encoder.to(accelerator.device, dtype=weight_dtype) + + # optimizer + params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters()) + optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) + + # dataloader + train_dataset = MyDataset( + args.data_json_file, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + size=args.resolution, + image_root_path=args.data_root_path, + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Prepare everything with our `accelerator`. + ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) + + global_step = 0 + for epoch in range(0, args.num_train_epochs): + begin = time.perf_counter() + for step, batch in enumerate(train_dataloader): + load_data_time = time.perf_counter() - begin + with accelerator.accumulate(ip_adapter): + # Convert images to latent space + with torch.no_grad(): + # vae of sdxl should use fp32 + latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae.dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + latents = latents.to(accelerator.device, dtype=weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to( + accelerator.device, dtype=weight_dtype + ) + + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + with torch.no_grad(): + image_embeds = image_encoder( + batch["clip_images"].to(accelerator.device, dtype=weight_dtype) + ).image_embeds + image_embeds_ = [] + for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]): + if drop_image_embed == 1: + image_embeds_.append(torch.zeros_like(image_embed)) + else: + image_embeds_.append(image_embed) + image_embeds = torch.stack(image_embeds_) + + with torch.no_grad(): + encoder_output = text_encoder( + batch["text_input_ids"].to(accelerator.device), output_hidden_states=True + ) + text_embeds = encoder_output.hidden_states[-2] + encoder_output_2 = text_encoder_2( + batch["text_input_ids_2"].to(accelerator.device), output_hidden_states=True + ) + pooled_text_embeds = encoder_output_2[0] + text_embeds_2 = encoder_output_2.hidden_states[-2] + text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat + + # add cond + add_time_ids = [ + batch["original_size"].to(accelerator.device), + batch["crop_coords_top_left"].to(accelerator.device), + batch["target_size"].to(accelerator.device), + ] + add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype) + unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids} + + noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds) + + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() + + # Backpropagate + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + + if accelerator.is_main_process: + print( + "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( + epoch, step, load_data_time, time.perf_counter() - begin, avg_loss + ) + ) + + global_step += 1 + + if global_step % args.save_steps == 0: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + + begin = time.perf_counter() + + +if __name__ == "__main__": + main()