3535from accelerate import Accelerator
3636from accelerate .logging import get_logger
3737from accelerate .utils import ProjectConfiguration , set_seed
38- from datasets import load_dataset
38+ from datasets import concatenate_datasets , load_dataset
3939from huggingface_hub import create_repo , upload_folder
4040from packaging import version
4141from torchvision import transforms
@@ -896,13 +896,19 @@ def preprocess_train(examples):
896896 # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
897897 new_fingerprint = Hasher .hash (args )
898898 new_fingerprint_for_vae = Hasher .hash (vae_path )
899- train_dataset = train_dataset .map (compute_embeddings_fn , batched = True , new_fingerprint = new_fingerprint )
900- train_dataset = train_dataset .map (
899+ train_dataset_with_embeddings = train_dataset .map (
900+ compute_embeddings_fn , batched = True , new_fingerprint = new_fingerprint
901+ )
902+ train_dataset_with_vae = train_dataset .map (
901903 compute_vae_encodings_fn ,
902904 batched = True ,
903905 batch_size = args .train_batch_size * accelerator .num_processes * args .gradient_accumulation_steps ,
904906 new_fingerprint = new_fingerprint_for_vae ,
905907 )
908+ precomputed_dataset = concatenate_datasets (
909+ [train_dataset_with_embeddings , train_dataset_with_vae .remove_columns (["image" , "text" ])], axis = 1
910+ )
911+ precomputed_dataset = precomputed_dataset .with_transform (preprocess_train )
906912
907913 del text_encoders , tokenizers , vae
908914 gc .collect ()
@@ -925,7 +931,7 @@ def collate_fn(examples):
925931
926932 # DataLoaders creation:
927933 train_dataloader = torch .utils .data .DataLoader (
928- train_dataset ,
934+ precomputed_dataset ,
929935 shuffle = True ,
930936 collate_fn = collate_fn ,
931937 batch_size = args .train_batch_size ,
@@ -976,7 +982,7 @@ def unwrap_model(model):
976982 total_batch_size = args .train_batch_size * accelerator .num_processes * args .gradient_accumulation_steps
977983
978984 logger .info ("***** Running training *****" )
979- logger .info (f" Num examples = { len (train_dataset )} " )
985+ logger .info (f" Num examples = { len (precomputed_dataset )} " )
980986 logger .info (f" Num Epochs = { args .num_train_epochs } " )
981987 logger .info (f" Instantaneous batch size per device = { args .train_batch_size } " )
982988 logger .info (f" Total train batch size (w. parallel, distributed & accumulation) = { total_batch_size } " )
0 commit comments