Skip to content

Commit 46b93ec

Browse files
authored
Merge pull request #67 from guohengkai/ghk/dev/streaming
Support streaming mode (experimental feature)
2 parents 6b16e1c + ee9750a commit 46b93ec

File tree

6 files changed

+350
-43
lines changed

6 files changed

+350
-43
lines changed

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ This work presents **Video Depth Anything** based on [Depth Anything V2](https:/
2121
![teaser](assets/teaser_video_v2.png)
2222

2323
## News
24+
- **2025-07-03:** 🚀🚀🚀 Release an experimental version of training-free **streaming video depth estimation**.
2425
- **2025-07-03:** Release our implementation of [training loss](https://github.com/DepthAnything/Video-Depth-Anything/tree/main/loss).
2526
- **2025-04-25:** 🌟🌟🌟 Release [metric depth model](https://github.com/DepthAnything/Video-Depth-Anything/tree/main/metric_depth) based on Video-Depth-Anything-Large.
2627
- **2025-04-05:** Our paper has been accepted for a **highlight** presentation at [CVPR 2025](https://cvpr.thecvf.com/) (13.5% of the accepted papers).
@@ -107,6 +108,24 @@ Options:
107108
- `--save_npz` (optional): Save the depth map in `npz` format.
108109
- `--save_exr` (optional): Save the depth map in `exr` format.
109110

111+
### Inference a video using streaming mode (Experimental features)
112+
We implement an experimental streaming mode **without training**. In details, we save the hidden states of temporal attentions for each frames in the caches, and only send a single frame into our video depth model during inference by reusing these past hidden states in temporal attentions. We hack our pipeline to align the original inference setting in the offline mode. Due to the inevitable gap between training and testing, we observe a **performance drop** between the streaming model and the offline model (e.g. the `d1` of ScanNet drops from `0.926` to `0.836`). Finetuning the model in the streaming mode will greatly improve the performance. We leave it for future work.
113+
114+
To run the streaming model:
115+
```bash
116+
python3 run_streaming.py --input_video ./assets/example_videos/davis_rollercoaster.mp4 --output_dir ./outputs_streaming --encoder vitl
117+
```
118+
Options:
119+
- `--input_video`: path of input video
120+
- `--output_dir`: path to save the output results
121+
- `--input_size` (optional): By default, we use input size `518` for model inference.
122+
- `--max_res` (optional): By default, we use maximum resolution `1280` for model inference.
123+
- `--encoder` (optional): `vits` for Video-Depth-Anything-V2-Small, `vitl` for Video-Depth-Anything-V2-Large.
124+
- `--max_len` (optional): maximum length of the input video, `-1` means no limit
125+
- `--target_fps` (optional): target fps of the input video, `-1` means the original fps
126+
- `--fp32` (optional): Use `fp32` precision for inference. By default, we use `fp16`.
127+
- `--grayscale` (optional): Save the grayscale depth map, without applying color palette.
128+
110129
### Training Loss
111130
Our training loss is in `loss/` directory. Please see the `loss/test_loss.py` for usage.
112131

run_streaming.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import argparse
15+
import numpy as np
16+
import os
17+
import torch
18+
import time
19+
import cv2
20+
21+
from video_depth_anything.video_depth_stream import VideoDepthAnything
22+
from utils.dc_utils import save_video
23+
24+
if __name__ == '__main__':
25+
parser = argparse.ArgumentParser(description='Video Depth Anything')
26+
parser.add_argument('--input_video', type=str, default='./assets/example_videos/davis_rollercoaster.mp4')
27+
parser.add_argument('--output_dir', type=str, default='./outputs')
28+
parser.add_argument('--input_size', type=int, default=518)
29+
parser.add_argument('--max_res', type=int, default=1280)
30+
parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])
31+
parser.add_argument('--max_len', type=int, default=-1, help='maximum length of the input video, -1 means no limit')
32+
parser.add_argument('--target_fps', type=int, default=-1, help='target fps of the input video, -1 means the original fps')
33+
parser.add_argument('--fp32', action='store_true', help='model infer with torch.float32, default is torch.float16')
34+
parser.add_argument('--grayscale', action='store_true', help='do not apply colorful palette')
35+
36+
args = parser.parse_args()
37+
38+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
39+
40+
model_configs = {
41+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
42+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
43+
}
44+
45+
video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])
46+
video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'), strict=True)
47+
video_depth_anything = video_depth_anything.to(DEVICE).eval()
48+
49+
cap = cv2.VideoCapture(args.input_video)
50+
original_fps = cap.get(cv2.CAP_PROP_FPS)
51+
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
52+
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
53+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
54+
55+
if args.max_res > 0 and max(original_height, original_width) > args.max_res:
56+
scale = args.max_res / max(original_height, original_width)
57+
height = round(original_height * scale)
58+
width = round(original_width * scale)
59+
60+
fps = original_fps if args.target_fps < 0 else args.target_fps
61+
62+
stride = max(round(original_fps / fps), 1)
63+
64+
depths = []
65+
frame_count = 0
66+
start = time.time()
67+
while cap.isOpened():
68+
ret, frame = cap.read()
69+
if not ret or (args.max_len > 0 and frame_count >= args.max_len):
70+
break
71+
if frame_count % stride == 0:
72+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
73+
if args.max_res > 0 and max(original_height, original_width) > args.max_res:
74+
frame = cv2.resize(frame, (width, height)) # Resize frame
75+
76+
# Inference depth
77+
depth = video_depth_anything.infer_video_depth_one(frame, input_size=args.input_size, device=DEVICE, fp32=args.fp32)
78+
depths.append(depth)
79+
frame_count += 1
80+
if frame_count % 50 == 0:
81+
print(f"frame: {frame_count}/{total_frames}")
82+
end = time.time()
83+
84+
cap.release()
85+
print(f"time: {end - start}s")
86+
87+
video_name = os.path.basename(args.input_video)
88+
if not os.path.exists(args.output_dir):
89+
os.makedirs(args.output_dir)
90+
91+
depth_vis_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
92+
depths = np.stack(depths, axis=0)
93+
save_video(depths, depth_vis_path, fps=fps, is_depths=True, grayscale=args.grayscale)

