2222 get_frames_by_pts ,
2323 get_json_metadata ,
2424 get_next_frame ,
25- scan_all_streams_to_update_metadata ,
2625 seek_to_pts ,
2726)
2827
@@ -154,8 +153,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"
154153 self ._device = device
155154
156155 def decode_frames (self , video_file , pts_list ):
157- decoder = create_from_file (video_file )
158- scan_all_streams_to_update_metadata (decoder )
156+ decoder = create_from_file (video_file , seek_mode = "exact" )
159157 _add_video_stream (
160158 decoder ,
161159 num_threads = self ._num_threads ,
@@ -170,7 +168,7 @@ def decode_frames(self, video_file, pts_list):
170168 return frames
171169
172170 def decode_first_n_frames (self , video_file , n ):
173- decoder = create_from_file (video_file )
171+ decoder = create_from_file (video_file , seek_mode = "approximate" )
174172 _add_video_stream (
175173 decoder ,
176174 num_threads = self ._num_threads ,
@@ -197,7 +195,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"
197195 self .transforms_v2 = transforms_v2
198196
199197 def decode_frames (self , video_file , pts_list ):
200- decoder = create_from_file (video_file )
198+ decoder = create_from_file (video_file , seek_mode = "approximate" )
201199 num_threads = int (self ._num_threads ) if self ._num_threads else 0
202200 _add_video_stream (
203201 decoder ,
@@ -216,7 +214,7 @@ def decode_frames(self, video_file, pts_list):
216214
217215 def decode_first_n_frames (self , video_file , n ):
218216 num_threads = int (self ._num_threads ) if self ._num_threads else 0
219- decoder = create_from_file (video_file )
217+ decoder = create_from_file (video_file , seek_mode = "approximate" )
220218 _add_video_stream (
221219 decoder ,
222220 num_threads = num_threads ,
@@ -233,7 +231,7 @@ def decode_first_n_frames(self, video_file, n):
233231
234232 def decode_and_resize (self , video_file , pts_list , height , width , device ):
235233 num_threads = int (self ._num_threads ) if self ._num_threads else 1
236- decoder = create_from_file (video_file )
234+ decoder = create_from_file (video_file , seek_mode = "approximate" )
237235 _add_video_stream (
238236 decoder ,
239237 num_threads = num_threads ,
@@ -263,8 +261,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"
263261 self ._device = device
264262
265263 def decode_frames (self , video_file , pts_list ):
266- decoder = create_from_file (video_file )
267- scan_all_streams_to_update_metadata (decoder )
264+ decoder = create_from_file (video_file , seek_mode = "exact" )
268265 _add_video_stream (
269266 decoder ,
270267 num_threads = self ._num_threads ,
@@ -279,8 +276,7 @@ def decode_frames(self, video_file, pts_list):
279276 return frames
280277
281278 def decode_first_n_frames (self , video_file , n ):
282- decoder = create_from_file (video_file )
283- scan_all_streams_to_update_metadata (decoder )
279+ decoder = create_from_file (video_file , seek_mode = "exact" )
284280 _add_video_stream (
285281 decoder ,
286282 num_threads = self ._num_threads ,
@@ -297,9 +293,10 @@ def decode_first_n_frames(self, video_file, n):
297293
298294
299295class TorchCodecPublic (AbstractDecoder ):
300- def __init__ (self , num_ffmpeg_threads = None , device = "cpu" ):
296+ def __init__ (self , num_ffmpeg_threads = None , device = "cpu" , seek_mode = "exact" ):
301297 self ._num_ffmpeg_threads = num_ffmpeg_threads
302298 self ._device = device
299+ self ._seek_mode = seek_mode
303300
304301 from torchvision .transforms import v2 as transforms_v2
305302
@@ -310,7 +307,10 @@ def decode_frames(self, video_file, pts_list):
310307 int (self ._num_ffmpeg_threads ) if self ._num_ffmpeg_threads else 0
311308 )
312309 decoder = VideoDecoder (
313- video_file , num_ffmpeg_threads = num_ffmpeg_threads , device = self ._device
310+ video_file ,
311+ num_ffmpeg_threads = num_ffmpeg_threads ,
312+ device = self ._device ,
313+ seek_mode = self ._seek_mode ,
314314 )
315315 return decoder .get_frames_played_at (pts_list )
316316
@@ -319,7 +319,10 @@ def decode_first_n_frames(self, video_file, n):
319319 int (self ._num_ffmpeg_threads ) if self ._num_ffmpeg_threads else 0
320320 )
321321 decoder = VideoDecoder (
322- video_file , num_ffmpeg_threads = num_ffmpeg_threads , device = self ._device
322+ video_file ,
323+ num_ffmpeg_threads = num_ffmpeg_threads ,
324+ device = self ._device ,
325+ seek_mode = self ._seek_mode ,
323326 )
324327 frames = []
325328 count = 0
@@ -335,17 +338,21 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
335338 int (self ._num_ffmpeg_threads ) if self ._num_ffmpeg_threads else 1
336339 )
337340 decoder = VideoDecoder (
338- video_file , num_ffmpeg_threads = num_ffmpeg_threads , device = self ._device
341+ video_file ,
342+ num_ffmpeg_threads = num_ffmpeg_threads ,
343+ device = self ._device ,
344+ seek_mode = self ._seek_mode ,
339345 )
340346 frames = decoder .get_frames_played_at (pts_list )
341347 frames = self .transforms_v2 .functional .resize (frames .data , (height , width ))
342348 return frames
343349
344350
345351class TorchCodecPublicNonBatch (AbstractDecoder ):
346- def __init__ (self , num_ffmpeg_threads = None , device = "cpu" ):
352+ def __init__ (self , num_ffmpeg_threads = None , device = "cpu" , seek_mode = "approximate" ):
347353 self ._num_ffmpeg_threads = num_ffmpeg_threads
348354 self ._device = device
355+ self ._seek_mode = seek_mode
349356
350357 from torchvision .transforms import v2 as transforms_v2
351358
@@ -356,7 +363,10 @@ def decode_frames(self, video_file, pts_list):
356363 int (self ._num_ffmpeg_threads ) if self ._num_ffmpeg_threads else 0
357364 )
358365 decoder = VideoDecoder (
359- video_file , num_ffmpeg_threads = num_ffmpeg_threads , device = self ._device
366+ video_file ,
367+ num_ffmpeg_threads = num_ffmpeg_threads ,
368+ device = self ._device ,
369+ seek_mode = self ._seek_mode ,
360370 )
361371
362372 frames = []
@@ -370,7 +380,10 @@ def decode_first_n_frames(self, video_file, n):
370380 int (self ._num_ffmpeg_threads ) if self ._num_ffmpeg_threads else 0
371381 )
372382 decoder = VideoDecoder (
373- video_file , num_ffmpeg_threads = num_ffmpeg_threads , device = self ._device
383+ video_file ,
384+ num_ffmpeg_threads = num_ffmpeg_threads ,
385+ device = self ._device ,
386+ seek_mode = self ._seek_mode ,
374387 )
375388 frames = []
376389 count = 0
@@ -386,7 +399,10 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
386399 int (self ._num_ffmpeg_threads ) if self ._num_ffmpeg_threads else 1
387400 )
388401 decoder = VideoDecoder (
389- video_file , num_ffmpeg_threads = num_ffmpeg_threads , device = self ._device
402+ video_file ,
403+ num_ffmpeg_threads = num_ffmpeg_threads ,
404+ device = self ._device ,
405+ seek_mode = self ._seek_mode ,
390406 )
391407
392408 frames = []
0 commit comments