Skip to content

Commit 6082803

Browse files
committed
Add support for get_frames_in_range
1 parent 73fa225 commit 6082803

File tree

6 files changed

+122
-27
lines changed

6 files changed

+122
-27
lines changed

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ int getNumChannels(const AVFrame* avFrame) {
6969
#endif
7070
}
7171

72+
int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
73+
// TODO not sure about the bounds of the versions here
74+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
75+
(IBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
76+
return av_get_channel_layout_nb_channels(avCodecContext->channel_layout);
77+
#else
78+
return avCodecContext->channels;
79+
#endif
80+
}
81+
7282
AVIOBytesContext::AVIOBytesContext(
7383
const void* data,
7484
size_t data_size,

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ int64_t getDuration(const UniqueAVFrame& frame);
140140
int64_t getDuration(const AVFrame* frame);
141141

142142
int getNumChannels(const AVFrame* avFrame);
143+
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
143144

144145
// Returns true if sws_scale can handle unaligned data.
145146
bool canSwsScaleHandleUnalignedData();

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,14 @@ void VideoDecoder::addVideoStream(
617617

618618
void VideoDecoder::addAudioStream(int streamIndex) {
619619
addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
620+
621+
// TODO address this, this is currently super limitting. The main thing we'll
622+
// need to handle is the pre-allocation of the output tensor in batch APIs. We
623+
// probably won't be able to pre-allocate anything.
624+
auto& streamInfo = streamInfos_[activeStreamIndex_];
625+
TORCH_CHECK(
626+
streamInfo.codecContext->frame_size > 0,
627+
"No support for variable framerate yet.");
620628
}
621629

622630
// --------------------------------------------------------------------------
@@ -736,9 +744,18 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
736744
step > 0, "Step must be greater than 0; is " + std::to_string(step));
737745

738746
int64_t numOutputFrames = std::ceil((stop - start) / double(step));
739-
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
740-
FrameBatchOutput frameBatchOutput(
741-
numOutputFrames, videoStreamOptions, streamMetadata);
747+
748+
FrameBatchOutput frameBatchOutput;
749+
if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) {
750+
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
751+
frameBatchOutput =
752+
FrameBatchOutput(numOutputFrames, videoStreamOptions, streamMetadata);
753+
} else {
754+
int64_t numSamples = streamInfo.codecContext->frame_size;
755+
int64_t numChannels = getNumChannels(streamInfo.codecContext);
756+
frameBatchOutput =
757+
FrameBatchOutput(numOutputFrames, numChannels, numSamples);
758+
}
742759

743760
for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
744761
FrameOutput frameOutput =
@@ -1200,8 +1217,8 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
12001217
frameOutput.durationSeconds = ptsToSeconds(
12011218
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
12021219
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1203-
// TODO: handle preAllocatedTensor for audio
1204-
convertAudioAVFrameToFrameOutputOnCPU(avFrameStream, frameOutput);
1220+
convertAudioAVFrameToFrameOutputOnCPU(
1221+
avFrameStream, frameOutput, preAllocatedOutputTensor);
12051222
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
12061223
convertAVFrameToFrameOutputOnCPU(
12071224
avFrameStream, frameOutput, preAllocatedOutputTensor);
@@ -1380,14 +1397,21 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
13801397

13811398
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13821399
VideoDecoder::AVFrameStream& avFrameStream,
1383-
FrameOutput& frameOutput) {
1400+
FrameOutput& frameOutput,
1401+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
13841402
const AVFrame* avFrame = avFrameStream.avFrame.get();
13851403

13861404
auto numSamples = avFrame->nb_samples; // per channel
13871405
auto numChannels = getNumChannels(avFrame);
13881406

13891407
// TODO: dtype should be format-dependent
1390-
torch::Tensor data = torch::empty({numChannels, numSamples}, torch::kFloat32);
1408+
// TODO rename data to something else
1409+
torch::Tensor data;
1410+
if (preAllocatedOutputTensor.has_value()) {
1411+
data = preAllocatedOutputTensor.value();
1412+
} else {
1413+
data = torch::empty({numChannels, numSamples}, torch::kFloat32);
1414+
}
13911415

13921416
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
13931417
// TODO Implement all formats
@@ -1431,6 +1455,20 @@ VideoDecoder::FrameBatchOutput::FrameBatchOutput(
14311455
height, width, videoStreamOptions.device, numFrames);
14321456
}
14331457

