@@ -150,19 +150,19 @@ class OpenCVDecoder(AbstractDecoder):
150150 def __init__ (self , backend ):
151151 import cv2
152152
153+ self .cv2 = cv2
154+
153155 self ._available_backends = {"FFMPEG" : cv2 .CAP_FFMPEG }
154156 self ._backend = self ._available_backends .get (backend )
155157
156158 self ._print_each_iteration_time = False
157159
158160 def decode_frames (self , video_file , pts_list ):
159- import cv2
160-
161- cap = cv2 .VideoCapture (video_file , self ._backend )
161+ cap = self .cv2 .VideoCapture (video_file , self ._backend )
162162 if not cap .isOpened ():
163163 raise ValueError ("Could not open video stream" )
164164
165- fps = cap .get (cv2 .CAP_PROP_FPS )
165+ fps = cap .get (self . cv2 .CAP_PROP_FPS )
166166 approx_frame_indices = [int (pts * fps ) for pts in pts_list ]
167167
168168 current_frame = 0
@@ -174,6 +174,11 @@ def decode_frames(self, video_file, pts_list):
174174 if current_frame in approx_frame_indices : # only decompress needed
175175 ret , frame = cap .retrieve ()
176176 if ret :
177+ # OpenCV uses BGR, change to RGB
178+ frame = self .cv2 .cvtColor (frame , self .cv2 .COLOR_BGR2RGB )
179+ # Update to C, H, W
180+ frame = np .transpose (frame , (2 , 0 , 1 ))
181+ frame = torch .from_numpy (frame )
177182 frames .append (frame )
178183
179184 if len (frames ) == len (approx_frame_indices ):
@@ -184,9 +189,7 @@ def decode_frames(self, video_file, pts_list):
184189 return frames
185190
186191 def decode_first_n_frames (self , video_file , n ):
187- import cv2
188-
189- cap = cv2 .VideoCapture (video_file , self ._backend )
192+ cap = self .cv2 .VideoCapture (video_file , self ._backend )
190193 if not cap .isOpened ():
191194 raise ValueError ("Could not open video stream" )
192195
@@ -197,16 +200,21 @@ def decode_first_n_frames(self, video_file, n):
197200 raise ValueError ("Could not grab video frame" )
198201 ret , frame = cap .retrieve ()
199202 if ret :
203+ # OpenCV uses BGR, change to RGB
204+ frame = self .cv2 .cvtColor (frame , self .cv2 .COLOR_BGR2RGB )
205+ # Update to C, H, W
206+ frame = np .transpose (frame , (2 , 0 , 1 ))
207+ frame = torch .from_numpy (frame )
200208 frames .append (frame )
201209 cap .release ()
202210 assert len (frames ) == n
203211 return frames
204212
205213 def decode_and_resize (self , video_file , pts_list , height , width , device ):
206- import cv2
207214
215+ # OpenCV doesn't apply antialias, while other `decode_and_resize()` implementations apply antialias by default.
208216 frames = [
209- cv2 .resize (frame , (width , height ))
217+ self . cv2 .resize (frame , (width , height ))
210218 for frame in self .decode_frames (video_file , pts_list )
211219 ]
212220 return frames
0 commit comments