Skip to content

Commit aaf43f7

Browse files
2 parents d276658 + 1ea9c8d commit aaf43f7

File tree

1 file changed

+52
-10
lines changed

1 file changed

+52
-10
lines changed

run.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import torch
1818
import time
19+
import cv2
1920

2021
from video_depth_anything.video_depth import VideoDepthAnything
2122
from utils.dc_utils import read_video_frames, save_video
@@ -51,14 +52,59 @@
5152
video_depth_anything = video_depth_anything.to(DEVICE).eval()
5253
model_load_time = time.time() - start_time
5354

54-
# Video reading
55+
# Video reading and processing in batches
5556
read_start = time.time()
56-
frames, target_fps = read_video_frames(args.input_video, args.max_len, args.target_fps, args.max_res)
57+
inference_start = time.time() # Add timing marker here
58+
batch_size = 300 # Process 300 frames at a time
59+
total_depths = []
60+
total_frames = []
61+
62+
# Initialize video capture
63+
cap = cv2.VideoCapture(args.input_video)
64+
original_fps = cap.get(cv2.CAP_PROP_FPS)
65+
target_fps = args.target_fps if args.target_fps > 0 else original_fps
66+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
67+
68+
if args.max_len > 0:
69+
frame_count = min(frame_count, args.max_len)
70+
71+
batch_frames = []
72+
frame_idx = 0
73+
74+
while frame_idx < frame_count:
75+
ret, frame = cap.read()
76+
if not ret:
77+
break
78+
79+
if args.max_res > 0:
80+
h, w = frame.shape[:2]
81+
scale = min(args.max_res / h, args.max_res / w)
82+
if scale < 1:
83+
new_h, new_w = int(h * scale), int(w * scale)
84+
frame = cv2.resize(frame, (new_w, new_h))
85+
86+
batch_frames.append(frame)
87+
frame_idx += 1
88+
89+
# Process batch when it reaches batch_size or end of video
90+
if len(batch_frames) == batch_size or frame_idx == frame_count:
91+
# Convert batch_frames list to numpy array
92+
batch_frames_array = np.array(batch_frames)
93+
depths, _ = video_depth_anything.infer_video_depth(
94+
batch_frames_array, target_fps, input_size=args.input_size,
95+
device=DEVICE, fp32=args.fp32
96+
)
97+
total_depths.extend(depths)
98+
total_frames.extend(batch_frames)
99+
batch_frames = [] # Clear batch
100+
print(f"Processed {frame_idx}/{frame_count} frames")
101+
102+
cap.release()
103+
frames = total_frames
104+
depths = total_depths
57105
read_time = time.time() - read_start
58106

59-
# Depth inference
60-
inference_start = time.time()
61-
depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=args.input_size, device=DEVICE, fp32=args.fp32)
107+
# Remove redundant inference
62108
inference_time = time.time() - inference_start
63109

64110
# Video saving
@@ -110,8 +156,4 @@
110156
num_frames = len(frames)
111157
print(f"\nPer-frame Statistics:")
112158
print(f"Number of Frames: {num_frames}")
113-
print(f"Average Processing Time per Frame: {inference_time/num_frames:.3f}s ({(num_frames/inference_time):.1f} FPS)")
114-
115-
116-
117-
159+
print(f"Average Processing Time per Frame: {inference_time/num_frames:.3f}s ({(num_frames/inference_time):.1f} FPS)")

0 commit comments

Comments
 (0)