1458+
VideoDecoder::FrameBatchOutput::FrameBatchOutput(
1459+
int64_t numFrames,
1460+
int64_t numChannels,
1461+
int64_t numSamples)
1462+
: ptsSeconds(torch::empty({numSamples}, {torch::kFloat64})),
1463+
durationSeconds(torch::empty({numSamples}, {torch::kFloat64})) {
1464+
// TODO handle dtypes other than float
1465+
auto tensorOptions = torch::TensorOptions()
1466+
.dtype(torch::kFloat32)
1467+
.layout(torch::kStrided)
1468+
.device(torch::kCPU);
1469+
data = torch::empty({numFrames, numChannels, numSamples}, tensorOptions);
1470+
}
1471+
14341472
torch::Tensor allocateEmptyHWCTensor(
14351473
int height,
14361474
int width,
@@ -1459,8 +1497,13 @@ torch::Tensor allocateEmptyHWCTensor(
14591497
// https://pytorch.org/docs/stable/generated/torch.permute.html
14601498
torch::Tensor VideoDecoder::maybePermuteHWC2CHW(torch::Tensor& hwcTensor) {
14611499
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_AUDIO) {
1462-
// TODO: Is this really how we want to handle audio?
1463-
return hwcTensor;
1500+
// TODO: Do something better
1501+
auto shape = hwcTensor.sizes();
1502+
auto numFrames = shape[0];
1503+
auto numChannels = shape[1];
1504+
auto numSamples = shape[2];
1505+
return hwcTensor.permute({1, 0, 2}).reshape(
1506+
{numChannels, numSamples * numFrames});
14641507
}
14651508
if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder ==
14661509
"NHWC") {

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,15 @@ class VideoDecoder {
162162
torch::Tensor ptsSeconds; // 1D of shape (N,)
163163
torch::Tensor durationSeconds; // 1D of shape (N,)
164164

165+
FrameBatchOutput(){};
165166
explicit FrameBatchOutput(
166167
int64_t numFrames,
167168
const VideoStreamOptions& videoStreamOptions,
168169
const StreamMetadata& streamMetadata);
170+
explicit FrameBatchOutput(
171+
int64_t numFrames,
172+
int64_t numChannels,
173+
int64_t numSamples);
169174
};
170175

171176
// Places the cursor at the first frame on or after the position in seconds.
@@ -385,7 +390,8 @@ class VideoDecoder {
385390

386391
void convertAudioAVFrameToFrameOutputOnCPU(
387392
AVFrameStream& avFrameStream,
388-
FrameOutput& frameOutput);
393+
FrameOutput& frameOutput,
394+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
389395

390396
torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame);
391397

test/decoders/test_video_decoder_ops.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -350,45 +350,51 @@ def test_pts_apis_against_index_ref(self, device):
350350
)
351351
torch.testing.assert_close(pts_seconds, all_pts_seconds_ref, atol=0, rtol=0)
352352

353+
@pytest.mark.parametrize("test_ref", (NASA_VIDEO, NASA_AUDIO))
353354
@pytest.mark.parametrize("device", cpu_and_cuda())
354-
def test_get_frames_in_range(self, device):
355-
decoder = create_from_file(str(NASA_VIDEO.path))
356-
add_video_stream(decoder, device=device)
355+
def test_get_frames_in_range(self, test_ref, device):
356+
if device == "cuda" and test_ref is NASA_AUDIO:
357+
pytest.skip(reason="CUDA decoding not supported for audio")
358+
decoder = create_from_file(str(test_ref.path))
359+
_add_stream(decoder=decoder, test_ref=test_ref, device=device)
357360

358361
# ensure that the degenerate case of a range of size 1 works
359-
ref_frame0 = NASA_VIDEO.get_frame_data_by_range(0, 1)
362+
ref_frame0 = test_ref.get_frame_data_by_range(0, 1)
360363
bulk_frame0, *_ = get_frames_in_range(decoder, start=0, stop=1)
361364
assert_frames_equal(bulk_frame0, ref_frame0.to(device))
362365

363-
ref_frame1 = NASA_VIDEO.get_frame_data_by_range(1, 2)
366+
ref_frame1 = test_ref.get_frame_data_by_range(1, 2)
364367
bulk_frame1, *_ = get_frames_in_range(decoder, start=1, stop=2)
365368
assert_frames_equal(bulk_frame1, ref_frame1.to(device))
366369

367-
ref_frame389 = NASA_VIDEO.get_frame_data_by_range(389, 390)
368-
bulk_frame389, *_ = get_frames_in_range(decoder, start=389, stop=390)
370+
last_index = 389 if test_ref is NASA_VIDEO else 203 # TODO ew
371+
ref_frame389 = test_ref.get_frame_data_by_range(last_index, last_index + 1)
372+
bulk_frame389, *_ = get_frames_in_range(
373+
decoder, start=last_index, stop=last_index + 1
374+
)
369375
assert_frames_equal(bulk_frame389, ref_frame389.to(device))
370376

371377
# contiguous ranges
372-
ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(0, 9)
378+
ref_frames0_9 = test_ref.get_frame_data_by_range(0, 9)
373379
bulk_frames0_9, *_ = get_frames_in_range(decoder, start=0, stop=9)
374380
assert_frames_equal(bulk_frames0_9, ref_frames0_9.to(device))
375381

376-
ref_frames4_8 = NASA_VIDEO.get_frame_data_by_range(4, 8)
382+
ref_frames4_8 = test_ref.get_frame_data_by_range(4, 8)
377383
bulk_frames4_8, *_ = get_frames_in_range(decoder, start=4, stop=8)
378384
assert_frames_equal(bulk_frames4_8, ref_frames4_8.to(device))
379385

380386
# ranges with a stride
381-
ref_frames15_35 = NASA_VIDEO.get_frame_data_by_range(15, 36, 5)
387+
ref_frames15_35 = test_ref.get_frame_data_by_range(15, 36, 5)
382388
bulk_frames15_35, *_ = get_frames_in_range(decoder, start=15, stop=36, step=5)
383389
assert_frames_equal(bulk_frames15_35, ref_frames15_35.to(device))
384390

385-
ref_frames0_9_2 = NASA_VIDEO.get_frame_data_by_range(0, 9, 2)
391+
ref_frames0_9_2 = test_ref.get_frame_data_by_range(0, 9, 2)
386392
bulk_frames0_9_2, *_ = get_frames_in_range(decoder, start=0, stop=9, step=2)
387393
assert_frames_equal(bulk_frames0_9_2, ref_frames0_9_2.to(device))
388394

389395
# an empty range is valid!
390396
empty_frame, *_ = get_frames_in_range(decoder, start=5, stop=5)
391-
assert_frames_equal(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device))
397+
assert_frames_equal(empty_frame, test_ref.empty_chw_tensor.to(device))
392398