video_depth_anything/dpt_temporal.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self,
5050
**motion_module_kwargs)
5151
])
5252

53-
def forward(self, out_features, patch_h, patch_w, frame_length, micro_batch_size=4):
53+
def forward(self, out_features, patch_h, patch_w, frame_length, micro_batch_size=4, cached_hidden_state_list=None):
5454
out = []
5555
for i, x in enumerate(out_features):
5656
if self.use_clstoken:
@@ -71,19 +71,27 @@ def forward(self, out_features, patch_h, patch_w, frame_length, micro_batch_size
7171
layer_1, layer_2, layer_3, layer_4 = out
7272

7373
B, T = layer_1.shape[0] // frame_length, frame_length
74+
if cached_hidden_state_list is not None:
75+
N = len(cached_hidden_state_list) // len(self.motion_modules)
76+
else:
77+
N = 0
7478

75-
layer_3 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
76-
layer_4 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
79+
layer_3, h0 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[0:N] if N else None)
80+
layer_3 = layer_3.permute(0, 2, 1, 3, 4).flatten(0, 1)
81+
layer_4, h1 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[N:2*N] if N else None)
82+
layer_4 = layer_4.permute(0, 2, 1, 3, 4).flatten(0, 1)
7783

7884
layer_1_rn = self.scratch.layer1_rn(layer_1)
7985
layer_2_rn = self.scratch.layer2_rn(layer_2)
8086
layer_3_rn = self.scratch.layer3_rn(layer_3)
8187
layer_4_rn = self.scratch.layer4_rn(layer_4)
8288

8389
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
84-
path_4 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
90+
path_4, h2 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[2*N:3*N] if N else None)
91+
path_4 = path_4.permute(0, 2, 1, 3, 4).flatten(0, 1)
8592
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
86-
path_3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
93+
path_3, h3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[3*N:] if N else None)
94+
path_3 = path_3.permute(0, 2, 1, 3, 4).flatten(0, 1)
8795

