Skip to content

Commit dd893b9

Browse files
committed
readme
1 parent 8aca9ee commit dd893b9

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

examples/dreambooth/README_flux.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,41 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \
294294
Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
295295
perform as expected.
296296

297+
Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets:
298+
299+
* Condition image
300+
* Target image
301+
* Instruction
302+
303+
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
304+
305+
```bash
306+
accelerate launch train_dreambooth_lora_flux_kontext.py \
307+
--pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
308+
--output_dir="kontext-i2i" \
309+
--dataset_name="kontext-community/relighting" \
310+
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
311+
--mixed_precision="bf16" \
312+
--resolution=1024 \
313+
--train_batch_size=1 \
314+
--guidance_scale=1 \
315+
--gradient_accumulation_steps=4 \
316+
--gradient_checkpointing \
317+
--optimizer="adamw" \
318+
--use_8bit_adam \
319+
--cache_latents \
320+
--learning_rate=1e-4 \
321+
--lr_scheduler="constant" \
322+
--lr_warmup_steps=0 \
323+
--max_train_steps=500 \
324+
--seed="0"
325+
```
326+
327+
More generally, when performing I2I fine-tuning, we expect you to:
328+
329+
* Have a dataset `kontext-community/relighting`
330+
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
331+
297332
### Misc notes
298333

299334
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,6 @@ def parse_args(input_args=None):
335335
"--instance_prompt",
336336
type=str,
337337
default=None,
338-
required=True,
339338
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
340339
)
341340
parser.add_argument(
@@ -740,6 +739,7 @@ def parse_args(input_args=None):
740739
assert args.image_column is not None
741740
assert args.caption_column is not None
742741
assert args.dataset_name is not None
742+
assert not args.train_text_encoder
743743

744744
return args
745745

@@ -870,7 +870,7 @@ def __init__(
870870
random_flip=args.random_flip,
871871
)
872872
self.pixel_values.append((image, bucket_idx))
873-
if dest_image:
873+
if dest_image is not None:
874874
self.cond_pixel_values.append((dest_image, bucket_idx))
875875

876876
self.num_instance_images = len(self.instance_images)
@@ -965,7 +965,7 @@ def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=
965965
if dest_image is not None:
966966
dest_image = normalize(to_tensor(dest_image))
967967

968-
return (image, dest_image) if dest_image is not None else image
968+
return (image, dest_image) if dest_image is not None else (image, None)
969969

970970

971971
def collate_fn(examples, with_prior_preservation=False):
@@ -1606,6 +1606,8 @@ def load_model_hook(models, input_dir):
16061606
center_crop=args.center_crop,
16071607
args=args,
16081608
)
1609+
if args.cond_image_column is not None:
1610+
logger.info("I2I fine-tuning enabled.")
16091611
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
16101612
train_dataloader = torch.utils.data.DataLoader(
16111613
train_dataset,
@@ -1645,6 +1647,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16451647

16461648
# Clear the memory here
16471649
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
1650+
text_encoder_one.cpu(), text_encoder_two.cpu()
16481651
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
16491652
free_memory()
16501653

@@ -1676,6 +1679,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16761679
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
16771680
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
16781681

1682+
elif train_dataset.custom_instance_prompts and not args.train_text_encoder:
1683+
cached_text_embeddings = []
1684+
for batch in tqdm(train_dataloader, desc="Embedding prompts"):
1685+
batch_prompts = batch["prompts"]
1686+
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
1687+
batch_prompts, text_encoders, tokenizers
1688+
)
1689+
cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids))
1690+
1691+
if args.validation_prompt is None:
1692+
text_encoder_one.cpu(), text_encoder_two.cpu()
1693+
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
1694+
free_memory()
1695+
16791696
vae_config_shift_factor = vae.config.shift_factor
16801697
vae_config_scaling_factor = vae.config.scaling_factor
16811698
vae_config_block_out_channels = vae.config.block_out_channels
@@ -1696,6 +1713,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16961713
cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
16971714

16981715
if args.validation_prompt is None:
1716+
vae.cpu()
16991717
del vae
17001718
free_memory()
17011719

@@ -1837,10 +1855,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18371855
# encode batch prompts when custom prompts are provided for each image -
18381856
if train_dataset.custom_instance_prompts:
18391857
if not args.train_text_encoder:
1840-
# Should find a way to precompute these.
1841-
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
1842-
prompts, text_encoders, tokenizers
1843-
)
1858+
prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step]
18441859
else:
18451860
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
18461861
tokens_two = tokenize_prompt(
@@ -1942,6 +1957,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
19421957
height=model_input.shape[2],
19431958
width=model_input.shape[3],
19441959
)
1960+
orig_inp_shape = packed_noisy_model_input.shape
19451961
if has_image_input:
19461962
packed_cond_input = FluxKontextPipeline._pack_latents(
19471963
cond_model_input,
@@ -1968,6 +1984,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
19681984
img_ids=latent_image_ids,
19691985
return_dict=False,
19701986
)[0]
1987+
if has_image_input:
1988+
model_pred = model_pred[:, : orig_inp_shape[1]]
19711989
model_pred = FluxKontextPipeline._unpack_latents(
19721990
model_pred,
19731991
height=model_input.shape[2] * vae_scale_factor,

0 commit comments

Comments
 (0)