diff --git a/examples/flux-control/README.md b/examples/flux-control/README.md index 26ad9d06a2af..14afa499db0d 100644 --- a/examples/flux-control/README.md +++ b/examples/flux-control/README.md @@ -121,7 +121,7 @@ prompt = "A couple, 4k photo, highly detailed" gen_images = pipe( prompt=prompt, - condition_image=image, + control_image=image, num_inference_steps=50, joint_attention_kwargs={"scale": 0.9}, guidance_scale=25., @@ -190,7 +190,7 @@ prompt = "A couple, 4k photo, highly detailed" gen_images = pipe( prompt=prompt, - condition_image=image, + control_image=image, num_inference_steps=50, guidance_scale=25., ).images[0] @@ -200,5 +200,5 @@ gen_images.save("output.png") ## Things to note * The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗 -* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used. +* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified. * We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality. \ No newline at end of file diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index 35f9a5f80342..7d0e28069054 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f for _ in range(args.num_validation_images): with autocast_ctx: - # need to fix in pipeline_flux_controlnet image = pipeline( prompt=validation_prompt, control_image=validation_image, @@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + formatted_images.append(wandb.Image(validation_image, caption="Conditioning")) for image in images: image = wandb.Image(image, caption=validation_prompt) formatted_images.append(image) @@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N img_str += f"![images_{i})](./images_{i}.png)\n" model_description = f""" -# control-lora-{repo_id} +# flux-control-{repo_id} These are Control weights trained on {base_model} with new type of conditioning. {img_str} @@ -434,7 +433,7 @@ def parse_args(input_args=None): "--conditioning_image_column", type=str, default="conditioning_image", - help="The column of the dataset containing the controlnet conditioning image.", + help="The column of the dataset containing the control conditioning image.", ) parser.add_argument( "--caption_column", @@ -442,6 +441,7 @@ def parse_args(input_args=None): default="text", help="The column of the dataset containing a caption or a list of captions.", ) + parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.") parser.add_argument( "--max_train_samples", type=int, @@ -468,7 +468,7 @@ def parse_args(input_args=None): default=None, nargs="+", help=( - "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + "A set of paths to the control conditioning image be evaluated every `--validation_steps`" " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" " `--validation_image` that will be used with all `--validation_prompt`s." @@ -505,7 +505,11 @@ def parse_args(input_args=None): default=None, help="Path to the jsonl file containing the training data.", ) - + parser.add_argument( + "--only_target_transformer_blocks", + action="store_true", + help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).", + ) parser.add_argument( "--guidance_scale", type=float, @@ -581,7 +585,7 @@ def parse_args(input_args=None): if args.resolution % 8 != 0: raise ValueError( - "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer." ) return args @@ -665,7 +669,12 @@ def preprocess_train(examples): conditioning_images = [image_transforms(image) for image in conditioning_images] examples["pixel_values"] = images examples["conditioning_pixel_values"] = conditioning_images - examples["captions"] = list(examples[args.caption_column]) + + is_caption_list = isinstance(examples[args.caption_column][0], list) + if is_caption_list: + examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]] + else: + examples["captions"] = list(examples[args.caption_column]) return examples @@ -765,7 +774,8 @@ def main(args): subfolder="scheduler", ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - flux_transformer.requires_grad_(True) + if not args.only_target_transformer_blocks: + flux_transformer.requires_grad_(True) vae.requires_grad_(False) # cast down and move to the CPU @@ -797,6 +807,12 @@ def main(args): assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) + if args.only_target_transformer_blocks: + flux_transformer.x_embedder.requires_grad_(True) + for name, module in flux_transformer.named_modules(): + if "transformer_blocks" in name: + module.requires_grad_(True) + def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model @@ -974,6 +990,32 @@ def load_model_hook(models, input_dir): else: initial_global_step = 0 + if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples: + logger.info("Logging some dataset samples.") + formatted_images = [] + formatted_control_images = [] + all_prompts = [] + for i, batch in enumerate(train_dataloader): + images = (batch["pixel_values"] + 1) / 2 + control_images = (batch["conditioning_pixel_values"] + 1) / 2 + prompts = batch["captions"] + + if len(formatted_images) > 10: + break + + for img, control_img, prompt in zip(images, control_images, prompts): + formatted_images.append(img) + formatted_control_images.append(control_img) + all_prompts.append(prompt) + + logged_artifacts = [] + for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts): + logged_artifacts.append(wandb.Image(control_img, caption="Conditioning")) + logged_artifacts.append(wandb.Image(img, caption=prompt)) + + wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"] + wandb_tracker[0].log({"dataset_samples": logged_artifacts}) + progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 99a05d54832f..44c684395849 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -132,7 +132,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f for _ in range(args.num_validation_images): with autocast_ctx: - # need to fix in pipeline_flux_controlnet image = pipeline( prompt=validation_prompt, control_image=validation_image, @@ -169,7 +168,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + formatted_images.append(wandb.Image(validation_image, caption="Conditioning")) for image in images: image = wandb.Image(image, caption=validation_prompt) formatted_images.append(image) @@ -198,7 +197,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N img_str += f"![images_{i})](./images_{i}.png)\n" model_description = f""" -# controlnet-lora-{repo_id} +# control-lora-{repo_id} These are Control LoRA weights trained on {base_model} with new type of conditioning. {img_str} @@ -256,7 +255,7 @@ def parse_args(input_args=None): parser.add_argument( "--output_dir", type=str, - default="controlnet-lora", + default="control-lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( @@ -466,7 +465,7 @@ def parse_args(input_args=None): "--conditioning_image_column", type=str, default="conditioning_image", - help="The column of the dataset containing the controlnet conditioning image.", + help="The column of the dataset containing the control conditioning image.", ) parser.add_argument( "--caption_column", @@ -474,6 +473,7 @@ def parse_args(input_args=None): default="text", help="The column of the dataset containing a caption or a list of captions.", ) + parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.") parser.add_argument( "--max_train_samples", type=int, @@ -500,7 +500,7 @@ def parse_args(input_args=None): default=None, nargs="+", help=( - "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + "A set of paths to the control conditioning image be evaluated every `--validation_steps`" " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" " `--validation_image` that will be used with all `--validation_prompt`s." @@ -613,7 +613,7 @@ def parse_args(input_args=None): if args.resolution % 8 != 0: raise ValueError( - "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer." ) return args @@ -697,7 +697,12 @@ def preprocess_train(examples): conditioning_images = [image_transforms(image) for image in conditioning_images] examples["pixel_values"] = images examples["conditioning_pixel_values"] = conditioning_images - examples["captions"] = list(examples[args.caption_column]) + + is_caption_list = isinstance(examples[args.caption_column][0], list) + if is_caption_list: + examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]] + else: + examples["captions"] = list(examples[args.caption_column]) return examples @@ -1132,6 +1137,32 @@ def load_model_hook(models, input_dir): else: initial_global_step = 0 + if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples: + logger.info("Logging some dataset samples.") + formatted_images = [] + formatted_control_images = [] + all_prompts = [] + for i, batch in enumerate(train_dataloader): + images = (batch["pixel_values"] + 1) / 2 + control_images = (batch["conditioning_pixel_values"] + 1) / 2 + prompts = batch["captions"] + + if len(formatted_images) > 10: + break + + for img, control_img, prompt in zip(images, control_images, prompts): + formatted_images.append(img) + formatted_control_images.append(control_img) + all_prompts.append(prompt) + + logged_artifacts = [] + for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts): + logged_artifacts.append(wandb.Image(control_img, caption="Conditioning")) + logged_artifacts.append(wandb.Image(img, caption=prompt)) + + wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"] + wandb_tracker[0].log({"dataset_samples": logged_artifacts}) + progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step,