@@ -65,6 +65,12 @@ def forward(self, x):
6565 return depth .squeeze (1 ).unflatten (0 , (B , T )) # return shape [B, T, H, W]
6666
6767 def infer_video_depth (self , frames , target_fps , input_size = 518 , device = 'cuda' ):
68+ frame_height , frame_width = frames [0 ].shape [:2 ]
69+ ratio = max (frame_height , frame_width ) / min (frame_height , frame_width )
70+ if ratio > 1.78 : # we recommend to process video with ratio smaller than 16:9 due to memory limitation
71+ input_size = int (input_size * 1.777 / ratio )
72+ input_size = round (input_size / 14 ) * 14
73+
6874 transform = Compose ([
6975 Resize (
7076 width = input_size ,
@@ -79,7 +85,6 @@ def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):
7985 PrepareForNet (),
8086 ])
8187
82- frame_size = frames [0 ].shape [:2 ]
8388 frame_list = [frames [i ] for i in range (frames .shape [0 ])]
8489 frame_step = INFER_LEN - OVERLAP
8590 org_video_len = len (frame_list )
@@ -99,7 +104,7 @@ def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):
99104 with torch .no_grad ():
100105 depth = self .forward (cur_input ) # depth shape: [1, T, H, W]
101106
102- depth = F .interpolate (depth .flatten (0 ,1 ).unsqueeze (1 ), size = frame_size , mode = 'bilinear' , align_corners = True )
107+ depth = F .interpolate (depth .flatten (0 ,1 ).unsqueeze (1 ), size = ( frame_height , frame_width ) , mode = 'bilinear' , align_corners = True )
103108 depth_list += [depth [i ][0 ].cpu ().numpy () for i in range (depth .shape [0 ])]
104109
105110 pre_input = cur_input
0 commit comments