|
2 | 2 | from pathlib import Path |
3 | 3 | from typing import Any, Dict, List, Optional, Tuple |
4 | 4 |
|
| 5 | +import numpy as np |
5 | 6 | import pandas as pd |
6 | 7 | import torch |
| 8 | +import torchvision.transforms as TT |
7 | 9 | from accelerate.logging import get_logger |
8 | 10 | from torch.utils.data import Dataset, Sampler |
9 | 11 | from torchvision import transforms |
| 12 | +from torchvision.transforms import InterpolationMode |
10 | 13 | from torchvision.transforms.functional import resize |
11 | 14 |
|
12 | 15 |
|
@@ -281,6 +284,71 @@ def _find_nearest_resolution(self, height, width): |
281 | 284 | return nearest_res[1], nearest_res[2] |
282 | 285 |
|
283 | 286 |
|
| 287 | +class VideoDatasetWithResizeAndRectangleCrop(VideoDataset): |
| 288 | + def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: |
| 289 | + super().__init__(*args, **kwargs) |
| 290 | + self.video_reshape_mode = video_reshape_mode |
| 291 | + |
| 292 | + def _resize_for_rectangle_crop(self, arr, image_size): |
| 293 | + reshape_mode = self.video_reshape_mode |
| 294 | + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: |
| 295 | + arr = resize( |
| 296 | + arr, |
| 297 | + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], |
| 298 | + interpolation=InterpolationMode.BICUBIC, |
| 299 | + ) |
| 300 | + else: |
| 301 | + arr = resize( |
| 302 | + arr, |
| 303 | + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], |
| 304 | + interpolation=InterpolationMode.BICUBIC, |
| 305 | + ) |
| 306 | + |
| 307 | + h, w = arr.shape[2], arr.shape[3] |
| 308 | + arr = arr.squeeze(0) |
| 309 | + |
| 310 | + delta_h = h - image_size[0] |
| 311 | + delta_w = w - image_size[1] |
| 312 | + |
| 313 | + if reshape_mode == "random" or reshape_mode == "none": |
| 314 | + top = np.random.randint(0, delta_h + 1) |
| 315 | + left = np.random.randint(0, delta_w + 1) |
| 316 | + elif reshape_mode == "center": |
| 317 | + top, left = delta_h // 2, delta_w // 2 |
| 318 | + else: |
| 319 | + raise NotImplementedError |
| 320 | + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) |
| 321 | + return arr |
| 322 | + |
| 323 | + def _preprocess_video(self, path: Path) -> torch.Tensor: |
| 324 | + if self.load_tensors: |
| 325 | + return self._load_preprocessed_latents_and_embeds(path) |
| 326 | + else: |
| 327 | + video_reader = decord.VideoReader(uri=path.as_posix()) |
| 328 | + video_num_frames = len(video_reader) |
| 329 | + nearest_frame_bucket = min( |
| 330 | + self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) |
| 331 | + ) |
| 332 | + |
| 333 | + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) |
| 334 | + |
| 335 | + frames = video_reader.get_batch(frame_indices) |
| 336 | + frames = frames[:nearest_frame_bucket].float() |
| 337 | + frames = frames.permute(0, 3, 1, 2).contiguous() |
| 338 | + |
| 339 | + nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) |
| 340 | + frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) |
| 341 | + frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) |
| 342 | + |
| 343 | + image = frames[:1].clone() if self.image_to_video else None |
| 344 | + |
| 345 | + return image, frames, None |
| 346 | + |
| 347 | + def _find_nearest_resolution(self, height, width): |
| 348 | + nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) |
| 349 | + return nearest_res[1], nearest_res[2] |
| 350 | + |
| 351 | + |
284 | 352 | class BucketSampler(Sampler): |
285 | 353 | def __init__(self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True) -> None: |
286 | 354 | self.data_source = data_source |
|
0 commit comments