Skip to content

Commit ce9b591

Browse files
[Training] add caption to validation log (#582)
1 parent d0e5a62 commit ce9b591

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

fastvideo/v1/training/training_pipeline.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,8 @@ def _prepare_validation_inputs(
546546
negative_prompt_attention_mask: torch.Tensor | None
547547
) -> ForwardBatch:
548548

549+
assert len(validation_batch['info_list']
550+
) == 1, "Only batch size 1 is supported for validation"
549551
prompt = validation_batch['info_list'][0]['prompt']
550552
prompt_embeds = validation_batch['text_embedding']
551553
prompt_attention_mask = validation_batch['text_attention_mask']
@@ -629,15 +631,17 @@ def _log_validation(self, transformer, training_args, global_step) -> None:
629631
# Process each validation prompt for each validation step
630632
for num_inference_steps in validation_steps:
631633
step_videos: List[np.ndarray] = []
632-
step_captions: List[str | None] = []
634+
step_captions: List[str] = []
633635

634636
for validation_batch in validation_dataloader:
635637
batch = self._prepare_validation_inputs(
636638
sampling_param, training_args, validation_batch,
637639
num_inference_steps, negative_prompt_embeds,
638640
negative_prompt_attention_mask)
639641

640-
step_captions.extend([None]) # TODO(peiyuan): add caption
642+
assert batch.prompt is not None and isinstance(
643+
batch.prompt, str)
644+
step_captions.append(batch.prompt)
641645

642646
# Run validation inference
643647
with torch.no_grad(), torch.autocast("cuda",

0 commit comments

Comments
 (0)