8896
batch_size = layer_1_rn.shape[0]
8997
if batch_size <= micro_batch_size or batch_size % micro_batch_size != 0:
@@ -97,7 +105,8 @@ def forward(self, out_features, patch_h, patch_w, frame_length, micro_batch_size
97105
ori_type = out.dtype
98106
with torch.autocast(device_type="cuda", enabled=False):
99107
out = self.scratch.output_conv2(out.float())
100-
return out.to(ori_type)
108+
109+
output = out.to(ori_type)
101110
else:
102111
ret = []
103112
for i in range(0, batch_size, micro_batch_size):
@@ -111,4 +120,6 @@ def forward(self, out_features, patch_h, patch_w, frame_length, micro_batch_size
111120
with torch.autocast(device_type="cuda", enabled=False):
112121
out = self.scratch.output_conv2(out.float())
113122
ret.append(out.to(ori_type))
114-
return torch.cat(ret, dim=0)
123+
output = torch.cat(ret, dim=0)
124+
125+
return output, h0 + h1 + h2 + h3

video_depth_anything/motion_module/motion_module.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ def __init__(
5757
if zero_initialize:
5858
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
5959

60-
def forward(self, input_tensor, encoder_hidden_states, attention_mask=None):
60+
def forward(self, input_tensor, encoder_hidden_states, attention_mask=None, cached_hidden_state_list=None):
6161
hidden_states = input_tensor
62-
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
62+
hidden_states, output_hidden_state_list = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask, cached_hidden_state_list)
6363

6464
output = hidden_states
65-
return output
65+
return output, output_hidden_state_list # list of hidden states
6666

6767

6868
class TemporalTransformer3DModel(nn.Module):
@@ -99,8 +99,10 @@ def __init__(
9999
)
100100
self.proj_out = nn.Linear(inner_dim, in_channels)
101101

102-
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
102+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, cached_hidden_state_list=None):
103103
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
104+
output_hidden_state_list = []
105+
104106
video_length = hidden_states.shape[2]
105107
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
106108

@@ -113,8 +115,14 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
113115
hidden_states = self.proj_in(hidden_states)
114116

115117
# Transformer Blocks
116-
for block in self.transformer_blocks:
117-
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask)
118+
if cached_hidden_state_list is not None:
119+
n = len(cached_hidden_state_list) // len(self.transformer_blocks)
120+
else:
121+
n = 0
122+
for i, block in enumerate(self.transformer_blocks):
123+
hidden_states, hidden_state_list = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask,
124+
cached_hidden_state_list=cached_hidden_state_list[i*n:(i+1)*n] if n else None)
125+
output_hidden_state_list.extend(hidden_state_list)
118126

119127
# output
120128
hidden_states = self.proj_out(hidden_states)
@@ -123,7 +131,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
123131
output = hidden_states + residual
124132
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
125133

126-
return output
134+
return output, output_hidden_state_list
127135

128136

129137
class TemporalTransformerBlock(nn.Module):
@@ -161,20 +169,24 @@ def __init__(
161169
self.ff_norm = nn.LayerNorm(dim)
162170

163171

164-
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
165-
for attention_block, norm in zip(self.attention_blocks, self.norms):
172+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, cached_hidden_state_list=None):
173+
output_hidden_state_list = []
174+
for i, (attention_block, norm) in enumerate(zip(self.attention_blocks, self.norms)):
166175
norm_hidden_states = norm(hidden_states)
167-
hidden_states = attention_block(
176+
residual_hidden_states, output_hidden_states = attention_block(
168177
norm_hidden_states,
169178
encoder_hidden_states=encoder_hidden_states,
170179
video_length=video_length,
171180
attention_mask=attention_mask,
172-
) + hidden_states
181+
cached_hidden_states=cached_hidden_state_list[i] if cached_hidden_state_list is not None else None,
182+
)
183+
hidden_states = residual_hidden_states + hidden_states
184+
output_hidden_state_list.append(output_hidden_states)
173185

