Skip to content

Commit 59b0d15

Browse files
committed
remove next() support
1 parent ce12f03 commit 59b0d15

File tree

4 files changed

+4
-81
lines changed

4 files changed

+4
-81
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,6 @@ void VideoDecoder::addAudioStream(int streamIndex) {
550550
auto& streamInfo = streamInfos_[activeStreamIndex_];
551551
auto& streamMetadata =
552552
containerMetadata_.allStreamMetadata[activeStreamIndex_];
553-
554553
streamMetadata.sampleRate =
555554
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
556555
streamMetadata.numChannels = getNumChannels(streamInfo.codecContext);
@@ -562,12 +561,13 @@ void VideoDecoder::addAudioStream(int streamIndex) {
562561

563562
VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
564563
auto output = getNextFrameInternal();
565-
output.data = maybePermuteOutputTensor(output.data);
564+
output.data = maybePermuteHWC2CHW(output.data);
566565
return output;
567566
}
568567

569568
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
570569
std::optional<torch::Tensor> preAllocatedOutputTensor) {
570+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
571571
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
572572
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
573573
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
@@ -576,7 +576,6 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
576576
}
577577

578578
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
579-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
580579
auto frameOutput = getFrameAtIndexInternal(frameIndex);
581580
frameOutput.data = maybePermuteHWC2CHW(frameOutput.data);
582581
return frameOutput;
@@ -585,7 +584,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
585584
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal(
586585
int64_t frameIndex,
587586
std::optional<torch::Tensor> preAllocatedOutputTensor) {
588-
validateActiveStream();
587+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
589588

590589
const auto& streamInfo = streamInfos_[activeStreamIndex_];
591590
const auto& streamMetadata =
@@ -1389,17 +1388,6 @@ torch::Tensor allocateEmptyHWCTensor(
13891388
}
13901389
}
13911390

1392-
torch::Tensor VideoDecoder::maybePermuteOutputTensor(
1393-
torch::Tensor& outputTensor) {
1394-
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
1395-
return maybePermuteHWC2CHW(outputTensor);
1396-
} else {
1397-
// No need to do anything for audio. We always return (numChannels,
1398-
// numSamples) or (numFrames, numChannels, numSamples)
1399-
return outputTensor;
1400-
}
1401-
}
1402-
14031391
// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
14041392
// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
14051393
// or 4D.

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,6 @@ class VideoDecoder {
376376
FrameOutput getNextFrameInternal(
377377
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
378378

379-
torch::Tensor maybePermuteOutputTensor(torch::Tensor& outputTensor);
380379
torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor);
381380

382381
FrameOutput convertAVFrameToFrameOutput(

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,6 @@ void add_audio_stream(
234234
}
235235

236236
void seek_to_pts(at::Tensor& decoder, double seconds) {
237-
// TODO-AUDIO we should prevent more than one call to this op for audio
238-
// streams, for the same reasons we do so for getFramesPlayedInRange(). But we
239-
// can't implement the logic here, because we don't know media type (audio vs
240-
// video). We also can't do it within setCursorPtsInSeconds because it's used
241-
// by all other decoding methods. This isn't un-doable, just not easy with
242-
// the API we currently have.
243237
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
244238
videoDecoder->setCursorPtsInSeconds(seconds);
245239
}

test/decoders/test_ops.py

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ class TestAudioOps:
626626
partial(get_frames_in_range, start=4, stop=5),
627627
partial(get_frame_at_pts, seconds=2),
628628
partial(get_frames_by_pts, timestamps=[0, 1.5]),
629+
partial(get_next_frame),
629630
),
630631
)
631632
def test_audio_bad_method(self, method):
@@ -641,28 +642,6 @@ def test_audio_bad_seek_mode(self):
641642
):
642643
add_audio_stream(decoder)
643644

644-
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
645-
def test_audio_decode_all_samples_with_next(self, asset):
646-
decoder = create_from_file(str(asset.path), seek_mode="approximate")
647-
add_audio_stream(decoder)
648-
649-
reference_frames = [
650-
asset.get_frame_data_by_index(i) for i in range(asset.num_frames)
651-
]
652-
653-
reference_frames = torch.cat(reference_frames, dim=-1)
654-
655-
all_frames = []
656-
while True:
657-
try:
658-
frame, *_ = get_next_frame(decoder)
659-
all_frames.append(frame)
660-
except IndexError:
661-
break
662-
all_frames = torch.cat(all_frames, dim=-1)
663-
664-
assert_frames_equal(all_frames, reference_frames)
665-
666645
@pytest.mark.parametrize(
667646
"range",
668647
(
@@ -736,43 +715,6 @@ def test_decode_just_one_frame_at_boundaries(self, asset, expected_shape):
736715
)
737716
assert frames.shape == expected_shape
738717

739-
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
740-
def test_seek_and_next_audio(self, asset):
741-
decoder = create_from_file(str(asset.path), seek_mode="approximate")
742-
add_audio_stream(decoder)
743-
744-
pts = 2
745-
# Need +1 because we're not at frames boundaries
746-
reference_frame = asset.get_frame_data_by_index(
747-
asset.get_frame_index(pts_seconds=pts) + 1
748-
)
749-
seek_to_pts(decoder, pts)
750-
frame, _, _ = get_next_frame(decoder)
751-
assert_frames_equal(frame, reference_frame)
752-
753-
# Seeking forward is OK
754-
pts = 4
755-
reference_frame = asset.get_frame_data_by_index(
756-
asset.get_frame_index(pts_seconds=pts) + 1
757-
)
758-
seek_to_pts(decoder, pts)
759-
frame, _, _ = get_next_frame(decoder)
760-
assert_frames_equal(frame, reference_frame)
761-
762-
# Seeking backwards doesn't error, but it's wrong. See TODO in
763-
# `seek_to_pts` op.
764-
prev_pts = pts
765-
pts = 1
766-
seek_to_pts(decoder, pts)
767-
frame, _, _ = get_next_frame(decoder)
768-
# the decoder actually didn't seek, so the frame we're getting is just
769-
# the "next: one without seeking. This assertion exists to illutrate
770-
# what currently hapens, but it's obviously *wrong*.
771-
reference_frame = asset.get_frame_data_by_index(
772-
asset.get_frame_index(pts_seconds=prev_pts) + 2
773-
)
774-
assert_frames_equal(frame, reference_frame)
775-
776718

777719
if __name__ == "__main__":
778720
pytest.main()

0 commit comments

Comments
 (0)