393399
@pytest.mark.parametrize(
394400
"test_ref, last_frame_index", ((NASA_VIDEO, 289), (NASA_AUDIO, 203))

test/utils.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,7 @@ def get_frame_data_by_range(
119119
*,
120120
stream_index: Optional[int] = None,
121121
) -> torch.Tensor:
122-
tensors = [
123-
self.get_frame_data_by_index(i, stream_index=stream_index)
124-
for i in range(start, stop, step)
125-
]
126-
return torch.stack(tensors)
122+
raise NotImplementedError("Override in child classes")
127123

128124
def get_pts_seconds_by_range(
129125
self,
@@ -197,6 +193,20 @@ def get_frame_data_by_index(
197193
)
198194
return torch.load(file_path, weights_only=True).permute(2, 0, 1)
199195

196+
def get_frame_data_by_range(
197+
self,
198+
start: int,
199+
stop: int,
200+
step: int = 1,
201+
*,
202+
stream_index: Optional[int] = None,
203+
) -> torch.Tensor:
204+
tensors = [
205+
self.get_frame_data_by_index(i, stream_index=stream_index)
206+
for i in range(start, stop, step)
207+
]
208+
return torch.stack(tensors)
209+
200210
@property
201211
def width(self) -> int:
202212
return self.stream_infos[self.default_stream_index].width
@@ -337,6 +347,25 @@ def get_frame_data_by_index(
337347

338348
return self._reference_frames[idx]
339349

350+
def get_frame_data_by_range(
351+
self,
352+
start: int,
353+
stop: int,
354+
step: int = 1,
355+
*,
356+
stream_index: Optional[int] = None,
357+
) -> torch.Tensor:
358+
tensors = [
359+
self.get_frame_data_by_index(i, stream_index=stream_index)
360+
for i in range(start, stop, step)
361+
]
362+
return torch.cat(tensors, dim=1)
363+
364+
# TODO: this shouldn't be named chw
365+
@property
366+
def empty_chw_tensor(self) -> torch.Tensor:
367+
return torch.empty([2, 0], dtype=torch.float32)
368+
340369

341370
NASA_AUDIO = TestAudio(
342371
filename="nasa_13013.mp4",

0 commit comments

Comments
 (0)