Skip to content

Commit 10bf85f

Browse files
author
--unset
committed
wandb request pil image Type
1 parent d0f5b05 commit 10bf85f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

examples/cogvideo/train_cogvideox_lora.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from torchvision.transforms import InterpolationMode
4747
import torchvision.transforms as TT
4848
import numpy as np
49+
from diffusers.image_processor import VaeImageProcessor
4950

5051

5152
if is_wandb_available():
@@ -740,8 +741,13 @@ def log_validation(
740741

741742
videos = []
742743
for _ in range(args.num_validation_videos):
743-
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
744-
videos.append(video)
744+
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
745+
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
746+
747+
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
748+
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
749+
750+
videos.append(image_pil)
745751

746752
for tracker in accelerator.trackers:
747753
phase_name = "test" if is_final_validation else "validation"

0 commit comments

Comments
 (0)