3838from skimage .metrics import structural_similarity as ssim
3939from flax .training import train_state
4040from maxdiffusion .pipelines .wan .wan_pipeline import WanPipeline
41+ from jax .experimental import multihost_utils
4142
4243
4344class TrainState (train_state .TrainState ):
@@ -156,6 +157,11 @@ def get_data_shardings(self, mesh):
156157 data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding }
157158 return data_sharding
158159
160+ def get_eval_data_shardings (self , mesh ):
161+ data_sharding = jax .sharding .NamedSharding (mesh , P (* self .config .data_sharding ))
162+ data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding , "timesteps" : data_sharding }
163+ return data_sharding
164+
159165 def load_dataset (self , mesh , is_training = True ):
160166 # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
161167 # Image pre-training - txt2img 256px
@@ -170,34 +176,43 @@ def load_dataset(self, mesh, is_training=True):
170176 raise ValueError (
171177 "Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True"
172178 )
173-
174179 feature_description = {
175180 "latents" : tf .io .FixedLenFeature ([], tf .string ),
176181 "encoder_hidden_states" : tf .io .FixedLenFeature ([], tf .string ),
177182 }
178183
179- def prepare_sample (features ):
184+ if not is_training :
185+ feature_description ["timesteps" ] = tf .io .FixedLenFeature ([], tf .int64 )
186+
187+ def prepare_sample_train (features ):
180188 latents = tf .io .parse_tensor (features ["latents" ], out_type = tf .float32 )
181189 encoder_hidden_states = tf .io .parse_tensor (features ["encoder_hidden_states" ], out_type = tf .float32 )
182190 return {"latents" : latents , "encoder_hidden_states" : encoder_hidden_states }
183191
192+ def prepare_sample_eval (features ):
193+ latents = tf .io .parse_tensor (features ["latents" ], out_type = tf .float32 )
194+ encoder_hidden_states = tf .io .parse_tensor (features ["encoder_hidden_states" ], out_type = tf .float32 )
195+ timesteps = features ["timesteps" ]
196+ return {"latents" : latents , "encoder_hidden_states" : encoder_hidden_states , "timesteps" : timesteps }
197+
184198 data_iterator = make_data_iterator (
185199 config ,
186200 jax .process_index (),
187201 jax .process_count (),
188202 mesh ,
189203 config .global_batch_size_to_load ,
190204 feature_description = feature_description ,
191- prepare_sample_fn = prepare_sample ,
205+ prepare_sample_fn = prepare_sample_train if is_training else prepare_sample_eval ,
192206 is_training = is_training ,
193207 )
194208 return data_iterator
195209
196210 def start_training (self ):
197211
198212 pipeline = self .load_checkpoint ()
199- # Generate a sample before training to compare against generated sample after training.
200- pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
213+ if self .config .enable_ssim :
214+ # Generate a sample before training to compare against generated sample after training.
215+ pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
201216
202217 if self .config .eval_every == - 1 or (not self .config .enable_generate_video_for_eval ):
203218 # save some memory.
@@ -215,8 +230,57 @@ def start_training(self):
215230 # Returns pipeline with trained transformer state
216231 pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , train_data_iterator )
217232
218- posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
219- print_ssim (pretrained_video_path , posttrained_video_path )
233+ if self .config .enable_ssim :
234+ posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
235+ print_ssim (pretrained_video_path , posttrained_video_path )
236+
237+ def eval (self , mesh , eval_rng_key , step , p_eval_step , state , scheduler_state , writer ):
238+ eval_data_iterator = self .load_dataset (mesh , is_training = False )
239+ eval_rng = eval_rng_key
240+ eval_losses_by_timestep = {}
241+ # Loop indefinitely until the iterator is exhausted
242+ while True :
243+ try :
244+ eval_start_time = datetime .datetime .now ()
245+ eval_batch = load_next_batch (eval_data_iterator , None , self .config )
246+ with mesh , nn_partitioning .axis_rules (
247+ self .config .logical_axis_rules
248+ ):
249+ metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
250+ metrics ["scalar" ]["learning/eval_loss" ].block_until_ready ()
251+ losses = metrics ["scalar" ]["learning/eval_loss" ]
252+ timesteps = eval_batch ["timesteps" ]
253+ gathered_losses = multihost_utils .process_allgather (losses )
254+ gathered_losses = jax .device_get (gathered_losses )
255+ gathered_timesteps = multihost_utils .process_allgather (timesteps )
256+ gathered_timesteps = jax .device_get (gathered_timesteps )
257+ if jax .process_index () == 0 :
258+ for t , l in zip (gathered_timesteps .flatten (), gathered_losses .flatten ()):
259+ timestep = int (t )
260+ if timestep not in eval_losses_by_timestep :
261+ eval_losses_by_timestep [timestep ] = []
262+ eval_losses_by_timestep [timestep ].append (l )
263+ eval_end_time = datetime .datetime .now ()
264+ eval_duration = eval_end_time - eval_start_time
265+ max_logging .log (f"Eval time: { eval_duration .total_seconds ():.2f} seconds." )
266+ except StopIteration :
267+ # This block is executed when the iterator has no more data
268+ break
269+ # Check if any evaluation was actually performed
270+ if eval_losses_by_timestep and jax .process_index () == 0 :
271+ mean_per_timestep = []
272+ if jax .process_index () == 0 :
273+ max_logging .log (f"Step { step } , calculating mean loss per timestep..." )
274+ for timestep , losses in sorted (eval_losses_by_timestep .items ()):
275+ losses = jnp .array (losses )
276+ losses = losses [: min (self .config .eval_max_number_of_samples_in_bucket , len (losses ))]
277+ mean_loss = jnp .mean (losses )
278+ max_logging .log (f" Mean eval loss for timestep { timestep } : { mean_loss :.4f} " )
279+ mean_per_timestep .append (mean_loss )
280+ final_eval_loss = jnp .mean (jnp .array (mean_per_timestep ))
281+ max_logging .log (f"Step { step } , Final Average Eval loss: { final_eval_loss :.4f} " )
282+ if writer :
283+ writer .add_scalar ("learning/eval_loss" , final_eval_loss , step )
220284
221285 def training_loop (self , pipeline , optimizer , learning_rate_scheduler , train_data_iterator ):
222286 mesh = pipeline .mesh
@@ -231,6 +295,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
231295 state = jax .lax .with_sharding_constraint (state , state_spec )
232296 state_shardings = nnx .get_named_sharding (state , mesh )
233297 data_shardings = self .get_data_shardings (mesh )
298+ eval_data_shardings = self .get_eval_data_shardings (mesh )
234299
235300 writer = max_utils .initialize_summary_writer (self .config )
236301 writer_thread = threading .Thread (target = _tensorboard_writer_worker , args = (writer , self .config ), daemon = True )
@@ -255,11 +320,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
255320 )
256321 p_eval_step = jax .jit (
257322 functools .partial (eval_step , scheduler = pipeline .scheduler , config = self .config ),
258- in_shardings = (state_shardings , data_shardings , None , None ),
323+ in_shardings = (state_shardings , eval_data_shardings , None , None ),
259324 out_shardings = (None , None ),
260325 )
261326
262327 rng = jax .random .key (self .config .seed )
328+ rng , eval_rng_key = jax .random .split (rng )
263329 start_step = 0
264330 last_step_completion = datetime .datetime .now ()
265331 local_metrics_file = open (self .config .metrics_file , "a" , encoding = "utf8" ) if self .config .metrics_file else None
@@ -304,27 +370,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
304370 inference_generate_video (self .config , pipeline , filename_prefix = f"{ step + 1 } -train_steps-" )
305371 # Re-create the iterator each time you start evaluation to reset it
306372 # This assumes your data loading logic can be called to get a fresh iterator.
307- eval_data_iterator = self .load_dataset (mesh , is_training = False )
308- eval_rng = jax .random .key (self .config .seed + step )
309- eval_metrics = []
310- # Loop indefinitely until the iterator is exhausted
311- while True :
312- try :
313- with mesh :
314- eval_batch = load_next_batch (eval_data_iterator , None , self .config )
315- metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
316- eval_metrics .append (metrics ["scalar" ]["learning/eval_loss" ])
317- except StopIteration :
318- # This block is executed when the iterator has no more data
319- break
320- # Check if any evaluation was actually performed
321- if eval_metrics :
322- eval_loss = jnp .mean (jnp .array (eval_metrics ))
323- max_logging .log (f"Step { step } , Eval loss: { eval_loss :.4f} " )
324- if writer :
325- writer .add_scalar ("learning/eval_loss" , eval_loss , step )
326- else :
327- max_logging .log (f"Step { step } , evaluation dataset was empty." )
373+ self .eval (mesh , eval_rng_key , step , p_eval_step , state , scheduler_state , writer )
374+
328375 example_batch = next_batch_future .result ()
329376 if step != 0 and self .config .checkpoint_every != - 1 and step % self .config .checkpoint_every == 0 :
330377 max_logging .log (f"Saving checkpoint for step { step } " )
@@ -394,57 +441,54 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
394441 """
395442 Computes the evaluation loss for a single batch without updating model weights.
396443 """
397- _ , new_rng , timestep_rng = jax .random .split (rng , num = 3 )
398-
399- # This ensures the batch size is consistent, though it might be redundant
400- # if the evaluation dataloader is already configured correctly.
401- for k , v in data .items ():
402- data [k ] = v [: config .global_batch_size_to_train_on , :]
403444
404445 # The loss function logic is identical to training. We are evaluating the model's
405446 # ability to perform its core training objective (e.g., denoising).
406- def loss_fn (params ):
447+ @jax .jit
448+ def loss_fn (params , latents , encoder_hidden_states , timesteps , rng ):
407449 # Reconstruct the model from its definition and parameters
408450 model = nnx .merge (state .graphdef , params , state .rest_of_state )
409451
410- # Prepare inputs
411- latents = data ["latents" ].astype (config .weights_dtype )
412- encoder_hidden_states = data ["encoder_hidden_states" ].astype (config .weights_dtype )
413- bsz = latents .shape [0 ]
414-
415- # Sample random timesteps and noise, just as in a training step
416- timesteps = jax .random .randint (
417- timestep_rng ,
418- (bsz ,),
419- 0 ,
420- scheduler .config .num_train_timesteps ,
421- )
422- noise = jax .random .normal (key = new_rng , shape = latents .shape , dtype = latents .dtype )
452+ noise = jax .random .normal (key = rng , shape = latents .shape , dtype = latents .dtype )
423453 noisy_latents = scheduler .add_noise (scheduler_state , latents , noise , timesteps )
424454
425455 # Get the model's prediction
426456 model_pred = model (
427457 hidden_states = noisy_latents ,
428458 timestep = timesteps ,
429459 encoder_hidden_states = encoder_hidden_states ,
460+ deterministic = True ,
430461 )
431462
432463 # Calculate the loss against the target
433464 training_target = scheduler .training_target (latents , noise , timesteps )
434465 training_weight = jnp .expand_dims (scheduler .training_weight (scheduler_state , timesteps ), axis = (1 , 2 , 3 , 4 ))
435466 loss = (training_target - model_pred ) ** 2
436467 loss = loss * training_weight
437- loss = jnp .mean (loss )
468+ # Calculate the mean loss per sample across all non-batch dimensions.
469+ loss = loss .reshape (loss .shape [0 ], - 1 ).mean (axis = 1 )
438470
439471 return loss
440472
441473 # --- Key Difference from train_step ---
442474 # Directly compute the loss without calculating gradients.
443475 # The model's state.params are used but not updated.
444- loss = loss_fn (state .params )
476+ # TODO(coolkp): Explore optimizing the creation of PRNGs in a vmap or statically outside of the loop
477+ bs = len (data ["latents" ])
478+ single_batch_size = config .global_batch_size_to_train_on
479+ losses = jnp .zeros (bs )
480+ for i in range (0 , bs , single_batch_size ):
481+ start = i
482+ end = min (i + single_batch_size , bs )
483+ latents = data ["latents" ][start :end , :].astype (config .weights_dtype )
484+ encoder_hidden_states = data ["encoder_hidden_states" ][start :end , :].astype (config .weights_dtype )
485+ timesteps = data ["timesteps" ][start :end ].astype ("int64" )
486+ _ , new_rng = jax .random .split (rng , num = 2 )
487+ loss = loss_fn (state .params , latents , encoder_hidden_states , timesteps , new_rng )
488+ losses = losses .at [start :end ].set (loss )
445489
446490 # Structure the metrics for logging and aggregation
447- metrics = {"scalar" : {"learning/eval_loss" : loss }}
491+ metrics = {"scalar" : {"learning/eval_loss" : losses }}
448492
449493 # Return the computed metrics and the new RNG key for the next eval step
450494 return metrics , new_rng
0 commit comments