Skip to content

Commit 0f92d60

Browse files
committed
Handle approximate mode. Sort of.
1 parent 1751c6b commit 0f92d60

File tree

2 files changed

+67
-30
lines changed

2 files changed

+67
-30
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ void VideoDecoder::initializeDecoder() {
156156
"Our stream index, " + std::to_string(i) +
157157
", does not match AVStream's index, " +
158158
std::to_string(avStream->index) + ".");
159+
160+
// TODO figure out audio metadata
161+
159162
streamMetadata.streamIndex = i;
160163
streamMetadata.mediaType = avStream->codecpar->codec_type;
161164
streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id);
@@ -171,12 +174,22 @@ void VideoDecoder::initializeDecoder() {
171174
av_q2d(avStream->time_base) * avStream->duration;
172175
}
173176

174-
double fps = av_q2d(avStream->r_frame_rate);
175-
if (fps > 0) {
176-
streamMetadata.averageFps = fps;
177-
}
178-
179177
if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) {
178+
double fps = av_q2d(avStream->r_frame_rate);
179+
if (fps > 0) {
180+
streamMetadata.averageFps = fps;
181+
}
182+
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
183+
int numSamplesPerFrame = avStream->codecpar->frame_size;
184+
int sampleRate = avStream->codecpar->sample_rate;
185+
if (numSamplesPerFrame > 0 && sampleRate > 0) {
186+
// This should allow the approximate mode to do its magic.
187+
// fps is numFrames / duration where
188+
// - duration = numSamplesTotal / sampleRate and
189+
// - numSamplesTotal = numSamplesPerFrame * numFrames
190+
streamMetadata.averageFps =
191+
static_cast<double>(sampleRate) / numSamplesPerFrame;
192+
}
180193
containerMetadata_.numVideoStreams++;
181194
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
182195
containerMetadata_.numAudioStreams++;
@@ -465,15 +478,27 @@ void VideoDecoder::addStream(
465478
.value_or(avCodec));
466479
}
467480

468-
// TODO figure out audio metadata
481+
// TODO: For audio, we raise if seek_mode="approximate" and if the number of
482+
// samples per frame is unknown (frame_size field of codec params). But that's
483+
// quite limitting. Ultimately, the most common type of call will be to decode
484+
// an entire file from start to end (possibly with some offsets for start and
485+
// end). And for that, we shouldn't [need to] force the user to scan, because
486+
// all this entails is a single call to seek(start) (if at all) and then just
487+
// a bunch of consecutive calls to getNextFrame(). Maybe there should be a
488+
// third seek mode for audio, e.g. seek_mode="contiguous" where we don't scan,
489+
// and only allow calls to getFramesPlayedAt().
469490
StreamMetadata& streamMetadata =
470491
containerMetadata_.allStreamMetadata[activeStreamIndex_];
471492
if (seekMode_ == SeekMode::approximate &&
472493
!streamMetadata.averageFps.has_value()) {
473-
throw std::runtime_error(
474-
"Seek mode is approximate, but stream " +
475-
std::to_string(activeStreamIndex_) +
476-
" does not have an average fps in its metadata.");
494+
std::string errMsg = "Seek mode is approximate, but stream " +
495+
std::to_string(activeStreamIndex_) + "does not have ";
496+
if (mediaType == AVMEDIA_TYPE_VIDEO) {
497+
errMsg += "an average fps in its metadata.";
498+
} else {
499+
errMsg += "a constant number of samples per frame.";
500+
}
501+
throw std::runtime_error(errMsg);
477502
}
478503

479504
AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);

