@@ -342,6 +342,65 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
342342 return frames
343343
344344
345+ class TorchCodecPublicNonBatch (AbstractDecoder ):
346+ def __init__ (self , num_ffmpeg_threads = None , device = "cpu" ):
347+ self ._num_ffmpeg_threads = num_ffmpeg_threads
348+ self ._device = device
349+
350+ from torchvision .transforms import v2 as transforms_v2
351+
352+ self .transforms_v2 = transforms_v2
353+
354+ def decode_frames (self , video_file , pts_list ):
355+ num_ffmpeg_threads = (
356+ int (self ._num_ffmpeg_threads ) if self ._num_ffmpeg_threads else 0
357+ )
358+ decoder = VideoDecoder (
359+ video_file , num_ffmpeg_threads = num_ffmpeg_threads , device = self ._device
360+ )
361+
362+ frames = []
363+ for pts in pts_list :
364+ frame = decoder .get_frame_played_at (pts )
365+ frames .append (frame )
366+ return frames
367+
368+ def decode_first_n_frames (self , video_file , n ):
369+ num_ffmpeg_threads = (
370+ int (self ._num_ffmpeg_threads ) if self ._num_ffmpeg_threads else 0
371+ )
372+ decoder = VideoDecoder (
373+ video_file , num_ffmpeg_threads = num_ffmpeg_threads , device = self ._device
374+ )
375+ frames = []
376+ count = 0
377+ for frame in decoder :
378+ frames .append (frame )
379+ count += 1
380+ if count == n :
381+ break
382+ return frames
383+
384+ def decode_and_resize (self , video_file , pts_list , height , width , device ):
385+ num_ffmpeg_threads = (
386+ int (self ._num_ffmpeg_threads ) if self ._num_ffmpeg_threads else 1
387+ )
388+ decoder = VideoDecoder (
389+ video_file , num_ffmpeg_threads = num_ffmpeg_threads , device = self ._device
390+ )
391+
392+ frames = []
393+ for pts in pts_list :
394+ frame = decoder .get_frame_played_at (pts )
395+ frames .append (frame )
396+
397+ frames = [
398+ self .transforms_v2 .functional .resize (frame .to (device ), (height , width ))
399+ for frame in frames
400+ ]
401+ return frames
402+
403+
345404@torch .compile (fullgraph = True , backend = "eager" )
346405def compiled_seek_and_next (decoder , pts ):
347406 seek_to_pts (decoder , pts )
0 commit comments