Skip to content

Commit 654a4dc

Browse files
author
pytorchbot
committed
2025-01-23 nightly release (f565e86)
1 parent 87675dc commit 654a4dc

19 files changed

+829
-474
lines changed

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
get_frames_by_pts,
2323
get_json_metadata,
2424
get_next_frame,
25-
scan_all_streams_to_update_metadata,
2625
seek_to_pts,
2726
)
2827

@@ -154,8 +153,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"
154153
self._device = device
155154

156155
def decode_frames(self, video_file, pts_list):
157-
decoder = create_from_file(video_file)
158-
scan_all_streams_to_update_metadata(decoder)
156+
decoder = create_from_file(video_file, seek_mode="exact")
159157
_add_video_stream(
160158
decoder,
161159
num_threads=self._num_threads,
@@ -170,7 +168,7 @@ def decode_frames(self, video_file, pts_list):
170168
return frames
171169

172170
def decode_first_n_frames(self, video_file, n):
173-
decoder = create_from_file(video_file)
171+
decoder = create_from_file(video_file, seek_mode="approximate")
174172
_add_video_stream(
175173
decoder,
176174
num_threads=self._num_threads,
@@ -197,7 +195,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"
197195
self.transforms_v2 = transforms_v2
198196

199197
def decode_frames(self, video_file, pts_list):
200-
decoder = create_from_file(video_file)
198+
decoder = create_from_file(video_file, seek_mode="approximate")
201199
num_threads = int(self._num_threads) if self._num_threads else 0
202200
_add_video_stream(
203201
decoder,
@@ -216,7 +214,7 @@ def decode_frames(self, video_file, pts_list):
216214

217215
def decode_first_n_frames(self, video_file, n):
218216
num_threads = int(self._num_threads) if self._num_threads else 0
219-
decoder = create_from_file(video_file)
217+
decoder = create_from_file(video_file, seek_mode="approximate")
220218
_add_video_stream(
221219
decoder,
222220
num_threads=num_threads,
@@ -233,7 +231,7 @@ def decode_first_n_frames(self, video_file, n):
233231

234232
def decode_and_resize(self, video_file, pts_list, height, width, device):
235233
num_threads = int(self._num_threads) if self._num_threads else 1
236-
decoder = create_from_file(video_file)
234+
decoder = create_from_file(video_file, seek_mode="approximate")
237235
_add_video_stream(
238236
decoder,
239237
num_threads=num_threads,
@@ -263,8 +261,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"
263261
self._device = device
264262

265263
def decode_frames(self, video_file, pts_list):
266-
decoder = create_from_file(video_file)
267-
scan_all_streams_to_update_metadata(decoder)
264+
decoder = create_from_file(video_file, seek_mode="exact")
268265
_add_video_stream(
269266
decoder,
270267
num_threads=self._num_threads,
@@ -279,8 +276,7 @@ def decode_frames(self, video_file, pts_list):
279276
return frames
280277

281278
def decode_first_n_frames(self, video_file, n):
282-
decoder = create_from_file(video_file)
283-
scan_all_streams_to_update_metadata(decoder)
279+
decoder = create_from_file(video_file, seek_mode="exact")
284280
_add_video_stream(
285281
decoder,
286282
num_threads=self._num_threads,
@@ -297,9 +293,10 @@ def decode_first_n_frames(self, video_file, n):
297293

298294

299295
class TorchCodecPublic(AbstractDecoder):
300-
def __init__(self, num_ffmpeg_threads=None, device="cpu"):
296+
def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="exact"):
301297
self._num_ffmpeg_threads = num_ffmpeg_threads
302298
self._device = device
299+
self._seek_mode = seek_mode
303300

304301
from torchvision.transforms import v2 as transforms_v2
305302

@@ -310,7 +307,10 @@ def decode_frames(self, video_file, pts_list):
310307
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0
311308
)
312309
decoder = VideoDecoder(
313-
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
310+
video_file,
311+
num_ffmpeg_threads=num_ffmpeg_threads,
312+
device=self._device,
313+
seek_mode=self._seek_mode,
314314
)
315315
return decoder.get_frames_played_at(pts_list)
316316

@@ -319,7 +319,10 @@ def decode_first_n_frames(self, video_file, n):
319319
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0
320320
)
321321
decoder = VideoDecoder(
322-
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
322+
video_file,
323+
num_ffmpeg_threads=num_ffmpeg_threads,
324+
device=self._device,
325+
seek_mode=self._seek_mode,
323326
)
324327
frames = []
325328
count = 0
@@ -335,17 +338,21 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
335338
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1
336339
)
337340
decoder = VideoDecoder(
338-
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
341+
video_file,
342+
num_ffmpeg_threads=num_ffmpeg_threads,
343+
device=self._device,
344+
seek_mode=self._seek_mode,
339345
)
340346
frames = decoder.get_frames_played_at(pts_list)
341347
frames = self.transforms_v2.functional.resize(frames.data, (height, width))
342348
return frames
343349

344350

345351
class TorchCodecPublicNonBatch(AbstractDecoder):
346-
def __init__(self, num_ffmpeg_threads=None, device="cpu"):
352+
def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="approximate"):
347353
self._num_ffmpeg_threads = num_ffmpeg_threads
348354
self._device = device
355+
self._seek_mode = seek_mode
349356

350357
from torchvision.transforms import v2 as transforms_v2
351358

@@ -356,7 +363,10 @@ def decode_frames(self, video_file, pts_list):
356363
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0
357364
)
358365
decoder = VideoDecoder(
359-
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
366+
video_file,
367+
num_ffmpeg_threads=num_ffmpeg_threads,
368+
device=self._device,
369+
seek_mode=self._seek_mode,
360370
)
361371

362372
frames = []
@@ -370,7 +380,10 @@ def decode_first_n_frames(self, video_file, n):
370380
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0
371381
)
372382
decoder = VideoDecoder(
373-
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
383+
video_file,
384+
num_ffmpeg_threads=num_ffmpeg_threads,
385+
device=self._device,
386+
seek_mode=self._seek_mode,
374387
)
375388
frames = []
376389
count = 0
@@ -386,7 +399,10 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
386399
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1
387400
)
388401
decoder = VideoDecoder(
389-
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
402+
video_file,
403+
num_ffmpeg_threads=num_ffmpeg_threads,
404+
device=self._device,
405+
seek_mode=self._seek_mode,
390406
)
391407

