Skip to content

Commit d37d3a5

Browse files
authored
Add public, nonbatch option for benchmarking (#438)
1 parent 374b44a commit d37d3a5

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

benchmarks/decoders/benchmark_decoders.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
TorchCodecCoreCompiled,
2727
TorchCodecCoreNonBatch,
2828
TorchCodecPublic,
29+
TorchCodecPublicNonBatch,
2930
TorchVision,
3031
)
3132

@@ -49,6 +50,9 @@ class DecoderKind:
4950
"TorchCodecCoreCompiled", TorchCodecCoreCompiled
5051
),
5152
"torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic),
53+
"torchcodec_public_nonbatch": DecoderKind(
54+
"TorchCodecPublicNonBatch", TorchCodecPublicNonBatch
55+
),
5256
"torchvision": DecoderKind(
5357
# We don't compare against TorchVision's "pyav" backend because it doesn't support
5458
# accurate seeks.

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,65 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
342342
return frames
343343

344344

345+
class TorchCodecPublicNonBatch(AbstractDecoder):
346+
def __init__(self, num_ffmpeg_threads=None, device="cpu"):
347+
self._num_ffmpeg_threads = num_ffmpeg_threads
348+
self._device = device
349+
350+
from torchvision.transforms import v2 as transforms_v2
351+
352+
self.transforms_v2 = transforms_v2
353+
354+
def decode_frames(self, video_file, pts_list):
355+
num_ffmpeg_threads = (
356+
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0
357+
)
358+
decoder = VideoDecoder(
359+
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
360+
)
361+
362+
frames = []
363+
for pts in pts_list:
364+
frame = decoder.get_frame_played_at(pts)
365+
frames.append(frame)
366+
return frames
367+
368+
def decode_first_n_frames(self, video_file, n):
369+
num_ffmpeg_threads = (
370+
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0
371+
)
372+
decoder = VideoDecoder(
373+
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
374+
)
375+
frames = []
376+
count = 0
377+
for frame in decoder:
378+
frames.append(frame)
379+
count += 1
380+
if count == n:
381+
break
382+
return frames
383+
384+
def decode_and_resize(self, video_file, pts_list, height, width, device):
385+
num_ffmpeg_threads = (
386+
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1
387+
)
388+
decoder = VideoDecoder(
389+
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
390+
)
391+
392+
frames = []
393+
for pts in pts_list:
394+
frame = decoder.get_frame_played_at(pts)
395+
frames.append(frame)
396+
397+
frames = [
398+
self.transforms_v2.functional.resize(frame.to(device), (height, width))
399+
for frame in frames
400+
]
401+
return frames
402+
403+
345404
@torch.compile(fullgraph=True, backend="eager")
346405
def compiled_seek_and_next(decoder, pts):
347406
seek_to_pts(decoder, pts)

0 commit comments

Comments
 (0)