Skip to content

Commit d914615

Browse files
committed
Handle options
1 parent 85569fb commit d914615

File tree

7 files changed

+71
-71
lines changed

7 files changed

+71
-71
lines changed

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace facebook::torchcodec {
1616

1717
void convertAVFrameToDecodedOutputOnCuda(
1818
const torch::Device& device,
19-
[[maybe_unused]] const VideoDecoder::VideoStreamDecoderOptions& options,
19+
[[maybe_unused]] const VideoDecoder::VideoStreamOptions& options,
2020
[[maybe_unused]] VideoDecoder::RawDecodedOutput& rawOutput,
2121
[[maybe_unused]] VideoDecoder::DecodedOutput& output,
2222
[[maybe_unused]] std::optional<torch::Tensor> preAllocatedOutputTensor) {

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ void initializeContextOnCuda(
185185

186186
void convertAVFrameToDecodedOutputOnCuda(
187187
const torch::Device& device,
188-
const VideoDecoder::VideoStreamDecoderOptions& options,
188+
const VideoDecoder::VideoStreamOptions& options,
189189
VideoDecoder::RawDecodedOutput& rawOutput,
190190
VideoDecoder::DecodedOutput& output,
191191
std::optional<torch::Tensor> preAllocatedOutputTensor) {

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ void initializeContextOnCuda(
3131

3232
void convertAVFrameToDecodedOutputOnCuda(
3333
const torch::Device& device,
34-
const VideoDecoder::VideoStreamDecoderOptions& options,
34+
const VideoDecoder::VideoStreamOptions& options,
3535
VideoDecoder::RawDecodedOutput& rawOutput,
3636
VideoDecoder::DecodedOutput& output,
3737
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibrary(
124124
torch::Tensor VideoDecoder::maybePermuteHWC2CHW(
125125
int streamIndex,
126126
torch::Tensor& hwcTensor) {
127-
if (streamInfos_[streamIndex].options.dimensionOrder == "NHWC") {
127+
if (streamInfos_[streamIndex].videoStreamOptions.dimensionOrder == "NHWC") {
128128
return hwcTensor;
129129
}
130130
auto numDimensions = hwcTensor.dim();
@@ -141,7 +141,7 @@ torch::Tensor VideoDecoder::maybePermuteHWC2CHW(
141141
}
142142
}
143143

144-
VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions(
144+
VideoDecoder::VideoStreamOptions::VideoStreamOptions(
145145
const std::string& optionsString) {
146146
std::vector<std::string> tokens =
147147
splitStringWithDelimiters(optionsString, ",");
@@ -194,14 +194,14 @@ VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions(
194194

195195
VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
196196
int64_t numFrames,
197-
const VideoStreamDecoderOptions& options,
197+
const VideoStreamOptions& videoStreamOptions,
198198
const StreamMetadata& streamMetadata)
199199
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
200200
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
201-
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, streamMetadata);
201+
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(videoStreamOptions, streamMetadata);
202202
int height = frameDims.height;
203203
int width = frameDims.width;
204-
frames = allocateEmptyHWCTensor(height, width, options.device, numFrames);
204+
frames = allocateEmptyHWCTensor(height, width, videoStreamOptions.device, numFrames);
205205
}
206206

207207
bool VideoDecoder::DecodedFrameContext::operator==(
@@ -338,9 +338,9 @@ void VideoDecoder::createFilterGraph(
338338
filterState.filterGraph.reset(avfilter_graph_alloc());
339339
TORCH_CHECK(filterState.filterGraph.get() != nullptr);
340340

341-
if (streamInfo.options.ffmpegThreadCount.has_value()) {
341+
if (streamInfo.videoStreamOptions.ffmpegThreadCount.has_value()) {
342342
filterState.filterGraph->nb_threads =
343-
streamInfo.options.ffmpegThreadCount.value();
343+
streamInfo.videoStreamOptions.ffmpegThreadCount.value();
344344
}
345345

346346
const AVFilter* buffersrc = avfilter_get_by_name("buffer");
@@ -444,7 +444,7 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
444444

445445
void VideoDecoder::addVideoStreamDecoder(
446446
int preferredStreamIndex,
447-
const VideoStreamDecoderOptions& options) {
447+
const VideoStreamOptions& videoStreamOptions) {
448448
if (activeStreamIndices_.count(preferredStreamIndex) > 0) {
449449
throw std::invalid_argument(
450450
"Stream with index " + std::to_string(preferredStreamIndex) +
@@ -484,26 +484,26 @@ void VideoDecoder::addVideoStreamDecoder(
484484
" is not a video stream.");
485485
}
486486

487-
if (options.device.type() == torch::kCUDA) {
488-
codec = findCudaCodec(options.device, streamInfo.stream->codecpar->codec_id)
487+
if (videoStreamOptions.device.type() == torch::kCUDA) {
488+
codec = findCudaCodec(videoStreamOptions.device, streamInfo.stream->codecpar->codec_id)
489489
.value_or(codec);
490490
}
491491

492492
AVCodecContext* codecContext = avcodec_alloc_context3(codec);
493493
TORCH_CHECK(codecContext != nullptr);
494-
codecContext->thread_count = options.ffmpegThreadCount.value_or(0);
494+
codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0);
495495
streamInfo.codecContext.reset(codecContext);
496496

497497
int retVal = avcodec_parameters_to_context(
498498
streamInfo.codecContext.get(), streamInfo.stream->codecpar);
499499
TORCH_CHECK_EQ(retVal, AVSUCCESS);
500500

501-
if (options.device.type() == torch::kCPU) {
501+
if (videoStreamOptions.device.type() == torch::kCPU) {
502502
// No more initialization needed for CPU.
503-
} else if (options.device.type() == torch::kCUDA) {
504-
initializeContextOnCuda(options.device, codecContext);
503+
} else if (videoStreamOptions.device.type() == torch::kCUDA) {
504+
initializeContextOnCuda(videoStreamOptions.device, codecContext);
505505
} else {
506-
TORCH_CHECK(false, "Invalid device type: " + options.device.str());
506+
TORCH_CHECK(false, "Invalid device type: " + videoStreamOptions.device.str());
507507
}
508508

509509
retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr);
@@ -514,7 +514,7 @@ void VideoDecoder::addVideoStreamDecoder(
514514
codecContext->time_base = streamInfo.stream->time_base;
515515
activeStreamIndices_.insert(streamIndex);
516516
updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext);
517-
streamInfo.options = options;
517+
streamInfo.videoStreamOptions = videoStreamOptions;
518518

519519
// By default, we want to use swscale for color conversion because it is
520520
// faster. However, it has width requirements, so we may need to fall back
@@ -523,10 +523,10 @@ void VideoDecoder::addVideoStreamDecoder(
523523
// swscale's width requirements to be violated. We don't expose the ability to
524524
// choose color conversion library publicly; we only use this ability
525525
// internally.
526-
int width = options.width.value_or(codecContext->width);
526+
int width = videoStreamOptions.width.value_or(codecContext->width);
527527
auto defaultLibrary = getDefaultColorConversionLibrary(width);
528528
streamInfo.colorConversionLibrary =
529-
options.colorConversionLibrary.value_or(defaultLibrary);
529+
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
530530
}
531531

532532
void VideoDecoder::updateMetadataWithCodecContext(
@@ -920,19 +920,19 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
920920
output.durationSeconds = ptsToSeconds(
921921
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
922922
// TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput.
923-
if (streamInfo.options.device.type() == torch::kCPU) {
923+
if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
924924
convertAVFrameToDecodedOutputOnCPU(
925925
rawOutput, output, preAllocatedOutputTensor);
926-
} else if (streamInfo.options.device.type() == torch::kCUDA) {
926+
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) {
927927
convertAVFrameToDecodedOutputOnCuda(
928-
streamInfo.options.device,
929-
streamInfo.options,
928+
streamInfo.videoStreamOptions.device,
929+
streamInfo.videoStreamOptions,
930930
rawOutput,
931931
output,
932932
preAllocatedOutputTensor);
933933
} else {
934934
TORCH_CHECK(
935-
false, "Invalid device type: " + streamInfo.options.device.str());
935+
false, "Invalid device type: " + streamInfo.videoStreamOptions.device.str());
936936
}
937937
return output;
938938
}
@@ -955,7 +955,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
955955
auto& streamInfo = streamInfos_[streamIndex];
956956

957957
auto frameDims =
958-
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *avFrame);
958+
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.videoStreamOptions, *avFrame);
959959
int expectedOutputHeight = frameDims.height;
960960
int expectedOutputWidth = frameDims.width;
961961

@@ -1262,8 +1262,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
12621262

12631263
const auto& streamMetadata = containerMetadata_.streamMetadatas[streamIndex];
12641264
const auto& streamInfo = streamInfos_[streamIndex];
1265-
const auto& options = streamInfo.options;
1266-
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);
1265+
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
1266+
BatchDecodedOutput output(frameIndices.size(), videoStreamOptions, streamMetadata);
12671267

12681268
auto previousIndexInVideo = -1;
12691269
for (size_t f = 0; f < frameIndices.size(); ++f) {
@@ -1344,8 +1344,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
13441344
step > 0, "Step must be greater than 0; is " + std::to_string(step));
13451345

13461346
int64_t numOutputFrames = std::ceil((stop - start) / double(step));
1347-
const auto& options = streamInfo.options;
1348-
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);
1347+
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
1348+
BatchDecodedOutput output(numOutputFrames, videoStreamOptions, streamMetadata);
13491349

13501350
for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
13511351
DecodedOutput singleOut =
@@ -1372,7 +1372,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
13721372
std::to_string(stopSeconds) + ".");
13731373

13741374
const auto& streamInfo = streamInfos_[streamIndex];
1375-
const auto& options = streamInfo.options;
1375+
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
13761376

13771377
// Special case needed to implement a half-open range. At first glance, this
13781378
// may seem unnecessary, as our search for stopFrame can return the end, and
@@ -1392,7 +1392,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
13921392
// values of the intervals will map to the same frame indices below. Hence, we
13931393
// need this special case below.
13941394
if (startSeconds == stopSeconds) {
1395-
BatchDecodedOutput output(0, options, streamMetadata);
1395+
BatchDecodedOutput output(0, videoStreamOptions, streamMetadata);
13961396
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);
13971397
return output;
13981398
}
@@ -1429,7 +1429,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
14291429
secondsToIndexUpperBound(stopSeconds, streamInfo, streamMetadata);
14301430
int64_t numFrames = stopFrameIndex - startFrameIndex;
14311431

1432-
BatchDecodedOutput output(numFrames, options, streamMetadata);
1432+
BatchDecodedOutput output(numFrames, videoStreamOptions, streamMetadata);
14331433
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
14341434
DecodedOutput singleOut =
14351435
getFrameAtIndexInternal(streamIndex, i, output.frames[f]);
@@ -1584,7 +1584,7 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
15841584

15851585
VideoDecoder::~VideoDecoder() {
15861586
for (auto& [streamIndex, streamInfo] : streamInfos_) {
1587-
auto& device = streamInfo.options.device;
1587+
auto& device = streamInfo.videoStreamOptions.device;
15881588
if (device.type() == torch::kCPU) {
15891589
} else if (device.type() == torch::kCUDA) {
15901590
releaseContextOnCuda(device, streamInfo.codecContext.get());
@@ -1599,19 +1599,19 @@ FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) {
15991599
}
16001600

16011601
FrameDims getHeightAndWidthFromOptionsOrMetadata(
1602-
const VideoDecoder::VideoStreamDecoderOptions& options,
1602+
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
16031603
const VideoDecoder::StreamMetadata& streamMetadata) {
16041604
return FrameDims(
1605-
options.height.value_or(*streamMetadata.height),
1606-
options.width.value_or(*streamMetadata.width));
1605+
videoStreamOptions.height.value_or(*streamMetadata.height),
1606+
videoStreamOptions.width.value_or(*streamMetadata.width));
16071607
}
16081608

16091609
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
1610-
const VideoDecoder::VideoStreamDecoderOptions& options,
1610+
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
16111611
const AVFrame& avFrame) {
16121612
return FrameDims(
1613-
options.height.value_or(avFrame.height),
1614-
options.width.value_or(avFrame.width));
1613+
videoStreamOptions.height.value_or(avFrame.height),
1614+
videoStreamOptions.width.value_or(avFrame.width));
16151615
}
16161616

16171617
torch::Tensor allocateEmptyHWCTensor(

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ class VideoDecoder {
130130
// Use the libswscale library for color conversion.
131131
SWSCALE
132132
};
133-
struct VideoStreamDecoderOptions {
134-
VideoStreamDecoderOptions() {}
135-
explicit VideoStreamDecoderOptions(const std::string& optionsString);
133+
struct VideoStreamOptions {
134+
VideoStreamOptions() {}
135+
explicit VideoStreamOptions(const std::string& optionsString);
136136
// Number of threads we pass to FFMPEG for decoding.
137137
// 0 means FFMPEG will choose the number of threads automatically to fully
138138
// utilize all cores. If not set, it will be the default FFMPEG behavior for
@@ -149,13 +149,13 @@ class VideoDecoder {
149149
// By default we use CPU for decoding for both C++ and python users.
150150
torch::Device device = torch::kCPU;
151151
};
152-
struct AudioStreamDecoderOptions {};
152+
struct AudioStreamOptions {};
153153
void addVideoStreamDecoder(
154154
int streamIndex,
155-
const VideoStreamDecoderOptions& options = VideoStreamDecoderOptions());
155+
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions());
156156
void addAudioStreamDecoder(
157157
int streamIndex,
158-
const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions());
158+
const AudioStreamOptions& audioStreamOptions = AudioStreamOptions());
159159

160160
torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
161161

@@ -214,7 +214,7 @@ class VideoDecoder {
214214

215215
explicit BatchDecodedOutput(
216216
int64_t numFrames,
217-
const VideoStreamDecoderOptions& options,
217+
const VideoStreamOptions& videoStreamOptions,
218218
const StreamMetadata& streamMetadata);
219219
};
220220

@@ -313,7 +313,7 @@ class VideoDecoder {
313313
// this pts to the user when they request a frame.
314314
// We update this field if the user requested a seek.
315315
int64_t discardFramesBeforePts = INT64_MIN;
316-
VideoStreamDecoderOptions options;
316+
VideoStreamOptions videoStreamOptions;
317317
// The filter state associated with this stream (for video streams). The
318318
// actual graph will be nullptr for inactive streams.
319319
FilterState filterState;
@@ -488,11 +488,11 @@ struct FrameDims {
488488
FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame);
489489

490490
FrameDims getHeightAndWidthFromOptionsOrMetadata(
491-
const VideoDecoder::VideoStreamDecoderOptions& options,
491+
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
492492
const VideoDecoder::StreamMetadata& streamMetadata);
493493

494494
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
495-
const VideoDecoder::VideoStreamDecoderOptions& options,
495+
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
496496
const AVFrame& avFrame);
497497

498498
torch::Tensor allocateEmptyHWCTensor(

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -180,23 +180,23 @@ void _add_video_stream(
180180
std::optional<int64_t> stream_index,
181181
std::optional<std::string_view> device,
182182
std::optional<std::string_view> color_conversion_library) {
183-
VideoDecoder::VideoStreamDecoderOptions options;
184-
options.width = width;
185-
options.height = height;
186-
options.ffmpegThreadCount = num_threads;
183+
VideoDecoder::VideoStreamOptions videoStreamOptions;
184+
videoStreamOptions.width = width;
185+
videoStreamOptions.height = height;
186+
videoStreamOptions.ffmpegThreadCount = num_threads;
187187

188188
if (dimension_order.has_value()) {
189189
std::string stdDimensionOrder{dimension_order.value()};
190190
TORCH_CHECK(stdDimensionOrder == "NHWC" || stdDimensionOrder == "NCHW");
191-
options.dimensionOrder = stdDimensionOrder;
191+
videoStreamOptions.dimensionOrder = stdDimensionOrder;
192192
}
193193
if (color_conversion_library.has_value()) {
194194
std::string stdColorConversionLibrary{color_conversion_library.value()};
195195
if (stdColorConversionLibrary == "filtergraph") {
196-
options.colorConversionLibrary =
196+
videoStreamOptions.colorConversionLibrary =
197197
VideoDecoder::ColorConversionLibrary::FILTERGRAPH;
198198
} else if (stdColorConversionLibrary == "swscale") {
199-
options.colorConversionLibrary =
199+
videoStreamOptions.colorConversionLibrary =
200200
VideoDecoder::ColorConversionLibrary::SWSCALE;
201201
} else {
202202
throw std::runtime_error(
@@ -206,10 +206,10 @@ void _add_video_stream(
206206
}
207207
if (device.has_value()) {
208208
if (device.value() == "cpu") {
209-
options.device = torch::Device(torch::kCPU);
209+
videoStreamOptions.device = torch::Device(torch::kCPU);
210210
} else if (device.value().rfind("cuda", 0) == 0) { // starts with "cuda"
211211
std::string deviceStr(device.value());
212-
options.device = torch::Device(deviceStr);
212+
videoStreamOptions.device = torch::Device(deviceStr);
213213
} else {
214214
throw std::runtime_error(
215215
"Invalid device=" + std::string(device.value()) +
@@ -218,7 +218,7 @@ void _add_video_stream(
218218
}
219219

220220
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
221-
videoDecoder->addVideoStreamDecoder(stream_index.value_or(-1), options);
221+
videoDecoder->addVideoStreamDecoder(stream_index.value_or(-1), videoStreamOptions);
222222
}
223223

224224
void seek_to_pts(at::Tensor& decoder, double seconds) {

0 commit comments

Comments
 (0)