Skip to content

Commit 63d5e9f

Browse files
authored
Merge branch 'main' into original-lora-hunyuan-video
2 parents f682d76 + f1e0c7c commit 63d5e9f

File tree

5 files changed

+356
-145
lines changed

5 files changed

+356
-145
lines changed

examples/flux-control/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ prompt = "A couple, 4k photo, highly detailed"
121121

122122
gen_images = pipe(
123123
prompt=prompt,
124-
condition_image=image,
124+
control_image=image,
125125
num_inference_steps=50,
126126
joint_attention_kwargs={"scale": 0.9},
127127
guidance_scale=25.,
@@ -190,7 +190,7 @@ prompt = "A couple, 4k photo, highly detailed"
190190

191191
gen_images = pipe(
192192
prompt=prompt,
193-
condition_image=image,
193+
control_image=image,
194194
num_inference_steps=50,
195195
guidance_scale=25.,
196196
).images[0]
@@ -200,5 +200,5 @@ gen_images.save("output.png")
200200
## Things to note
201201

202202
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
203-
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used.
203+
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
204204
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.

examples/flux-control/train_control_flux.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
122122

123123
for _ in range(args.num_validation_images):
124124
with autocast_ctx:
125-
# need to fix in pipeline_flux_controlnet
126125
image = pipeline(
127126
prompt=validation_prompt,
128127
control_image=validation_image,
@@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
159158
images = log["images"]
160159
validation_prompt = log["validation_prompt"]
161160
validation_image = log["validation_image"]
162-
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
161+
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
163162
for image in images:
164163
image = wandb.Image(image, caption=validation_prompt)
165164
formatted_images.append(image)
@@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
188187
img_str += f"![images_{i})](./images_{i}.png)\n"
189188

190189
model_description = f"""
191-
# control-lora-{repo_id}
190+
# flux-control-{repo_id}
192191
193192
These are Control weights trained on {base_model} with new type of conditioning.
194193
{img_str}
@@ -434,14 +433,15 @@ def parse_args(input_args=None):
434433
"--conditioning_image_column",
435434
type=str,
436435
default="conditioning_image",
437-
help="The column of the dataset containing the controlnet conditioning image.",
436+
help="The column of the dataset containing the control conditioning image.",
438437
)
439438
parser.add_argument(
440439
"--caption_column",
441440
type=str,
442441
default="text",
443442
help="The column of the dataset containing a caption or a list of captions.",
444443
)
444+
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
445445
parser.add_argument(
446446
"--max_train_samples",
447447
type=int,
@@ -468,7 +468,7 @@ def parse_args(input_args=None):
468468
default=None,
469469
nargs="+",
470470
help=(
471-
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
471+
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
472472
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
473473
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
474474
" `--validation_image` that will be used with all `--validation_prompt`s."
@@ -505,7 +505,11 @@ def parse_args(input_args=None):
505505
default=None,
506506
help="Path to the jsonl file containing the training data.",
507507
)
508-
508+
parser.add_argument(
509+
"--only_target_transformer_blocks",
510+
action="store_true",
511+
help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
512+
)
509513
parser.add_argument(
510514
"--guidance_scale",
511515
type=float,
@@ -581,7 +585,7 @@ def parse_args(input_args=None):
581585

582586
if args.resolution % 8 != 0:
583587
raise ValueError(
584-
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
588+
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
585589
)
586590

587591
return args
@@ -665,7 +669,12 @@ def preprocess_train(examples):
665669
conditioning_images = [image_transforms(image) for image in conditioning_images]
666670
examples["pixel_values"] = images
667671
examples["conditioning_pixel_values"] = conditioning_images
668-
examples["captions"] = list(examples[args.caption_column])
672+
673+
is_caption_list = isinstance(examples[args.caption_column][0], list)
674+
if is_caption_list:
675+
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
676+
else:
677+
examples["captions"] = list(examples[args.caption_column])
669678

670679
return examples
671680

@@ -765,7 +774,8 @@ def main(args):
765774
subfolder="scheduler",
766775
)
767776
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
768-
flux_transformer.requires_grad_(True)
777+
if not args.only_target_transformer_blocks:
778+
flux_transformer.requires_grad_(True)
769779
vae.requires_grad_(False)
770780

771781
# cast down and move to the CPU
@@ -797,6 +807,12 @@ def main(args):
797807
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
798808
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
799809

810+
if args.only_target_transformer_blocks:
811+
flux_transformer.x_embedder.requires_grad_(True)
812+
for name, module in flux_transformer.named_modules():
813+
if "transformer_blocks" in name:
814+
module.requires_grad_(True)
815+
800816
def unwrap_model(model):
801817
model = accelerator.unwrap_model(model)
802818
model = model._orig_mod if is_compiled_module(model) else model
@@ -974,6 +990,32 @@ def load_model_hook(models, input_dir):
974990
else:
975991
initial_global_step = 0
976992

993+
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
994+
logger.info("Logging some dataset samples.")
995+
formatted_images = []
996+
formatted_control_images = []
997+
all_prompts = []
998+
for i, batch in enumerate(train_dataloader):
999+
images = (batch["pixel_values"] + 1) / 2
1000+
control_images = (batch["conditioning_pixel_values"] + 1) / 2
1001+
prompts = batch["captions"]
1002+
1003+
if len(formatted_images) > 10:
1004+
break
1005+
1006+
for img, control_img, prompt in zip(images, control_images, prompts):
1007+
formatted_images.append(img)
1008+
formatted_control_images.append(control_img)
1009+
all_prompts.append(prompt)
1010+
1011+
logged_artifacts = []
1012+
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
1013+
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
1014+
logged_artifacts.append(wandb.Image(img, caption=prompt))
1015+
1016+
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
1017+
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
1018+
9771019
progress_bar = tqdm(
9781020
range(0, args.max_train_steps),
9791021
initial=initial_global_step,

examples/flux-control/train_control_lora_flux.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
132132

133133
for _ in range(args.num_validation_images):
134134
with autocast_ctx:
135-
# need to fix in pipeline_flux_controlnet
136135
image = pipeline(
137136
prompt=validation_prompt,
138137
control_image=validation_image,
@@ -169,7 +168,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
169168
images = log["images"]
170169
validation_prompt = log["validation_prompt"]
171170
validation_image = log["validation_image"]
172-
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
171+
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
173172
for image in images:
174173
image = wandb.Image(image, caption=validation_prompt)
175174
formatted_images.append(image)
@@ -198,7 +197,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
198197
img_str += f"![images_{i})](./images_{i}.png)\n"
199198

200199
model_description = f"""
201-
# controlnet-lora-{repo_id}
200+
# control-lora-{repo_id}
202201
203202
These are Control LoRA weights trained on {base_model} with new type of conditioning.
204203
{img_str}
@@ -256,7 +255,7 @@ def parse_args(input_args=None):
256255
parser.add_argument(
257256
"--output_dir",
258257
type=str,
259-
default="controlnet-lora",
258+
default="control-lora",
260259
help="The output directory where the model predictions and checkpoints will be written.",
261260
)
262261
parser.add_argument(
@@ -466,14 +465,15 @@ def parse_args(input_args=None):
466465
"--conditioning_image_column",
467466
type=str,
468467
default="conditioning_image",
469-
help="The column of the dataset containing the controlnet conditioning image.",
468+
help="The column of the dataset containing the control conditioning image.",
470469
)
471470
parser.add_argument(
472471
"--caption_column",
473472
type=str,
474473
default="text",
475474
help="The column of the dataset containing a caption or a list of captions.",
476475
)
476+
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
477477
parser.add_argument(
478478
"--max_train_samples",
479479
type=int,
@@ -500,7 +500,7 @@ def parse_args(input_args=None):
500500
default=None,
501501
nargs="+",
502502
help=(
503-
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
503+
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
504504
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
505505
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
506506
" `--validation_image` that will be used with all `--validation_prompt`s."
@@ -613,7 +613,7 @@ def parse_args(input_args=None):
613613

614614
if args.resolution % 8 != 0:
615615
raise ValueError(
616-
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
616+
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
617617
)
618618

619619
return args
@@ -697,7 +697,12 @@ def preprocess_train(examples):
697697
conditioning_images = [image_transforms(image) for image in conditioning_images]
698698
examples["pixel_values"] = images
699699
examples["conditioning_pixel_values"] = conditioning_images
700-
examples["captions"] = list(examples[args.caption_column])
700+
701+
is_caption_list = isinstance(examples[args.caption_column][0], list)
702+
if is_caption_list:
703+
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
704+
else:
705+
examples["captions"] = list(examples[args.caption_column])
701706

702707
return examples
703708

@@ -1132,6 +1137,32 @@ def load_model_hook(models, input_dir):
11321137
else:
11331138
initial_global_step = 0
11341139

1140+
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
1141+
logger.info("Logging some dataset samples.")
1142+
formatted_images = []
1143+
formatted_control_images = []
1144+
all_prompts = []
1145+
for i, batch in enumerate(train_dataloader):
1146+
images = (batch["pixel_values"] + 1) / 2
1147+
control_images = (batch["conditioning_pixel_values"] + 1) / 2
1148+
prompts = batch["captions"]
1149+
1150+
if len(formatted_images) > 10:
1151+
break
1152+
1153+
for img, control_img, prompt in zip(images, control_images, prompts):
1154+
formatted_images.append(img)
1155+
formatted_control_images.append(control_img)
1156+
all_prompts.append(prompt)
1157+
1158+
logged_artifacts = []
1159+
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
1160+
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
1161+
logged_artifacts.append(wandb.Image(img, caption=prompt))
1162+
1163+
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
1164+
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
1165+
11351166
progress_bar = tqdm(
11361167
range(0, args.max_train_steps),
11371168
initial=initial_global_step,

examples/research_projects/instructpix2pix_lora/README.md

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,42 @@
22
This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).
33
This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.
44

5+
## Running locally with PyTorch
6+
### Installing the dependencies
7+
8+
Before running the scripts, make sure to install the library's training dependencies:
9+
10+
**Important**
11+
12+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
13+
```bash
14+
git clone https://github.com/huggingface/diffusers
15+
cd diffusers
16+
pip install .
17+
```
18+
19+
Then cd in the example folder and run
20+
```bash
21+
pip install -r requirements.txt
22+
```
23+
24+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
25+
26+
```bash
27+
accelerate config
28+
```
29+
30+
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
31+
32+
533
## Training script example
634

735
```bash
836
export MODEL_ID="timbrooks/instruct-pix2pix"
937
export DATASET_ID="instruction-tuning-sd/cartoonization"
1038
export OUTPUT_DIR="instructPix2Pix-cartoonization"
1139

12-
accelerate launch finetune_instruct_pix2pix.py \
40+
accelerate launch train_instruct_pix2pix_lora.py \
1341
--pretrained_model_name_or_path=$MODEL_ID \
1442
--dataset_name=$DATASET_ID \
1543
--enable_xformers_memory_efficient_attention \
@@ -24,7 +52,10 @@ accelerate launch finetune_instruct_pix2pix.py \
2452
--rank=4 \
2553
--output_dir=$OUTPUT_DIR \
2654
--report_to=wandb \
27-
--push_to_hub
55+
--push_to_hub \
56+
--original_image_column="original_image" \
57+
--edited_image_column="cartoonized_image" \
58+
--edit_prompt_column="edit_prompt"
2859
```
2960

3061
## Inference

0 commit comments

Comments
 (0)