Skip to content

Commit c453a3c

Browse files
committed
Address comments
1 parent c35ae47 commit c453a3c

File tree

6 files changed

+33
-29
lines changed

6 files changed

+33
-29
lines changed

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,26 +60,22 @@ int64_t getDuration(const AVFrame* frame) {
6060
#endif
6161
}
6262

63-
int64_t getNumChannels(const AVFrame* avFrame) {
63+
int getNumChannels(const AVFrame* avFrame) {
6464
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
6565
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
66-
int numChannels = avFrame->ch_layout.nb_channels;
66+
return avFrame->ch_layout.nb_channels;
6767
#else
68-
int numChannels = av_get_channel_layout_nb_channels(avFrame->channel_layout);
68+
return av_get_channel_layout_nb_channels(avFrame->channel_layout);
6969
#endif
70-
71-
return static_cast<int64_t>(numChannels);
7270
}
7371

74-
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext) {
72+
int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
7573
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
7674
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
77-
int numChannels = avCodecContext->ch_layout.nb_channels;
75+
return avCodecContext->ch_layout.nb_channels;
7876
#else
79-
int numChannels = avCodecContext->channels;
77+
return avCodecContext->channels;
8078
#endif
81-
82-
return static_cast<int64_t>(numChannels);
8379
}
8480

8581
AVIOBytesContext::AVIOBytesContext(

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
139139
int64_t getDuration(const UniqueAVFrame& frame);
140140
int64_t getDuration(const AVFrame* frame);
141141

142-
int64_t getNumChannels(const AVFrame* avFrame);
143-
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext);
142+
int getNumChannels(const AVFrame* avFrame);
143+
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
144144

145145
// Returns true if sws_scale can handle unaligned data.
146146
bool canSwsScaleHandleUnalignedData();

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,8 @@ void VideoDecoder::addAudioStream(int streamIndex) {
553553
containerMetadata_.allStreamMetadata[activeStreamIndex_];
554554
streamMetadata.sampleRate =
555555
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
556-
streamMetadata.numChannels = getNumChannels(streamInfo.codecContext);
556+
streamMetadata.numChannels =
557+
static_cast<int64_t>(getNumChannels(streamInfo.codecContext));
557558
}
558559

559560
// --------------------------------------------------------------------------
@@ -875,16 +876,16 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
875876
std::vector<torch::Tensor> tensors;
876877

877878
auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
878-
auto shouldStopDecoding = false;
879-
while (!shouldStopDecoding) {
879+
auto finished = false;
880+
while (!finished) {
880881
try {
881882
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
882883
return cursor_ < avFrame->pts + getDuration(avFrame);
883884
});
884885
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
885886
tensors.push_back(frameOutput.data);
886887
} catch (const EndOfFileException& e) {
887-
shouldStopDecoding = true;
888+
finished = true;
888889
}
889890

890891
// If stopSeconds is in [begin, end] of the last decoded frame, we should
@@ -893,7 +894,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
893894
// stopSeconds, which isn't what we want!
894895
auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
895896
streamInfo.lastDecodedAvFrameDuration;
896-
shouldStopDecoding |= (streamInfo.lastDecodedAvFramePts) <= stopPts &&
897+
finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts &&
897898
(stopPts <= lastDecodedAvFrameEnd);
898899
}
899900
return torch::cat(tensors, 1);

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,6 @@ class VideoDecoder {
395395
const AVFrame* avFrame,
396396
torch::Tensor& outputTensor);
397397

398-
FrameBatchOutput makeFrameBatchOutput(int64_t numFrames);
399398
// --------------------------------------------------------------------------
400399
// COLOR CONVERSION LIBRARIES HANDLERS CREATION
401400
// --------------------------------------------------------------------------

test/decoders/test_ops.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -697,22 +697,24 @@ def test_get_frames_by_pts_in_range_audio(self, range, asset):
697697

698698
torch.testing.assert_close(frames, reference_frames)
699699

700-
@pytest.mark.parametrize(
701-
"asset, expected_shape", ((NASA_AUDIO, (2, 1024)), (NASA_AUDIO_MP3, (2, 576)))
702-
)
703-
def test_decode_epsilon_range(self, asset, expected_shape):
700+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
701+
def test_decode_epsilon_range(self, asset):
704702
decoder = create_from_file(str(asset.path), seek_mode="approximate")
705703
add_audio_stream(decoder)
706704

705+
start_seconds = 5
707706
frames = get_frames_by_pts_in_range_audio(
708-
decoder, start_seconds=5, stop_seconds=5 + 1e-5
707+
decoder, start_seconds=start_seconds, stop_seconds=start_seconds + 1e-5
708+
)
709+
torch.testing.assert_close(
710+
frames,
711+
asset.get_frame_data_by_index(
712+
asset.get_frame_index(pts_seconds=start_seconds)
713+
),
709714
)
710-
assert frames.shape == expected_shape
711715

712-
@pytest.mark.parametrize(
713-
"asset, expected_shape", ((NASA_AUDIO, (2, 1024)), (NASA_AUDIO_MP3, (2, 576)))
714-
)
715-
def test_decode_just_one_frame_at_boundaries(self, asset, expected_shape):
716+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
717+
def test_decode_just_one_frame_at_boundaries(self, asset):
716718
decoder = create_from_file(str(asset.path), seek_mode="approximate")
717719
add_audio_stream(decoder)
718720

@@ -721,7 +723,12 @@ def test_decode_just_one_frame_at_boundaries(self, asset, expected_shape):
721723
frames = get_frames_by_pts_in_range_audio(
722724
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
723725
)
724-
assert frames.shape == expected_shape
726+
torch.testing.assert_close(
727+
frames,
728+
asset.get_frame_data_by_index(
729+
asset.get_frame_index(pts_seconds=start_seconds)
730+
),
731+
)
725732

726733
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
727734
def test_decode_start_equal_stop(self, asset):

test/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def get_frame_index(
377377
# 0.13~, not 0.
378378
return 0
379379
try:
380+
# Could use bisect() to maek this faster if needed
380381
return next(
381382
frame_index
382383
for (frame_index, frame_info) in self.frames[stream_index].items()

0 commit comments

Comments
 (0)