392408
frames = []

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace facebook::torchcodec {
1616

1717
void convertAVFrameToDecodedOutputOnCuda(
1818
const torch::Device& device,
19-
[[maybe_unused]] const VideoDecoder::VideoStreamDecoderOptions& options,
19+
[[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions,
2020
[[maybe_unused]] VideoDecoder::RawDecodedOutput& rawOutput,
2121
[[maybe_unused]] VideoDecoder::DecodedOutput& output,
2222
[[maybe_unused]] std::optional<torch::Tensor> preAllocatedOutputTensor) {

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,18 @@ void initializeContextOnCuda(
185185

186186
void convertAVFrameToDecodedOutputOnCuda(
187187
const torch::Device& device,
188-
const VideoDecoder::VideoStreamDecoderOptions& options,
188+
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
189189
VideoDecoder::RawDecodedOutput& rawOutput,
190190
VideoDecoder::DecodedOutput& output,
191191
std::optional<torch::Tensor> preAllocatedOutputTensor) {
192-
AVFrame* src = rawOutput.frame.get();
192+
AVFrame* avFrame = rawOutput.avFrame.get();
193193

194194
TORCH_CHECK(
195-
src->format == AV_PIX_FMT_CUDA,
195+
avFrame->format == AV_PIX_FMT_CUDA,
196196
"Expected format to be AV_PIX_FMT_CUDA, got " +
197-
std::string(av_get_pix_fmt_name((AVPixelFormat)src->format)));
198-
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(options, *src);
197+
std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format)));
198+
auto frameDims =
199+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *avFrame);
199200
int height = frameDims.height;
200201
int width = frameDims.width;
201202
torch::Tensor& dst = output.frame;
@@ -212,28 +213,28 @@ void convertAVFrameToDecodedOutputOnCuda(
212213
"x3, got ",
213214
shape);
214215
} else {
215-
dst = allocateEmptyHWCTensor(height, width, options.device);
216+
dst = allocateEmptyHWCTensor(height, width, videoStreamOptions.device);
216217
}
217218

218219
// Use the user-requested GPU for running the NPP kernel.
219220
c10::cuda::CUDAGuard deviceGuard(device);
220221

221222
NppiSize oSizeROI = {width, height};
222-
Npp8u* input[2] = {src->data[0], src->data[1]};
223+
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};
223224

