Skip to content

Commit 809c7c7

Browse files
committed
More stuff, implement planar
2 parents befcabc + 6e9267b commit 809c7c7

File tree

2 files changed

+47
-28
lines changed

2 files changed

+47
-28
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -478,13 +478,16 @@ void VideoDecoder::addVideoStreamDecoder(
478478
TORCH_CHECK(formatContext_.get() != nullptr);
479479

480480
AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
481-
int streamIndex = av_find_best_stream(
482-
formatContext_.get(),
483-
AVMEDIA_TYPE_AUDIO,
484-
preferredStreamIndex,
485-
-1,
486-
&avCodec,
487-
0);
481+
// int streamIndex = av_find_best_stream(
482+
// formatContext_.get(),
483+
// AVMEDIA_TYPE_AUDIO,
484+
// preferredStreamIndex,
485+
// -1,
486+
// &avCodec,
487+
// 0);
488+
int streamIndex = preferredStreamIndex;
489+
avCodec = avcodec_find_decoder(formatContext_->streams[streamIndex]->codecpar->codec_id);
490+
488491
if (streamIndex < 0) {
489492
throw std::invalid_argument("No valid stream found in input file.");
490493
}
@@ -519,11 +522,7 @@ void VideoDecoder::addVideoStreamDecoder(
519522

520523
AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
521524
TORCH_CHECK(codecContext != nullptr);
522-
codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0);
523-
if (!codecContext->channel_layout) {
524-
codecContext->channel_layout =
525-
av_get_default_channel_layout(codecContext->channels);
526-
}
525+
// codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0);
527526
streamInfo.codecContext.reset(codecContext);
528527

529528
int retVal = avcodec_parameters_to_context(
@@ -539,23 +538,34 @@ void VideoDecoder::addVideoStreamDecoder(
539538
false, "Invalid device type: " + videoStreamOptions.device.str());
540539
}
541540

541+
if (!streamInfo.codecContext->channel_layout) {
542+
streamInfo.codecContext->channel_layout =
543+
av_get_default_channel_layout(streamInfo.codecContext->channels);
544+
}
545+
546+
AVDictionary* opt = nullptr;
547+
av_dict_set(&opt, "threads", "1", 0);
542548
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
543549
if (retVal < AVSUCCESS) {
544550
throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal));
545551
}
546552

547-
codecContext->time_base = streamInfo.stream->time_base;
553+
// codecContext->time_base = streamInfo.stream->time_base;
554+
// AVRational tb{0, 1};
555+
// codecContext->time_base = tb;
548556
activeStreamIndex_ = streamIndex;
549557
updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext);
550558
streamInfo.videoStreamOptions = videoStreamOptions;
551559

552560
// We will only need packets from the active stream, so we tell FFmpeg to
553561
// discard packets from the other streams. Note that av_read_frame() may still
554-
// return some of those undesired packets under some conditions, so it's still
555-
// important to discard/demux packets correctly in the inner decoding loop.
562+
// return some of those un-desired packet under some conditions, so it's still
563+
// important to discard/demux correctly in the inner decoding loop.
556564
for (unsigned int i = 0; i < formatContext_->nb_streams; ++i) {
557565
if (i != static_cast<unsigned int>(activeStreamIndex_)) {
558566
formatContext_->streams[i]->discard = AVDISCARD_ALL;
567+
} else {
568+
formatContext_->streams[i]->discard = AVDISCARD_DEFAULT;
559569
}
560570
}
561571

@@ -898,7 +908,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
898908
// --------------------------------------------------------------------------
899909
// SEEKING APIs
900910
// --------------------------------------------------------------------------
901-
902911
void VideoDecoder::setCursorPtsInSeconds(double seconds) {
903912
desiredPtsSeconds_ = seconds;
904913
}
@@ -986,6 +995,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
986995
desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts;
987996
}
988997

998+
printf("Seeking to PTS = %ld\n", desiredPts);
999+
9891000
int ffmepgStatus = avformat_seek_file(
9901001
formatContext_.get(),
9911002
streamInfo.streamIndex,
@@ -999,6 +1010,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
9991010
getFFMPEGErrorStringFromErrorCode(ffmepgStatus));
10001011
}
10011012
decodeStats_.numFlushes++;
1013+
1014+
printf("Flushing\n");
10021015
avcodec_flush_buffers(streamInfo.codecContext.get());
10031016
}
10041017

@@ -1232,13 +1245,19 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
12321245
av_get_bytes_per_sample(streamInfo.codecContext->sample_fmt);
12331246
// printf("bytes per sample: %d\n", bytesPerSample);
12341247

1235-
torch::Tensor data = torch::empty({numChannels, numSamples}, torch::kFloat32);
1236-
for (auto channel = 0; channel < numChannels; ++channel) {
1237-
float* dataFloatPtr = (float*)(avFrame->data[channel]);
1238-
for (auto sampleIndex = 0; sampleIndex < numSamples; ++sampleIndex) {
1239-
data[channel][sampleIndex] = dataFloatPtr[sampleIndex];
1240-
}
1241-
}
1248+
// float32 Planar
1249+
// torch::Tensor data = torch::empty({numChannels, numSamples}, torch::kFloat32);
1250+
// for (auto channel = 0; channel < numChannels; ++channel) {
1251+
// float* dataFloatPtr = (float*)(avFrame->data[channel]);
1252+
// for (auto sampleIndex = 0; sampleIndex < numSamples; ++sampleIndex) {
1253+
// data[channel][sampleIndex] = dataFloatPtr[sampleIndex];
1254+
// }
1255+
// }
1256+
// float32 non-Planar
1257+
torch::Tensor data = torch::empty({numSamples, numChannels}, torch::kFloat32);
1258+
uint8_t* pData = static_cast<uint8_t*>(data.data_ptr());
1259+
memcpy(pData, avFrame->extended_data[0], numSamples * numChannels * bytesPerSample);
1260+
data = data.permute({1, 0});
12421261

12431262
frameOutput.data = data;
12441263
return frameOutput;

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,11 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) {
237237
} catch (const VideoDecoder::EndOfFileException& e) {
238238
C10_THROW_ERROR(IndexError, e.what());
239239
}
240-
if (result.data.sizes().size() != 3) {
241-
throw std::runtime_error(
242-
"image_size is unexpected. Expected 3, got: " +
243-
std::to_string(result.data.sizes().size()));
244-
}
240+
// if (result.data.sizes().size() != 3) {
241+
// throw std::runtime_error(
242+
// "image_size is unexpected. Expected 3, got: " +
243+
// std::to_string(result.data.sizes().size()));
244+
// }
245245
return makeOpsFrameOutput(result);
246246
}
247247

0 commit comments

Comments
 (0)