Skip to content

Commit 9ee63e6

Browse files
committed
more stuff
1 parent ee30cc3 commit 9ee63e6

File tree

5 files changed

+25
-3
lines changed

5 files changed

+25
-3
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,13 +573,15 @@ void VideoDecoder::addAudioStream(int streamIndex) {
573573

574574
VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
575575
auto output = getNextFrameInternal();
576-
output.data = maybePermuteHWC2CHW(output.data);
576+
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
577+
output.data = maybePermuteHWC2CHW(output.data);
578+
}
577579
return output;
578580
}
579581

580582
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
581583
std::optional<torch::Tensor> preAllocatedOutputTensor) {
582-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
584+
validateActiveStream();
583585
AVFrameStream avFrameStream = decodeAVFrame(
584586
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
585587
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);

test/decoders/test_decoders.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,3 +1080,21 @@ def test_frame_start_is_not_zero(self):
10801080

10811081
reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index)
10821082
torch.testing.assert_close(samples.data, reference_frames)
1083+
1084+
def test_single_channel(self):
1085+
asset = SINE_MONO_S32
1086+
decoder = AudioDecoder(asset.path)
1087+
1088+
samples = decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=2)
1089+
assert samples.data.shape[0] == asset.num_channels == 1
1090+
1091+
def test_format_conversion(self):
1092+
asset = SINE_MONO_S32
1093+
decoder = AudioDecoder(asset.path)
1094+
assert decoder.metadata.sample_format == asset.sample_format == "s32"
1095+
1096+
all_samples = decoder.get_samples_played_in_range(start_seconds=0)
1097+
assert all_samples.data.dtype == torch.float32
1098+
1099+
reference_frames = asset.get_frame_data_by_range(start=0, stop=asset.num_frames)
1100+
torch.testing.assert_close(all_samples.data, reference_frames)

test/decoders/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,6 @@ class TestAudioOps:
626626
partial(get_frames_in_range, start=4, stop=5),
627627
partial(get_frame_at_pts, seconds=2),
628628
partial(get_frames_by_pts, timestamps=[0, 1.5]),
629-
partial(get_next_frame),
630629
),
631630
)
632631
def test_audio_bad_method(self, method):
266 KB
Binary file not shown.

test/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,9 @@ def sample_format(self) -> str:
444444
},
445445
)
446446

447+
# Note that the file itself is s32 sample format, but the reference frames are
448+
# stored as fltp. We can add the s32 original reference frames once we support
449+
# decoding to non-fltp format, but for now we don't need to.
447450
SINE_MONO_S32 = TestAudio(
448451
filename="sine_mono_s32.wav",
449452
default_stream_index=0,

0 commit comments

Comments
 (0)