Skip to content

Commit 3999b5e

Browse files
committed
update video writer
1 parent 013235f commit 3999b5e

File tree

5 files changed

+32
-64
lines changed

5 files changed

+32
-64
lines changed

app.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@
1313
# limitations under the License.
1414
import gradio as gr
1515

16-
1716
import numpy as np
1817
import os
1918
import torch
2019

2120
from video_depth_anything.video_depth import VideoDepthAnything
22-
from utils.dc_utils import read_video_frames, vis_sequence_depth, save_video
21+
from utils.dc_utils import read_video_frames, save_video
2322

2423
examples = [
25-
['assets/example_videos/davis_rollercoaster.mp4'],
24+
['assets/example_videos/davis_rollercoaster.mp4', -1, -1, 1280],
2625
]
2726

2827
model_configs = {
@@ -46,17 +45,16 @@ def infer_video_depth(
4645
input_size: int = 518,
4746
):
4847
frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
49-
depth_list, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device='cuda')
50-
depth_list = np.stack(depth_list, axis=0)
51-
vis = vis_sequence_depth(depth_list)
48+
depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device='cuda')
49+
5250
video_name = os.path.basename(input_video)
5351
if not os.path.exists(output_dir):
5452
os.makedirs(output_dir)
5553

5654
processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
5755
depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
5856
save_video(frames, processed_video_path, fps=fps)
59-
save_video(vis, depth_vis_path, fps=fps)
57+
save_video(depths, depth_vis_path, fps=fps, is_depths=True)
6058

6159
return [processed_video_path, depth_vis_path]
6260

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ torchvision
44
opencv-python
55
matplotlib
66
pillow
7-
mediapy
7+
imageio
8+
imageio-ffmpeg
89
decord
910
xformers
1011
einops

run.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818

1919
from video_depth_anything.video_depth import VideoDepthAnything
20-
from utils.dc_utils import read_video_frames, vis_sequence_depth, save_video
20+
from utils.dc_utils import read_video_frames, save_video
2121

2222
if __name__ == '__main__':
2323
parser = argparse.ArgumentParser(description='Video Depth Anything')
@@ -43,17 +43,16 @@
4343
video_depth_anything = video_depth_anything.to(DEVICE).eval()
4444

4545
frames, target_fps = read_video_frames(args.input_video, args.max_len, args.target_fps, args.max_res)
46-
depth_list, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=args.input_size, device=DEVICE)
47-
depth_list = np.stack(depth_list, axis=0)
48-
vis = vis_sequence_depth(depth_list)
46+
depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=args.input_size, device=DEVICE)
47+
4948
video_name = os.path.basename(args.input_video)
5049
if not os.path.exists(args.output_dir):
5150
os.makedirs(args.output_dir)
5251

5352
processed_video_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
5453
depth_vis_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
5554
save_video(frames, processed_video_path, fps=fps)
56-
save_video(vis, depth_vis_path, fps=fps)
55+
save_video(depths, depth_vis_path, fps=fps, is_depths=True)
5756

5857

5958

utils/dc_utils.py

Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,29 @@
33
#
44
# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
55
# Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file].
6-
from typing import Union, List
7-
import tempfile
86
import numpy as np
9-
import PIL.Image
107
import matplotlib.cm as cm
11-
import mediapy
12-
import torch
8+
import imageio
139
try:
1410
from decord import VideoReader, cpu
1511
DECORD_AVAILABLE = True
1612
except:
1713
import cv2
1814
DECORD_AVAILABLE = False
1915

16+
def ensure_even(value):
17+
return value if value % 2 == 0 else value + 1
2018

21-
def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1, dataset="open"):
19+
def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1):
2220
if DECORD_AVAILABLE:
2321
vid = VideoReader(video_path, ctx=cpu(0))
2422
original_height, original_width = vid.get_batch([0]).shape[1:3]
2523
height = original_height
2624
width = original_width
2725
if max_res > 0 and max(height, width) > max_res:
2826
scale = max_res / max(original_height, original_width)
29-
height = round(original_height * scale)
30-
width = round(original_width * scale)
27+
height = ensure_even(round(original_height * scale))
28+
width = ensure_even(round(original_width * scale))
3129

3230
vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
3331

@@ -71,46 +69,18 @@ def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1, dat
7169
return frames, fps
7270

7371

74-
def save_video(
75-
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
76-
output_video_path: str = None,
77-
fps: int = 10,
78-
crf: int = 18,
79-
) -> str:
80-
if output_video_path is None:
81-
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
82-
83-
if isinstance(video_frames[0], np.ndarray):
84-
video_frames = [frame.astype(np.uint8) for frame in video_frames]
85-
86-
elif isinstance(video_frames[0], PIL.Image.Image):
87-
video_frames = [np.array(frame) for frame in video_frames]
88-
mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf)
89-
return output_video_path
90-
91-
92-
class ColorMapper:
93-
# a color mapper to map depth values to a certain colormap
94-
def __init__(self, colormap: str = "inferno"):
95-
self.colormap = torch.tensor(cm.get_cmap(colormap).colors)
96-
97-
def apply(self, image: torch.Tensor, v_min=None, v_max=None):
98-
# assert len(image.shape) == 2
99-
if v_min is None:
100-
v_min = image.min()
101-
if v_max is None:
102-
v_max = image.max()
103-
image = (image - v_min) / (v_max - v_min)
104-
image = (image * 255).long()
105-
image = self.colormap[image] * 255
106-
return image
107-
72+
def save_video(frames, output_video_path, fps=10, is_depths=False):
73+
writer = imageio.get_writer(output_video_path, fps=fps, macro_block_size=1, codec='libx264', ffmpeg_params=['-crf', '18'])
74+
if is_depths:
75+
colormap = np.array(cm.get_cmap("inferno").colors)
76+
d_min, d_max = frames.min(), frames.max()
77+
for i in range(frames.shape[0]):
78+
depth = frames[i]
79+
depth_norm = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8)
80+
depth_vis = (colormap[depth_norm] * 255).astype(np.uint8)
81+
writer.append_data(depth_vis)
82+
else:
83+
for i in range(frames.shape[0]):
84+
writer.append_data(frames[i])
10885

109-
def vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None):
110-
visualizer = ColorMapper()
111-
if v_min is None:
112-
v_min = depths.min()
113-
if v_max is None:
114-
v_max = depths.max()
115-
res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy()
116-
return res
86+
writer.close()

video_depth_anything/video_depth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,5 +150,5 @@ def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):
150150

151151
depth_list = depth_list_aligned
152152

153-
return depth_list[:org_video_len], target_fps
153+
return np.stack(depth_list[:org_video_len], axis=0), target_fps
154154

0 commit comments

Comments
 (0)