Skip to content

Commit aaf9504

Browse files
committed
Remove streamType field from DecodedOutput
1 parent d26bfbc commit aaf9504

File tree

2 files changed

+67
-78
lines changed

2 files changed

+67
-78
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 67 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,6 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
869869
AVFrame* frame = rawOutput.frame.get();
870870
output.streamIndex = streamIndex;
871871
auto& streamInfo = streams_[streamIndex];
872-
output.streamType = streams_[streamIndex].stream->codecpar->codec_type;
873872
output.pts = frame->pts;
874873
output.ptsSeconds =
875874
ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base);
@@ -930,86 +929,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
930929
}
931930

932931
torch::Tensor outputTensor;
933-
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
934-
// We need to compare the current frame context with our previous frame
935-
// context. If they are different, then we need to re-create our colorspace
936-
// conversion objects. We create our colorspace conversion objects late so
937-
// that we don't have to depend on the unreliable metadata in the header.
938-
// And we sometimes re-create them because it's possible for frame
939-
// resolution to change mid-stream. Finally, we want to reuse the colorspace
940-
// conversion objects as much as possible for performance reasons.
941-
enum AVPixelFormat frameFormat =
942-
static_cast<enum AVPixelFormat>(frame->format);
943-
auto frameContext = DecodedFrameContext{
944-
frame->width,
945-
frame->height,
946-
frameFormat,
947-
expectedOutputWidth,
948-
expectedOutputHeight};
932+
// We need to compare the current frame context with our previous frame
933+
// context. If they are different, then we need to re-create our colorspace
934+
// conversion objects. We create our colorspace conversion objects late so
935+
// that we don't have to depend on the unreliable metadata in the header.
936+
// And we sometimes re-create them because it's possible for frame
937+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
938+
// conversion objects as much as possible for performance reasons.
939+
enum AVPixelFormat frameFormat =
940+
static_cast<enum AVPixelFormat>(frame->format);
941+
auto frameContext = DecodedFrameContext{
942+
frame->width,
943+
frame->height,
944+
frameFormat,
945+
expectedOutputWidth,
946+
expectedOutputHeight};
949947

950-
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
951-
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
952-
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
948+
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
949+
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
950+
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
953951

954-
if (!streamInfo.swsContext ||
955-
streamInfo.prevFrameContext != frameContext) {
956-
createSwsContext(streamInfo, frameContext, frame->colorspace);
957-
streamInfo.prevFrameContext = frameContext;
958-
}
959-
int resultHeight =
960-
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
961-
// If this check failed, it would mean that the frame wasn't reshaped to
962-
// the expected height.
963-
// TODO: Can we do the same check for width?
964-
TORCH_CHECK(
965-
resultHeight == expectedOutputHeight,
966-
"resultHeight != expectedOutputHeight: ",
967-
resultHeight,
968-
" != ",
969-
expectedOutputHeight);
952+
if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
953+
createSwsContext(streamInfo, frameContext, frame->colorspace);
954+
streamInfo.prevFrameContext = frameContext;
955+
}
956+
int resultHeight =
957+
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
958+
// If this check failed, it would mean that the frame wasn't reshaped to
959+
// the expected height.
960+
// TODO: Can we do the same check for width?
961+
TORCH_CHECK(
962+
resultHeight == expectedOutputHeight,
963+
"resultHeight != expectedOutputHeight: ",
964+
resultHeight,
965+
" != ",
966+
expectedOutputHeight);
967+
968+
output.frame = outputTensor;
969+
} else if (
970+
streamInfo.colorConversionLibrary ==
971+
ColorConversionLibrary::FILTERGRAPH) {
972+
if (!streamInfo.filterState.filterGraph ||
973+
streamInfo.prevFrameContext != frameContext) {
974+
createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth);
975+
streamInfo.prevFrameContext = frameContext;
976+
}
977+
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
970978

971-
output.frame = outputTensor;
972-
} else if (
973-
streamInfo.colorConversionLibrary ==
974-
ColorConversionLibrary::FILTERGRAPH) {
975-
if (!streamInfo.filterState.filterGraph ||
976-
streamInfo.prevFrameContext != frameContext) {
977-
createFilterGraph(
978-
streamInfo, expectedOutputHeight, expectedOutputWidth);
979-
streamInfo.prevFrameContext = frameContext;
980-
}
981-
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
982-
983-
// Similarly to above, if this check fails it means the frame wasn't
984-
// reshaped to its expected dimensions by filtergraph.
985-
auto shape = outputTensor.sizes();
986-
TORCH_CHECK(
987-
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
988-
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
989-
"Expected output tensor of shape ",
990-
expectedOutputHeight,
991-
"x",
992-
expectedOutputWidth,
993-
"x3, got ",
994-
shape);
995-
996-
if (preAllocatedOutputTensor.has_value()) {
997-
// We have already validated that preAllocatedOutputTensor and
998-
// outputTensor have the same shape.
999-
preAllocatedOutputTensor.value().copy_(outputTensor);
1000-
output.frame = preAllocatedOutputTensor.value();
1001-
} else {
1002-
output.frame = outputTensor;
1003-
}
979+
// Similarly to above, if this check fails it means the frame wasn't
980+
// reshaped to its expected dimensions by filtergraph.
981+
auto shape = outputTensor.sizes();
982+
TORCH_CHECK(
983+
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
984+
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
985+
"Expected output tensor of shape ",
986+
expectedOutputHeight,
987+
"x",
988+
expectedOutputWidth,
989+
"x3, got ",
990+
shape);
991+
992+
if (preAllocatedOutputTensor.has_value()) {
993+
// We have already validated that preAllocatedOutputTensor and
994+
// outputTensor have the same shape.
995+
preAllocatedOutputTensor.value().copy_(outputTensor);
996+
output.frame = preAllocatedOutputTensor.value();
1004997
} else {
1005-
throw std::runtime_error(
1006-
"Invalid color conversion library: " +
1007-
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
998+
output.frame = outputTensor;
1008999
}
1009-
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
1010-
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
1011-
// audio decoding.
1012-
throw std::runtime_error("Audio is not supported yet.");
1000+
} else {
1001+
throw std::runtime_error(
1002+
"Invalid color conversion library: " +
1003+
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
10131004
}
10141005
}
10151006

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,6 @@ class VideoDecoder {
191191
struct DecodedOutput {
192192
// The actual decoded output as a Tensor.
193193
torch::Tensor frame;
194-
// Could be AVMEDIA_TYPE_VIDEO or AVMEDIA_TYPE_AUDIO.
195-
AVMediaType streamType;
196194
// The stream index of the decoded frame. Used to distinguish
197195
// between streams that are of the same type.
198196
int streamIndex;

0 commit comments

Comments
 (0)