174186
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
175187

176188
output = hidden_states
177-
return output
189+
return output, output_hidden_state_list
178190

179191

180192
class PositionalEncoding(nn.Module):
@@ -227,9 +239,21 @@ def __init__(
227239
else:
228240
raise NotImplementedError
229241

230-
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
242+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, cached_hidden_states=None):
243+
# TODO: support cache for these
244+
assert encoder_hidden_states is None
245+
assert attention_mask is None
246+
231247
d = hidden_states.shape[1]
232-
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
248+
d_in = 0
249+
if cached_hidden_states is None:
250+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
251+
input_hidden_states = hidden_states # (bxd) f c
252+
else:
253+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=1)
254+
input_hidden_states = hidden_states
255+
d_in = cached_hidden_states.shape[1]
256+
hidden_states = torch.cat([cached_hidden_states, hidden_states], dim=1)
233257

234258
if self.pos_encoder is not None:
235259
hidden_states = self.pos_encoder(hidden_states)
@@ -239,7 +263,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
239263
if self.group_norm is not None:
240264
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
241265

242-
query = self.to_q(hidden_states)
266+
query = self.to_q(hidden_states[:, d_in:, ...])
243267
dim = query.shape[-1]
244268

245269
if self.added_kv_proj_dim is not None:
@@ -294,4 +318,4 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
294318

295319
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
296320

297-
return hidden_states
321+
return hidden_states, input_hidden_states

video_depth_anything/video_depth.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
# Copyright (2025) Bytedance Ltd. and/or its affiliates
1+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
22

3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
66

7-
# http://www.apache.org/licenses/LICENSE-2.0
7+
# http://www.apache.org/licenses/LICENSE-2.0
88

9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
1414
import torch
1515
import torch.nn.functional as F
1616
import torch.nn as nn
@@ -36,9 +36,9 @@ class VideoDepthAnything(nn.Module):
3636
def __init__(
3737
self,
3838
encoder='vitl',
39-
features=256,
40-
out_channels=[256, 512, 1024, 1024],
41-
use_bn=False,
39+
features=256,
40+
out_channels=[256, 512, 1024, 1024],
41+
use_bn=False,
4242
use_clstoken=False,
4343
num_frames=32,
4444
pe='ape'
@@ -49,7 +49,7 @@ def __init__(
4949
'vits': [2, 5, 8, 11],
5050
'vitl': [4, 11, 17, 23]
5151
}
52-
52+
5353
self.encoder = encoder
5454
self.pretrained = DINOv2(model_name=encoder)
5555

@@ -59,11 +59,11 @@ def forward(self, x):
5959
B, T, C, H, W = x.shape
6060
patch_h, patch_w = H // 14, W // 14
6161
features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True)
62-
depth = self.head(features, patch_h, patch_w, T)
62+
depth = self.head(features, patch_h, patch_w, T)[0]
6363
depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True)
6464
depth = F.relu(depth)
6565
return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W]
66-
66+
6767
def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda', fp32=False):
6868
frame_height, frame_width = frames[0].shape[:2]
6969
ratio = max(frame_height, frame_width) / min(frame_height, frame_width)
@@ -90,7 +90,7 @@ def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda', f
9090
org_video_len = len(frame_list)
9191
append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step)
9292
frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len
93-
93+
9494
depth_list = []
9595
pre_input = None
9696
for frame_id in tqdm(range(0, org_video_len, frame_step)):
@@ -149,8 +149,8 @@ def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda', f
149149
new_depth = depth_list[frame_id+kf_id] * scale + shift
150150
new_depth[new_depth<0] = 0
151151
ref_align.append(new_depth)
152-
152+
153153
depth_list = depth_list_aligned
154-
154+
155155
return np.stack(depth_list[:org_video_len], axis=0), target_fps
156-
156+

0 commit comments

Comments
 (0)