test/decoders/test_video_decoder_ops.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,13 @@ def test_add_stream(self):
109109
),
110110
)
111111
@pytest.mark.parametrize("device", cpu_and_cuda())
112-
def test_seek_and_next(self, test_ref, index_of_frame_after_seeking_at_6, device):
112+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
113+
def test_seek_and_next(
114+
self, test_ref, index_of_frame_after_seeking_at_6, device, seek_mode
115+
):
113116
if device == "cuda" and test_ref is NASA_AUDIO:
114117
pytest.skip(reason="CUDA decoding not supported for audio")
115-
decoder = create_from_file(str(test_ref.path))
118+
decoder = create_from_file(str(test_ref.path), seek_mode=seek_mode)
116119
_add_stream(decoder=decoder, test_ref=test_ref, device=device)
117120
frame0, _, _ = get_next_frame(decoder)
118121
reference_frame0 = test_ref.get_frame_data_by_index(0)
@@ -129,11 +132,12 @@ def test_seek_and_next(self, test_ref, index_of_frame_after_seeking_at_6, device
129132

130133
@pytest.mark.parametrize("test_ref", (NASA_VIDEO, NASA_AUDIO))
131134
@pytest.mark.parametrize("device", cpu_and_cuda())
132-
def test_seek_to_negative_pts(self, test_ref, device):
135+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
136+
def test_seek_to_negative_pts(self, test_ref, device, seek_mode):
133137
if device == "cuda" and test_ref is NASA_AUDIO:
134138
pytest.skip(reason="CUDA decoding not supported for audio")
135139

136-
decoder = create_from_file(str(test_ref.path))
140+
decoder = create_from_file(str(test_ref.path), seek_mode=seek_mode)
137141
_add_stream(decoder=decoder, test_ref=test_ref, device=device)
138142
frame0, _, _ = get_next_frame(decoder)
139143
reference_frame0 = test_ref.get_frame_data_by_index(0)
@@ -144,9 +148,10 @@ def test_seek_to_negative_pts(self, test_ref, device):
144148
assert_frames_equal(frame0, reference_frame0.to(device))
145149

146150
@pytest.mark.parametrize("device", cpu_and_cuda())
147-
def test_get_frame_at_pts_video(self, device):
151+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
152+
def test_get_frame_at_pts_video(self, device, seek_mode):
148153

149-
decoder = create_from_file(str(NASA_VIDEO.path))
154+
decoder = create_from_file(str(NASA_VIDEO.path), seek_mode=seek_mode)
150155
add_video_stream(decoder=decoder, device=device)
151156
# This frame has pts=6.006 and duration=0.033367, so it should be visible
152157
# at timestamps in the range [6.006, 6.039367) (not including the last timestamp).
@@ -168,8 +173,9 @@ def test_get_frame_at_pts_video(self, device):
168173
with pytest.raises(AssertionError):
169174
assert_frames_equal(next_frame, reference_frame6.to(device))
170175

171-
def test_get_frame_at_pts_audio(self):
172-
decoder = create_from_file(str(NASA_AUDIO.path))
176+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
177+
def test_get_frame_at_pts_audio(self, seek_mode):
178+
decoder = create_from_file(str(NASA_AUDIO.path), seek_mode=seek_mode)
173179
add_audio_stream(decoder=decoder)
174180
# This frame has pts=6.016 and duration=0.064 , so it should be played
175181
# at timestamps in the range [6.016, 6.08) (not including the last timestamp).
@@ -191,11 +197,12 @@ def test_get_frame_at_pts_audio(self):
191197

192198
@pytest.mark.parametrize("test_ref", (NASA_VIDEO, NASA_AUDIO))
193199
@pytest.mark.parametrize("device", cpu_and_cuda())
194-
def test_get_frame_at_index(self, test_ref, device):
200+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
201+
def test_get_frame_at_index(self, test_ref, device, seek_mode):
195202
if device == "cuda" and test_ref is NASA_AUDIO:
196203
pytest.skip(reason="CUDA decoding not supported for audio")
197204

198-
decoder = create_from_file(str(test_ref.path))
205+
decoder = create_from_file(str(test_ref.path), seek_mode=seek_mode)
199206
_add_stream(decoder=decoder, test_ref=test_ref, device=device)
200207
frame0, _, _ = get_frame_at_index(decoder, frame_index=0)
201208
reference_frame0 = test_ref.get_frame_data_by_index(0)
@@ -213,12 +220,13 @@ def test_get_frame_at_index(self, test_ref, device):
213220
),
214221
)
215222
@pytest.mark.parametrize("device", cpu_and_cuda())
223+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
216224
def test_get_frame_with_info_at_index(
217-
self, test_ref, expected_pts, expected_duration, device
225+
self, test_ref, expected_pts, expected_duration, device, seek_mode
218226
):
219227
if device == "cuda" and test_ref is NASA_AUDIO:
220228
pytest.skip(reason="CUDA decoding not supported for audio")
221-
decoder = create_from_file(str(test_ref.path))
229+
decoder = create_from_file(str(test_ref.path), seek_mode=seek_mode)
222230
_add_stream(decoder=decoder, test_ref=test_ref, device=device)
223231
frame6, pts, duration = get_frame_at_index(decoder, frame_index=180)
224232
reference_frame6 = test_ref.get_frame_data_by_index(180)
@@ -228,10 +236,11 @@ def test_get_frame_with_info_at_index(
228236

229237
@pytest.mark.parametrize("test_ref", (NASA_VIDEO, NASA_AUDIO))
230238
@pytest.mark.parametrize("device", cpu_and_cuda())
231-
def test_get_frames_at_indices(self, test_ref, device):
239+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
240+
def test_get_frames_at_indices(self, test_ref, device, seek_mode):
232241
if device == "cuda" and test_ref is NASA_AUDIO:
233242
pytest.skip(reason="CUDA decoding not supported for audio")
234-
decoder = create_from_file(str(test_ref.path))
243+
decoder = create_from_file(str(test_ref.path), seek_mode=seek_mode)
235244
_add_stream(decoder=decoder, test_ref=test_ref, device=device)
236245
frames0and180, *_ = get_frames_at_indices(decoder, frame_indices=[0, 180])
237246
reference_frame0 = test_ref.get_frame_data_by_index(0)
@@ -242,11 +251,12 @@ def test_get_frames_at_indices(self, test_ref, device):
242251

243252
@pytest.mark.parametrize("test_ref", (NASA_VIDEO, NASA_AUDIO))
244253
@pytest.mark.parametrize("device", cpu_and_cuda())
245-
def test_get_frames_at_indices_unsorted_indices(self, test_ref, device):
254+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
255+
def test_get_frames_at_indices_unsorted_indices(self, test_ref, device, seek_mode):
246256
if device == "cuda" and test_ref is NASA_AUDIO:
247257
pytest.skip(reason="CUDA decoding not supported for audio")
248258

249-
decoder = create_from_file(str(test_ref.path))
259+
decoder = create_from_file(str(test_ref.path), seek_mode=seek_mode)
250260
_add_stream(decoder=decoder, test_ref=test_ref, device=device)
251261

252262
frame_indices = [2, 0, 1, 0, 2]
@@ -272,8 +282,9 @@ def test_get_frames_at_indices_unsorted_indices(self, test_ref, device):
272282
assert_frames_equal(frames[0], frames[-1])
273283

274284
@pytest.mark.parametrize("device", cpu_and_cuda())
275-
def test_get_frames_by_pts(self, device):
276-
decoder = create_from_file(str(NASA_VIDEO.path))
285+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
286+
def test_get_frames_by_pts(self, device, seek_mode):
287+
decoder = create_from_file(str(NASA_VIDEO.path), seek_mode=seek_mode)
277288
_add_video_stream(decoder=decoder, device=device)
278289

279290
# Note: 13.01 should give the last video frame for the NASA video
@@ -361,10 +372,11 @@ def test_pts_apis_against_index_ref(self, test_ref, device):
361372

362373
@pytest.mark.parametrize("test_ref", (NASA_VIDEO, NASA_AUDIO))
363374
@pytest.mark.parametrize("device", cpu_and_cuda())
364-
def test_get_frames_in_range(self, test_ref, device):
375+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
376+
def test_get_frames_in_range(self, test_ref, device, seek_mode):
365377
if device == "cuda" and test_ref is NASA_AUDIO:
366378
pytest.skip(reason="CUDA decoding not supported for audio")
367-
decoder = create_from_file(str(test_ref.path))
379+
decoder = create_from_file(str(test_ref.path), seek_mode=seek_mode)
368380
_add_stream(decoder=decoder, test_ref=test_ref, device=device)
369381

370382
# ensure that the degenerate case of a range of size 1 works

0 commit comments

Comments
 (0)