File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change 4646from torchvision .transforms import InterpolationMode
4747import torchvision .transforms as TT
4848import numpy as np
49+ from diffusers .image_processor import VaeImageProcessor
4950
5051
5152if is_wandb_available ():
@@ -740,8 +741,13 @@ def log_validation(
740741
741742 videos = []
742743 for _ in range (args .num_validation_videos ):
743- video = pipe (** pipeline_args , generator = generator , output_type = "np" ).frames [0 ]
744- videos .append (video )
744+ pt_images = pipe (** pipeline_args , generator = generator , output_type = "pt" ).frames [0 ]
745+ pt_images = torch .stack ([pt_images [i ] for i in range (pt_images .shape [0 ])])
746+
747+ image_np = VaeImageProcessor .pt_to_numpy (pt_images )
748+ image_pil = VaeImageProcessor .numpy_to_pil (image_np )
749+
750+ videos .append (image_pil )
745751
746752 for tracker in accelerator .trackers :
747753 phase_name = "test" if is_final_validation else "validation"
You can’t perform that action at this time.
0 commit comments