Skip to content

Commit 6bfd13f

Browse files
asomozasayakpaul
andauthored
[SD3 Training] T5 token limit (#8564)
* initial commit * default back to 77 * better text * text correction --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent eeb7003 commit 6bfd13f

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

examples/dreambooth/README_sd3.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ To better track our training experiments, we're using the following flags in the
106106
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
107107
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
108108

109+
> [!NOTE]
110+
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
111+
109112
> [!TIP]
110113
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
111114

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,12 @@ def parse_args(input_args=None):
298298
default=None,
299299
help="The prompt to specify images in the same class as provided instance images.",
300300
)
301+
parser.add_argument(
302+
"--max_sequence_length",
303+
type=int,
304+
default=77,
305+
help="Maximum sequence length to use with with the T5 text encoder",
306+
)
301307
parser.add_argument(
302308
"--validation_prompt",
303309
type=str,
@@ -830,6 +836,7 @@ def tokenize_prompt(tokenizer, prompt):
830836
def _encode_prompt_with_t5(
831837
text_encoder,
832838
tokenizer,
839+
max_sequence_length,
833840
prompt=None,
834841
num_images_per_prompt=1,
835842
device=None,
@@ -840,7 +847,7 @@ def _encode_prompt_with_t5(
840847
text_inputs = tokenizer(
841848
prompt,
842849
padding="max_length",
843-
max_length=77,
850+
max_length=max_sequence_length,
844851
truncation=True,
845852
add_special_tokens=True,
846853
return_tensors="pt",
@@ -897,6 +904,7 @@ def encode_prompt(
897904
text_encoders,
898905
tokenizers,
899906
prompt: str,
907+
max_sequence_length,
900908
device=None,
901909
num_images_per_prompt: int = 1,
902910
):
@@ -924,6 +932,7 @@ def encode_prompt(
924932
t5_prompt_embed = _encode_prompt_with_t5(
925933
text_encoders[-1],
926934
tokenizers[-1],
935+
max_sequence_length,
927936
prompt=prompt,
928937
num_images_per_prompt=num_images_per_prompt,
929938
device=device if device is not None else text_encoders[-1].device,
@@ -1297,7 +1306,9 @@ def load_model_hook(models, input_dir):
12971306

12981307
def compute_text_embeddings(prompt, text_encoders, tokenizers):
12991308
with torch.no_grad():
1300-
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
1309+
prompt_embeds, pooled_prompt_embeds = encode_prompt(
1310+
text_encoders, tokenizers, prompt, args.max_sequence_length
1311+
)
13011312
prompt_embeds = prompt_embeds.to(accelerator.device)
13021313
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
13031314
return prompt_embeds, pooled_prompt_embeds

examples/dreambooth/train_dreambooth_sd3.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ def parse_args(input_args=None):
297297
default=None,
298298
help="The prompt to specify images in the same class as provided instance images.",
299299
)
300+
parser.add_argument(
301+
"--max_sequence_length",
302+
type=int,
303+
default=77,
304+
help="Maximum sequence length to use with with the T5 text encoder",
305+
)
300306
parser.add_argument(
301307
"--validation_prompt",
302308
type=str,
@@ -828,6 +834,7 @@ def tokenize_prompt(tokenizer, prompt):
828834
def _encode_prompt_with_t5(
829835
text_encoder,
830836
tokenizer,
837+
max_sequence_length,
831838
prompt=None,
832839
num_images_per_prompt=1,
833840
device=None,
@@ -838,7 +845,7 @@ def _encode_prompt_with_t5(
838845
text_inputs = tokenizer(
839846
prompt,
840847
padding="max_length",
841-
max_length=77,
848+
max_length=max_sequence_length,
842849
truncation=True,
843850
add_special_tokens=True,
844851
return_tensors="pt",
@@ -895,6 +902,7 @@ def encode_prompt(
895902
text_encoders,
896903
tokenizers,
897904
prompt: str,
905+
max_sequence_length,
898906
device=None,
899907
num_images_per_prompt: int = 1,
900908
):
@@ -922,6 +930,7 @@ def encode_prompt(
922930
t5_prompt_embed = _encode_prompt_with_t5(
923931
text_encoders[-1],
924932
tokenizers[-1],
933+
max_sequence_length,
925934
prompt=prompt,
926935
num_images_per_prompt=num_images_per_prompt,
927936
device=device if device is not None else text_encoders[-1].device,
@@ -1324,7 +1333,9 @@ def load_model_hook(models, input_dir):
13241333

13251334
def compute_text_embeddings(prompt, text_encoders, tokenizers):
13261335
with torch.no_grad():
1327-
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
1336+
prompt_embeds, pooled_prompt_embeds = encode_prompt(
1337+
text_encoders, tokenizers, prompt, args.max_sequence_length
1338+
)
13281339
prompt_embeds = prompt_embeds.to(accelerator.device)
13291340
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
13301341
return prompt_embeds, pooled_prompt_embeds

0 commit comments

Comments
 (0)