@@ -121,8 +121,8 @@ def __init__(
121121 weight_dtype ,
122122 device ,
123123 noise_scheduler ,
124+ vae_scale_factor ,
124125 transformer ,
125- vae ,
126126 optimizer ,
127127 dataloader ,
128128 args ,
@@ -131,13 +131,13 @@ def __init__(
131131 self .device = device
132132 self .noise_scheduler = noise_scheduler
133133 self .transformer = transformer
134- self .vae = vae
135134 self .optimizer = optimizer
136135 self .args = args
137136 self .mesh = xs .get_global_mesh ()
138137 self .dataloader = iter (dataloader )
139138 self .global_step = 0
140139 self .noise_scheduler_copy = copy .deepcopy (noise_scheduler )
140+ self .vae_scale_factor = vae_scale_factor
141141
142142 def run_optimizer (self ):
143143 self .optimizer .step ()
@@ -198,13 +198,7 @@ def step_fn(
198198 prompt_embeds = batch ["prompt_embeds" ]
199199 pooled_prompt_embeds = batch ["pooled_prompt_embeds" ]
200200 text_ids = batch ["text_ids" ]
201-
202- pixel_tensor_values = batch ["pixel_tensor_values" ]
203- model_input = self .vae .encode (pixel_tensor_values ).latent_dist .sample ()
204- model_input = (model_input - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
205- model_input = model_input .to (dtype = self .weight_dtype )
206-
207- vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
201+ model_input = batch ["model_input" ]
208202
209203 latent_image_ids = FluxPipeline ._prepare_latent_image_ids (
210204 model_input .shape [0 ],
@@ -264,9 +258,9 @@ def step_fn(
264258 # upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
265259 model_pred = FluxPipeline ._unpack_latents (
266260 model_pred ,
267- height = model_input .shape [2 ] * vae_scale_factor ,
268- width = model_input .shape [3 ] * vae_scale_factor ,
269- vae_scale_factor = vae_scale_factor ,
261+ height = model_input .shape [2 ] * self . vae_scale_factor ,
262+ width = model_input .shape [3 ] * self . vae_scale_factor ,
263+ vae_scale_factor = self . vae_scale_factor ,
270264 )
271265
272266 # these weighting schemes use a uniform timestep sampling
@@ -626,6 +620,17 @@ def encode_prompt(
626620
627621 return {"prompt_embeds" : prompt_embeds , "pooled_prompt_embeds" : pooled_prompt_embeds , "text_ids" : text_ids }
628622
623+ def compute_vae_encodings (batch , vae , device , dtype ):
624+ images = batch .pop ("pixel_values" )
625+ pixel_values = torch .stack (list (images ))
626+ pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
627+ pixel_values = pixel_values .to (vae .device , dtype = vae .dtype )
628+
629+ with torch .no_grad ():
630+ model_input = vae .encode (pixel_values ).latent_dist .sample ()
631+ model_input = (model_input - vae .config .shift_factor ) * vae .config .scaling_factor
632+ return {"model_input" : model_input }
633+
629634def pixels_to_tensors (batch , device , dtype ):
630635 images = batch .pop ("pixel_values" )
631636 pixel_values = torch .stack (list (images ))
@@ -729,20 +734,20 @@ def main(args):
729734
730735 from torch_xla .distributed .fsdp .utils import apply_xla_patch_to_nn_linear
731736
732- #unet = apply_xla_patch_to_nn_linear(unet , xs.xla_patched_nn_linear_forward)
737+ #transformer = apply_xla_patch_to_nn_linear(transformer , xs.xla_patched_nn_linear_forward)
733738 transformer .enable_xla_flash_attention (partition_spec = ("data" , None , None , None ), is_flux = True )
734739 FlashAttention .DEFAULT_BLOCK_SIZES = {
735- "block_q" : 1536 ,
736- "block_k_major" : 1536 ,
737- "block_k" : 1536 ,
738- "block_b" : 1536 ,
739- "block_q_major_dkv" : 1536 ,
740- "block_k_major_dkv" : 1536 ,
741- "block_q_dkv" : 1536 ,
742- "block_k_dkv" : 1536 ,
743- "block_q_dq" : 1536 ,
744- "block_k_dq" : 1536 ,
745- "block_k_major_dq" : 1536 ,
740+ "block_q" : 512 ,
741+ "block_k_major" : 512 ,
742+ "block_k" : 512 ,
743+ "block_b" : 512 ,
744+ "block_q_major_dkv" : 512 ,
745+ "block_k_major_dkv" : 512 ,
746+ "block_q_dkv" : 512 ,
747+ "block_k_dkv" : 512 ,
748+ "block_q_dq" : 512 ,
749+ "block_k_dq" : 768 ,
750+ "block_k_major_dq" : 512 ,
746751 }
747752 # For mixed precision training we cast all non-trainable weights (vae,
748753 # non-lora text_encoder and non-lora unet) to half-precision
@@ -812,8 +817,7 @@ def preprocess_train(examples):
812817 tokenizers = tokenizers ,
813818 caption_column = caption_column ,
814819 )
815- #compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
816- pixels_to_tensors_fn = functools .partial (pixels_to_tensors , device = device , dtype = weight_dtype )
820+ compute_vae_encodings_fn = functools .partial (compute_vae_encodings , vae = vae , device = device , dtype = weight_dtype )
817821 from datasets .fingerprint import Hasher
818822
819823 new_fingerprint = Hasher .hash (args )
@@ -822,24 +826,25 @@ def preprocess_train(examples):
822826 compute_embeddings_fn , batched = True , new_fingerprint = new_fingerprint
823827 )
824828 train_dataset_with_tensors = train_dataset .map (
825- pixels_to_tensors_fn , batched = True , new_fingerprint = new_fingerprint_two , batch_size = 256
829+ compute_vae_encodings_fn , batched = True , new_fingerprint = new_fingerprint_two , batch_size = 8
826830 )
827831 precomputed_dataset = concatenate_datasets (
828832 [train_dataset_with_embeddings , train_dataset_with_tensors .remove_columns (["text" , "image" ])], axis = 1
829833 )
830834 precomputed_dataset = precomputed_dataset .with_transform (preprocess_train )
831- del compute_embeddings_fn , text_encoder , text_encoder_2
835+ vae_scale_factor = 2 ** (len (vae .config .block_out_channels ) - 1 )
836+ del compute_embeddings_fn , text_encoder , text_encoder_2 , vae
832837 del text_encoders , tokenizers
833838 def collate_fn (examples ):
834839 prompt_embeds = torch .stack ([torch .tensor (example ["prompt_embeds" ]) for example in examples ]).to (dtype = weight_dtype )
835840 pooled_prompt_embeds = torch .stack ([torch .tensor (example ["pooled_prompt_embeds" ]) for example in examples ]).to (dtype = weight_dtype )
836841 text_ids = torch .stack ([torch .tensor (example ["text_ids" ]) for example in examples ]).to (dtype = weight_dtype )
837- pixel_tensor_values = torch .stack ([torch .tensor (example ["pixel_tensor_values " ]) for example in examples ]).to (dtype = weight_dtype )
842+ model_input = torch .stack ([torch .tensor (example ["model_input " ]) for example in examples ]).to (dtype = weight_dtype )
838843 return {
839844 "prompt_embeds" : prompt_embeds ,
840845 "pooled_prompt_embeds" : pooled_prompt_embeds ,
841846 "text_ids" : text_ids ,
842- "pixel_tensor_values " : pixel_tensor_values
847+ "model_input " : model_input ,
843848 }
844849
845850 g = torch .Generator ()
@@ -860,7 +865,7 @@ def collate_fn(examples):
860865 input_sharding = {
861866 "prompt_embeds" : xs .ShardingSpec (mesh , ("data" , None , None ), minibatch = True ),
862867 "pooled_prompt_embeds" : xs .ShardingSpec (mesh , ("data" , None ,), minibatch = True ),
863- "pixel_tensor_values " : xs .ShardingSpec (mesh , ("data" , None , None , None ), minibatch = True ),
868+ "model_input " : xs .ShardingSpec (mesh , ("data" , None , None , None ), minibatch = True ),
864869 "text_ids" : xs .ShardingSpec (mesh , ("data" , None , None ), minibatch = True ),
865870 },
866871 loader_prefetch_size = args .loader_prefetch_size ,
@@ -881,8 +886,8 @@ def collate_fn(examples):
881886 weight_dtype = weight_dtype ,
882887 device = device ,
883888 noise_scheduler = noise_scheduler ,
889+ vae_scale_factor = vae_scale_factor ,
884890 transformer = transformer ,
885- vae = vae ,
886891 optimizer = optimizer ,
887892 dataloader = train_dataloader ,
888893 args = args ,
0 commit comments