@@ -546,6 +546,8 @@ def _prepare_validation_inputs(
546546 negative_prompt_attention_mask : torch .Tensor | None
547547 ) -> ForwardBatch :
548548
549+ assert len (validation_batch ['info_list' ]
550+ ) == 1 , "Only batch size 1 is supported for validation"
549551 prompt = validation_batch ['info_list' ][0 ]['prompt' ]
550552 prompt_embeds = validation_batch ['text_embedding' ]
551553 prompt_attention_mask = validation_batch ['text_attention_mask' ]
@@ -629,15 +631,17 @@ def _log_validation(self, transformer, training_args, global_step) -> None:
629631 # Process each validation prompt for each validation step
630632 for num_inference_steps in validation_steps :
631633 step_videos : List [np .ndarray ] = []
632- step_captions : List [str | None ] = []
634+ step_captions : List [str ] = []
633635
634636 for validation_batch in validation_dataloader :
635637 batch = self ._prepare_validation_inputs (
636638 sampling_param , training_args , validation_batch ,
637639 num_inference_steps , negative_prompt_embeds ,
638640 negative_prompt_attention_mask )
639641
640- step_captions .extend ([None ]) # TODO(peiyuan): add caption
642+ assert batch .prompt is not None and isinstance (
643+ batch .prompt , str )
644+ step_captions .append (batch .prompt )
641645
642646 # Run validation inference
643647 with torch .no_grad (), torch .autocast ("cuda" ,
0 commit comments