|
3 | 3 | # |
4 | 4 | # This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification] |
5 | 5 | # Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file]. |
6 | | -from typing import Union, List |
7 | | -import tempfile |
8 | 6 | import numpy as np |
9 | | -import PIL.Image |
10 | 7 | import matplotlib.cm as cm |
11 | | -import mediapy |
12 | | -import torch |
| 8 | +import imageio |
13 | 9 | try: |
14 | 10 | from decord import VideoReader, cpu |
15 | 11 | DECORD_AVAILABLE = True |
16 | 12 | except: |
17 | 13 | import cv2 |
18 | 14 | DECORD_AVAILABLE = False |
19 | 15 |
|
| 16 | +def ensure_even(value): |
| 17 | + return value if value % 2 == 0 else value + 1 |
20 | 18 |
|
21 | | -def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1, dataset="open"): |
| 19 | +def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1): |
22 | 20 | if DECORD_AVAILABLE: |
23 | 21 | vid = VideoReader(video_path, ctx=cpu(0)) |
24 | 22 | original_height, original_width = vid.get_batch([0]).shape[1:3] |
25 | 23 | height = original_height |
26 | 24 | width = original_width |
27 | 25 | if max_res > 0 and max(height, width) > max_res: |
28 | 26 | scale = max_res / max(original_height, original_width) |
29 | | - height = round(original_height * scale) |
30 | | - width = round(original_width * scale) |
| 27 | + height = ensure_even(round(original_height * scale)) |
| 28 | + width = ensure_even(round(original_width * scale)) |
31 | 29 |
|
32 | 30 | vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height) |
33 | 31 |
|
@@ -71,46 +69,18 @@ def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1, dat |
71 | 69 | return frames, fps |
72 | 70 |
|
73 | 71 |
|
74 | | -def save_video( |
75 | | - video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], |
76 | | - output_video_path: str = None, |
77 | | - fps: int = 10, |
78 | | - crf: int = 18, |
79 | | -) -> str: |
80 | | - if output_video_path is None: |
81 | | - output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name |
82 | | - |
83 | | - if isinstance(video_frames[0], np.ndarray): |
84 | | - video_frames = [frame.astype(np.uint8) for frame in video_frames] |
85 | | - |
86 | | - elif isinstance(video_frames[0], PIL.Image.Image): |
87 | | - video_frames = [np.array(frame) for frame in video_frames] |
88 | | - mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf) |
89 | | - return output_video_path |
90 | | - |
91 | | - |
92 | | -class ColorMapper: |
93 | | - # a color mapper to map depth values to a certain colormap |
94 | | - def __init__(self, colormap: str = "inferno"): |
95 | | - self.colormap = torch.tensor(cm.get_cmap(colormap).colors) |
96 | | - |
97 | | - def apply(self, image: torch.Tensor, v_min=None, v_max=None): |
98 | | - # assert len(image.shape) == 2 |
99 | | - if v_min is None: |
100 | | - v_min = image.min() |
101 | | - if v_max is None: |
102 | | - v_max = image.max() |
103 | | - image = (image - v_min) / (v_max - v_min) |
104 | | - image = (image * 255).long() |
105 | | - image = self.colormap[image] * 255 |
106 | | - return image |
107 | | - |
| 72 | +def save_video(frames, output_video_path, fps=10, is_depths=False): |
| 73 | + writer = imageio.get_writer(output_video_path, fps=fps, macro_block_size=1, codec='libx264', ffmpeg_params=['-crf', '18']) |
| 74 | + if is_depths: |
| 75 | + colormap = np.array(cm.get_cmap("inferno").colors) |
| 76 | + d_min, d_max = frames.min(), frames.max() |
| 77 | + for i in range(frames.shape[0]): |
| 78 | + depth = frames[i] |
| 79 | + depth_norm = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8) |
| 80 | + depth_vis = (colormap[depth_norm] * 255).astype(np.uint8) |
| 81 | + writer.append_data(depth_vis) |
| 82 | + else: |
| 83 | + for i in range(frames.shape[0]): |
| 84 | + writer.append_data(frames[i]) |
108 | 85 |
|
109 | | -def vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None): |
110 | | - visualizer = ColorMapper() |
111 | | - if v_min is None: |
112 | | - v_min = depths.min() |
113 | | - if v_max is None: |
114 | | - v_max = depths.max() |
115 | | - res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy() |
116 | | - return res |
| 86 | + writer.close() |
0 commit comments