1818from flaxdiff .utils import RandomMarkovState , serialize_model , get_latest_checkpoint
1919from flaxdiff .inputs import ConditioningEncoder , ConditionalInputConfig , DiffusionInputConfig
2020
21- from .simple_trainer import SimpleTrainer , SimpleTrainState , Metrics
21+ from .simple_trainer import SimpleTrainer , SimpleTrainState , Metrics , convert_to_global_tree
2222
2323from flaxdiff .models .autoencoder .autoencoder import AutoEncoder
2424from flax .training import dynamic_scale as dynamic_scale_lib
2525
2626# Reuse the TrainState from the DiffusionTrainer
27- from flaxdiff . trainer .diffusion_trainer import TrainState , DiffusionTrainer
27+ from .diffusion_trainer import TrainState , DiffusionTrainer
2828import shutil
2929
3030def generate_modelname (
@@ -103,6 +103,15 @@ def generate_modelname(
103103 # model_name = f"{model_name}-{config_hash}"
104104 return model_name
105105
106+ @dataclass
107+ class EvaluationMetric :
108+ """
109+ Evaluation metrics for the diffusion model.
110+ The function is given generated samples batch [B, H, W, C] and the original batch.
111+ """
112+ function : Callable
113+ name : str
114+
106115class GeneralDiffusionTrainer (DiffusionTrainer ):
107116 """
108117 General trainer for diffusion models supporting both images and videos.
@@ -126,6 +135,7 @@ def __init__(self,
126135 native_resolution : int = None ,
127136 frames_per_sample : int = None ,
128137 wandb_config : Dict [str , Any ] = None ,
138+ eval_metrics : List [EvaluationMetric ] = None ,
129139 ** kwargs
130140 ):
131141 """
@@ -150,6 +160,7 @@ def __init__(self,
150160 autoencoder = autoencoder ,
151161 )
152162 self .input_config = input_config
163+ self .eval_metrics = eval_metrics
153164
154165 if wandb_config is not None :
155166 # If input_config is not in wandb_config, add it
@@ -363,7 +374,6 @@ def _define_validation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSamp
363374 def generate_samples (
364375 val_state : TrainState ,
365376 batch ,
366- sampler : DiffusionSampler ,
367377 diffusion_steps : int ,
368378 ):
369379 # Process all conditional inputs
@@ -385,7 +395,7 @@ def generate_samples(
385395 model_conditioning_inputs = tuple (model_conditioning_inputs ),
386396 )
387397
388- return sampler , generate_samples
398+ return generate_samples
389399
390400 def _get_image_size (self ):
391401 """Helper to determine image size from available information."""
@@ -415,32 +425,73 @@ def validation_loop(
415425 """
416426 Run validation and log samples for both image and video diffusion.
417427 """
418- sampler , generate_samples = val_step_fn
419- val_ds = iter (val_ds ()) if val_ds else None
428+ global_device_count = jax .device_count ()
429+ local_device_count = jax .local_device_count ()
430+ process_index = jax .process_index ()
431+ generate_samples = val_step_fn
420432
433+ val_ds = iter (val_ds ()) if val_ds else None
434+ # Evaluation step
421435 try :
422- # Generate samples
423- samples = generate_samples (
424- val_state ,
425- next (val_ds ),
426- sampler ,
427- diffusion_steps ,
428- )
429-
430- # Log samples to wandb
431- if getattr (self , 'wandb' , None ) is not None and self .wandb :
432- import numpy as np
436+ metrics = {metric .name : [] for metric in self .eval_metrics } if self .eval_metrics else {}
437+ for i in range (val_steps_per_epoch ):
438+ if val_ds is None :
439+ batch = None
440+ else :
441+ batch = next (val_ds )
442+ if self .distributed_training and global_device_count > 1 :
443+ batch = convert_to_global_tree (self .mesh , batch )
444+ # Generate samples
445+ samples = generate_samples (
446+ val_state ,
447+ batch ,
448+ diffusion_steps ,
449+ )
433450
434- # Process samples differently based on dimensionality
435- if len (samples .shape ) == 5 : # [B,T,H,W,C] - Video data
436- self ._log_video_samples (samples , current_step )
437- else : # [B,H,W,C] - Image data
438- self ._log_image_samples (samples , current_step )
451+ if self .eval_metrics is not None :
452+ for metric in self .eval_metrics :
453+ try :
454+ # Evaluate metrics
455+ metric_val = metric .function (samples , batch )
456+ metrics [metric .name ].append (metric_val )
457+ except Exception as e :
458+ print ("Error in evaluation metrics:" , e )
459+ import traceback
460+ traceback .print_exc ()
461+ pass
439462
463+ if i == 0 :
464+ print (f"Evaluation started for process index { process_index } " )
465+ # Log samples to wandb
466+ if getattr (self , 'wandb' , None ) is not None and self .wandb :
467+ import numpy as np
468+
469+ # Process samples differently based on dimensionality
470+ if len (samples .shape ) == 5 : # [B,T,H,W,C] - Video data
471+ self ._log_video_samples (samples , current_step )
472+ else : # [B,H,W,C] - Image data
473+ self ._log_image_samples (samples , current_step )
474+
475+ if getattr (self , 'wandb' , None ) is not None and self .wandb :
476+ # metrics is a dict of metrics
477+ if metrics and type (metrics ) == dict :
478+ # Flatten the metrics
479+ metrics = {k : np .mean (v ) for k , v in metrics .items ()}
480+ # Log the metrics
481+ for key , value in metrics .items ():
482+ if isinstance (value , jnp .ndarray ):
483+ value = np .array (value )
484+ self .wandb .log ({
485+ f"val/{ key } " : value ,
486+ }, step = current_step )
487+
488+ except StopIteration :
489+ print (f"Validation dataset exhausted for process index { process_index } " )
440490 except Exception as e :
441- print ("Error in validation loop:" , e )
491+ print (f "Error during validation for process index { process_index } : { e } " )
442492 import traceback
443493 traceback .print_exc ()
494+
444495
445496 def _log_video_samples (self , samples , current_step ):
446497 """Helper to log video samples to wandb."""
0 commit comments