224225
auto start = std::chrono::high_resolution_clock::now();
225226
NppStatus status;
226-
if (src->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
227+
if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {
227228
status = nppiNV12ToRGB_709CSC_8u_P2C3R(
228229
input,
229-
src->linesize[0],
230+
avFrame->linesize[0],
230231
static_cast<Npp8u*>(dst.data_ptr()),
231232
dst.stride(0),
232233
oSizeROI);
233234
} else {
234235
status = nppiNV12ToRGB_8u_P2C3R(
235236
input,
236-
src->linesize[0],
237+
avFrame->linesize[0],
237238
static_cast<Npp8u*>(dst.data_ptr()),
238239
dst.stride(0),
239240
oSizeROI);

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ void initializeContextOnCuda(
3131

3232
void convertAVFrameToDecodedOutputOnCuda(
3333
const torch::Device& device,
34-
const VideoDecoder::VideoStreamDecoderOptions& options,
34+
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
3535
VideoDecoder::RawDecodedOutput& rawOutput,
3636
VideoDecoder::DecodedOutput& output,
3737
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,25 @@
1010

1111
namespace facebook::torchcodec {
1212

13+
AutoAVPacket::AutoAVPacket() : avPacket_(av_packet_alloc()) {
14+
TORCH_CHECK(avPacket_ != nullptr, "Couldn't allocate avPacket.");
15+
}
16+
AutoAVPacket::~AutoAVPacket() {
17+
av_packet_free(&avPacket_);
18+
}
19+
20+
ReferenceAVPacket::ReferenceAVPacket(AutoAVPacket& shared)
21+
: avPacket_(shared.avPacket_) {}
22+
ReferenceAVPacket::~ReferenceAVPacket() {
23+
av_packet_unref(avPacket_);
24+
}
25+
AVPacket* ReferenceAVPacket::get() {
26+
return avPacket_;
27+
}
28+
AVPacket* ReferenceAVPacket::operator->() {
29+
return avPacket_;
30+
}
31+
1332
AVCodecOnlyUseForCallingAVFindBestStream
1433
makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec) {
1534
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100)

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ using UniqueAVCodecContext = std::unique_ptr<
5757
Deleterp<AVCodecContext, void, avcodec_free_context>>;
5858
using UniqueAVFrame =
5959
std::unique_ptr<AVFrame, Deleterp<AVFrame, void, av_frame_free>>;
60-
using UniqueAVPacket =
61-
std::unique_ptr<AVPacket, Deleterp<AVPacket, void, av_packet_free>>;
6260
using UniqueAVFilterGraph = std::unique_ptr<
6361
AVFilterGraph,
6462
Deleterp<AVFilterGraph, void, avfilter_graph_free>>;
@@ -70,6 +68,44 @@ using UniqueAVIOContext = std::
7068
using UniqueSwsContext =
7169
std::unique_ptr<SwsContext, Deleter<SwsContext, void, sws_freeContext>>;
7270

71+
// These 2 classes share the same underlying AVPacket object. They are meant to
72+
// be used in tandem, like so:
73+
//
74+
// AutoAVPacket autoAVPacket; // <-- malloc for AVPacket happens here
75+
// while(...){
76+
// ReferenceAVPacket packet(autoAVPacket);
77+
// av_read_frame(..., packet.get()); <-- av_packet_ref() called by FFmpeg
78+
// } <-- av_packet_unref() called here
79+
//
80+
// This achieves a few desirable things:
81+
// - Memory allocation of the underlying AVPacket happens only once, when
82+
// autoAVPacket is created.
83+
// - av_packet_free() is called when autoAVPacket gets out of scope
84+
// - av_packet_unref() is automatically called when needed, i.e. at the end of
85+
// each loop iteration (or when hitting break / continue). This prevents the
86+
// risk of us forgetting to call it.
87+
class AutoAVPacket {
88+
friend class ReferenceAVPacket;
89+
90+
private:
91+
AVPacket* avPacket_;
92+
93+
public:
94+
AutoAVPacket();
95+
~AutoAVPacket();
96+
};
97+
98+
class ReferenceAVPacket {
99+
private:
100+
AVPacket* avPacket_;
101+
102+
public:
103+
ReferenceAVPacket(AutoAVPacket& shared);
104+
~ReferenceAVPacket();
105+
AVPacket* get();
106+
AVPacket* operator->();
107+
};
108+
73109
// av_find_best_stream is not const-correct before commit:
74110
// https://github.com/FFmpeg/FFmpeg/commit/46dac8cf3d250184ab4247809bc03f60e14f4c0c
75111
// which was released in FFMPEG version=5.0.3

0 commit comments

Comments
 (0)