Skip to content

Commit 9a2600e

Browse files
kopylsayakpaullhoestq
authored
Map speedup (#6745)
* Speed up dataset mapping * Fix missing columns * Remove cache files cleanup * Update examples/text_to_image/train_text_to_image_sdxl.py * make style * Fix code style * style * Empty-Commit --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 5f150c4 commit 9a2600e

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from accelerate import Accelerator
3636
from accelerate.logging import get_logger
3737
from accelerate.utils import ProjectConfiguration, set_seed
38-
from datasets import load_dataset
38+
from datasets import concatenate_datasets, load_dataset
3939
from huggingface_hub import create_repo, upload_folder
4040
from packaging import version
4141
from torchvision import transforms
@@ -896,13 +896,19 @@ def preprocess_train(examples):
896896
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
897897
new_fingerprint = Hasher.hash(args)
898898
new_fingerprint_for_vae = Hasher.hash(vae_path)
899-
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
900-
train_dataset = train_dataset.map(
899+
train_dataset_with_embeddings = train_dataset.map(
900+
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
901+
)
902+
train_dataset_with_vae = train_dataset.map(
901903
compute_vae_encodings_fn,
902904
batched=True,
903905
batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
904906
new_fingerprint=new_fingerprint_for_vae,
905907
)
908+
precomputed_dataset = concatenate_datasets(
909+
[train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
910+
)
911+
precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
906912

907913
del text_encoders, tokenizers, vae
908914
gc.collect()
@@ -925,7 +931,7 @@ def collate_fn(examples):
925931

926932
# DataLoaders creation:
927933
train_dataloader = torch.utils.data.DataLoader(
928-
train_dataset,
934+
precomputed_dataset,
929935
shuffle=True,
930936
collate_fn=collate_fn,
931937
batch_size=args.train_batch_size,
@@ -976,7 +982,7 @@ def unwrap_model(model):
976982
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
977983

978984
logger.info("***** Running training *****")
979-
logger.info(f" Num examples = {len(train_dataset)}")
985+
logger.info(f" Num examples = {len(precomputed_dataset)}")
980986
logger.info(f" Num Epochs = {args.num_train_epochs}")
981987
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
982988
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")

0 commit comments

Comments
 (0)