-
Couldn't load subscription status.
- Fork 6.5k
[research_projects] add shortened flux training script with quantization #11743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
89a3af6
f5a0a4d
4355b20
5648aaa
ae6bd61
ba5144b
4a5f73a
dc7932e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,338 @@ | ||||||||
| import copy | ||||||||
| import logging | ||||||||
| import math | ||||||||
| import os | ||||||||
| from pathlib import Path | ||||||||
| import shutil | ||||||||
|
|
||||||||
| import numpy as np | ||||||||
| import pandas as pd | ||||||||
| import torch | ||||||||
| import transformers | ||||||||
| from accelerate import Accelerator, DistributedType | ||||||||
| from accelerate.logging import get_logger | ||||||||
| from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed | ||||||||
| from datasets import load_dataset | ||||||||
| from huggingface_hub.utils import insecure_hashlib | ||||||||
| from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict | ||||||||
| from peft.utils import get_peft_model_state_dict | ||||||||
| from PIL.ImageOps import exif_transpose | ||||||||
| from torch.utils.data import Dataset | ||||||||
| from torchvision import transforms | ||||||||
| from torchvision.transforms.functional import crop | ||||||||
| from tqdm.auto import tqdm | ||||||||
|
|
||||||||
| import diffusers | ||||||||
| from diffusers import ( | ||||||||
| AutoencoderKL, BitsAndBytesConfig, FlowMatchEulerDiscreteScheduler, | ||||||||
| FluxPipeline, FluxTransformer2DModel, | ||||||||
| ) | ||||||||
| from diffusers.optimization import get_scheduler | ||||||||
| from diffusers.training_utils import ( | ||||||||
| cast_training_params, compute_density_for_timestep_sampling, | ||||||||
| compute_loss_weighting_for_sd3, free_memory, | ||||||||
| ) | ||||||||
| from diffusers.utils import convert_unet_state_dict_to_peft, is_wandb_available | ||||||||
| from diffusers.utils.torch_utils import is_compiled_module | ||||||||
|
|
||||||||
| logger = get_logger(__name__) | ||||||||
|
|
||||||||
| class DreamBoothDataset(Dataset): | ||||||||
| def __init__(self, data_df_path, dataset_name, width, height, max_sequence_length=77): | ||||||||
| self.width, self.height, self.max_sequence_length = width, height, max_sequence_length | ||||||||
| self.data_df_path = Path(data_df_path) | ||||||||
| if not self.data_df_path.exists(): | ||||||||
| raise ValueError("`data_df_path` doesn't exists.") | ||||||||
|
|
||||||||
| dataset = load_dataset(dataset_name, split="train") | ||||||||
| self.instance_images = [sample["image"] for sample in dataset] | ||||||||
| self.image_hashes = [insecure_hashlib.sha256(img.tobytes()).hexdigest() for img in self.instance_images] | ||||||||
| self.pixel_values = self._apply_transforms() | ||||||||
| self.data_dict = self._map_embeddings() | ||||||||
| self._length = len(self.instance_images) | ||||||||
|
|
||||||||
| def __len__(self): | ||||||||
| return self._length | ||||||||
|
|
||||||||
| def __getitem__(self, index): | ||||||||
| idx = index % len(self.instance_images) | ||||||||
| hash_key = self.image_hashes[idx] | ||||||||
| prompt_embeds, pooled_prompt_embeds, text_ids = self.data_dict[hash_key] | ||||||||
| return { | ||||||||
| "instance_images": self.pixel_values[idx], | ||||||||
| "prompt_embeds": prompt_embeds, | ||||||||
| "pooled_prompt_embeds": pooled_prompt_embeds, | ||||||||
| "text_ids": text_ids, | ||||||||
| } | ||||||||
|
|
||||||||
| def _apply_transforms(self): | ||||||||
| transform = transforms.Compose([ | ||||||||
| transforms.Resize((self.height, self.width), interpolation=transforms.InterpolationMode.BILINEAR), | ||||||||
| transforms.RandomCrop((self.height, self.width)), | ||||||||
| transforms.ToTensor(), | ||||||||
| transforms.Normalize([0.5], [0.5]), | ||||||||
| ]) | ||||||||
|
|
||||||||
| pixel_values = [] | ||||||||
| for image in self.instance_images: | ||||||||
| image = exif_transpose(image).convert("RGB") if image.mode != "RGB" else exif_transpose(image) | ||||||||
| pixel_values.append(transform(image)) | ||||||||
| return pixel_values | ||||||||
|
|
||||||||
| def _map_embeddings(self): | ||||||||
| df = pd.read_parquet(self.data_df_path) | ||||||||
| data_dict = {} | ||||||||
| for _, row in df.iterrows(): | ||||||||
| prompt_embeds = torch.from_numpy(np.array(row["prompt_embeds"]).reshape(self.max_sequence_length, 4096)) | ||||||||
| pooled_prompt_embeds = torch.from_numpy(np.array(row["pooled_prompt_embeds"]).reshape(768)) | ||||||||
| text_ids = torch.from_numpy(np.array(row["text_ids"]).reshape(77, 3)) | ||||||||
| data_dict[row["image_hash"]] = (prompt_embeds, pooled_prompt_embeds, text_ids) | ||||||||
| return data_dict | ||||||||
|
|
||||||||
| def collate_fn(examples): | ||||||||
| pixel_values = torch.stack([ex["instance_images"] for ex in examples]).float() | ||||||||
| pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | ||||||||
| prompt_embeds = torch.stack([ex["prompt_embeds"] for ex in examples]) | ||||||||
| pooled_prompt_embeds = torch.stack([ex["pooled_prompt_embeds"] for ex in examples]) | ||||||||
| text_ids = torch.stack([ex["text_ids"] for ex in examples])[0] | ||||||||
|
|
||||||||
| return { | ||||||||
| "pixel_values": pixel_values, | ||||||||
| "prompt_embeds": prompt_embeds, | ||||||||
| "pooled_prompt_embeds": pooled_prompt_embeds, | ||||||||
| "text_ids": text_ids, | ||||||||
| } | ||||||||
|
|
||||||||
| def main(args): | ||||||||
| # Setup accelerator | ||||||||
| accelerator = Accelerator( | ||||||||
| gradient_accumulation_steps=args.gradient_accumulation_steps, | ||||||||
| mixed_precision=args.mixed_precision, | ||||||||
| log_with=args.report_to, | ||||||||
| project_config=ProjectConfiguration(project_dir=args.output_dir, logging_dir=Path(args.output_dir, "logs")), | ||||||||
| kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)], | ||||||||
| ) | ||||||||
|
|
||||||||
| # Setup logging | ||||||||
| logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO) | ||||||||
| if accelerator.is_local_main_process: | ||||||||
| transformers.utils.logging.set_verbosity_warning() | ||||||||
| diffusers.utils.logging.set_verbosity_info() | ||||||||
| else: | ||||||||
| transformers.utils.logging.set_verbosity_error() | ||||||||
| diffusers.utils.logging.set_verbosity_error() | ||||||||
|
|
||||||||
| set_seed(args.seed) if args.seed is not None else None | ||||||||
| os.makedirs(args.output_dir, exist_ok=True) if accelerator.is_main_process else None | ||||||||
|
|
||||||||
| # Load models with quantization | ||||||||
| noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") | ||||||||
| noise_scheduler_copy = copy.deepcopy(noise_scheduler) | ||||||||
|
|
||||||||
| vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") | ||||||||
|
|
||||||||
| nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16) | ||||||||
| transformer = FluxTransformer2DModel.from_pretrained( | ||||||||
| args.pretrained_model_name_or_path, subfolder="transformer", | ||||||||
| quantization_config=nf4_config, torch_dtype=torch.float16 | ||||||||
| ) | ||||||||
| transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) | ||||||||
|
|
||||||||
| # Freeze models and setup LoRA | ||||||||
| transformer.requires_grad_(False) | ||||||||
| vae.requires_grad_(False) | ||||||||
| vae.to(accelerator.device, dtype=torch.float16) | ||||||||
| if args.gradient_checkpointing: | ||||||||
| transformer.enable_gradient_checkpointing() | ||||||||
|
|
||||||||
| # now we will add new LoRA weights to the attention layers | ||||||||
| transformer_lora_config = LoraConfig( | ||||||||
| r=args.rank, | ||||||||
| lora_alpha=args.rank, | ||||||||
| init_lora_weights="gaussian", | ||||||||
| target_modules=["to_k", "to_q", "to_v", "to_out.0"], | ||||||||
| ) | ||||||||
| transformer.add_adapter(transformer_lora_config) | ||||||||
|
|
||||||||
| print(f"trainable params: {transformer.num_parameters(only_trainable=True)} || all params: {transformer.num_parameters()}") | ||||||||
|
|
||||||||
| # Setup optimizer | ||||||||
| import bitsandbytes as bnb | ||||||||
| optimizer = bnb.optim.AdamW8bit( | ||||||||
| [{"params": list(filter(lambda p: p.requires_grad, transformer.parameters())), "lr": args.learning_rate}], | ||||||||
| betas=(0.9, 0.999), weight_decay=1e-04, eps=1e-08 | ||||||||
| ) | ||||||||
|
|
||||||||
| # Setup dataset and dataloader | ||||||||
| train_dataset = DreamBoothDataset(args.data_df_path, "derekl35/alphonse-mucha-style", args.width, args.height) | ||||||||
| train_dataloader = torch.utils.data.DataLoader( | ||||||||
| train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn | ||||||||
| ) | ||||||||
|
|
||||||||
| # Cache latents | ||||||||
| vae_config = vae.config | ||||||||
| latents_cache = [] | ||||||||
| for batch in tqdm(train_dataloader, desc="Caching latents"): | ||||||||
| with torch.no_grad(): | ||||||||
| pixel_values = batch["pixel_values"].to(accelerator.device, dtype=torch.float16) | ||||||||
| latents_cache.append(vae.encode(pixel_values).latent_dist) | ||||||||
|
|
||||||||
| del vae | ||||||||
| free_memory() | ||||||||
DerekLiu35 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|
|
||||||||
| # Setup scheduler and training steps | ||||||||
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | ||||||||
| args.max_train_steps = args.max_train_steps or args.num_train_epochs * num_update_steps_per_epoch | ||||||||
|
|
||||||||
| args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||||||||
|
|
||||||||
| lr_scheduler = get_scheduler("constant", optimizer=optimizer, num_warmup_steps=0, num_training_steps=args.max_train_steps) | ||||||||
|
|
||||||||
| # Prepare for training | ||||||||
| transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(transformer, optimizer, train_dataloader, lr_scheduler) | ||||||||
|
|
||||||||
| # Register save/load hooks | ||||||||
| def unwrap_model(model): | ||||||||
| model = accelerator.unwrap_model(model) | ||||||||
| return model._orig_mod if is_compiled_module(model) else model | ||||||||
|
|
||||||||
| def save_model_hook(models, weights, output_dir): | ||||||||
|
||||||||
| if accelerator.is_main_process: | ||||||||
| for model in models: | ||||||||
| if isinstance(unwrap_model(model), type(unwrap_model(transformer))): | ||||||||
| lora_layers = get_peft_model_state_dict(unwrap_model(model)) | ||||||||
| FluxPipeline.save_lora_weights(output_dir, transformer_lora_layers=lora_layers, text_encoder_lora_layers=None) | ||||||||
| weights.pop() if weights else None | ||||||||
|
|
||||||||
| accelerator.register_save_state_pre_hook(save_model_hook) | ||||||||
| cast_training_params([transformer], dtype=torch.float32) if args.mixed_precision == "fp16" else None | ||||||||
|
|
||||||||
| # Initialize tracking | ||||||||
| accelerator.init_trackers("dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args)) if accelerator.is_main_process else None | ||||||||
|
||||||||
| accelerator.init_trackers("dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args)) if accelerator.is_main_process else None | |
| if accelerator.is_main_proces | |
| accelerator.init_trackers("dreambooth-flux-dev-lora-alphonse-mucha", config=vars(args)) |
Can we do it while creating the output folder?
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args.weighting_scheme
Let's also make constants for the magic numbers.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have the weighting_scheme args. So, let's use it from there.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a note on where this is coming from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should cast the LoRA params to FP32. Do you have a full run with this script that works without FP32 upcasting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was casting to FP32 below with
cast_training_params([transformer], dtype=torch.float32) if args.mixed_precision == "fp16" else Nonebelow (probably will move it over here and change it to match original training script better.I do have a full run with this script with reasonable results without FP32 upcasting.
But, I noticed in the loss curves are slightly different between nano script (rare-voice-24 run) and original script (fanciful-totem-2) so I will need to find where the discrepancy is coming from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it doesn't affect results, probably okay