Skip to content

Commit 29e0b8d

Browse files
committed
Add more tests
1 parent ae15304 commit 29e0b8d

File tree

5 files changed

+69
-37
lines changed

5 files changed

+69
-37
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ void VideoDecoder::addAudioStream(int streamIndex) {
575575
// TODO-AUDIO
576576
TORCH_CHECK(
577577
streamMetadata.averageFps.has_value(),
578-
"frame_size or sampl_rate aren't known. Cannot decode.");
578+
"frame_size or sample_rate aren't known. Cannot decode.");
579579

580580
streamMetadata.sampleRate =
581581
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
@@ -1311,20 +1311,18 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13111311
auto numSamples = avFrame->nb_samples; // per channel
13121312
auto numChannels = getNumChannels(avFrame);
13131313

1314-
// TODO-AUDIO: dtype should be format-dependent
1315-
// TODO-AUDIO rename data to something else
1316-
torch::Tensor data;
1314+
torch::Tensor outputData;
13171315
if (preAllocatedOutputTensor.has_value()) {
1318-
data = preAllocatedOutputTensor.value();
1316+
outputData = preAllocatedOutputTensor.value();
13191317
} else {
1320-
data = torch::empty({numChannels, numSamples}, torch::kFloat32);
1318+
outputData = torch::empty({numChannels, numSamples}, torch::kFloat32);
13211319
}
13221320

13231321
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
1324-
// TODO Implement all formats
1322+
// TODO-AUDIO Implement all formats.
13251323
switch (format) {
13261324
case AV_SAMPLE_FMT_FLTP: {
1327-
uint8_t* outputChannelData = static_cast<uint8_t*>(data.data_ptr());
1325+
uint8_t* outputChannelData = static_cast<uint8_t*>(outputData.data_ptr());
13281326
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
13291327
for (auto channel = 0; channel < numChannels;
13301328
++channel, outputChannelData += numBytesPerChannel) {
@@ -1341,7 +1339,7 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13411339
"Unsupported audio format (yet!): ",
13421340
av_get_sample_fmt_name(format));
13431341
}
1344-
frameOutput.data = data;
1342+
frameOutput.data = outputData;
13451343
}
13461344

13471345
// --------------------------------------------------------------------------

test/decoders/test_ops.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
assert_frames_equal,
4141
cpu_and_cuda,
4242
NASA_AUDIO,
43+
NASA_AUDIO_MP3,
4344
NASA_VIDEO,
4445
needs_cuda,
4546
)
@@ -637,6 +638,45 @@ def test_audio_bad_seek_mode(self):
637638
):
638639
add_audio_stream(decoder)
639640

641+
def test_audio_decode_all_samples_with_get_frames_by_pts_in_range(self):
642+
decoder = create_from_file(str(NASA_AUDIO.path), seek_mode="approximate")
643+
add_audio_stream(decoder)
644+
645+
reference_frames = [
646+
NASA_AUDIO.get_frame_data_by_index(i) for i in range(NASA_AUDIO.num_frames)
647+
]
648+
reference_frames = torch.stack(
649+
reference_frames
650+
) # shape is (num_frames, C, num_samples_per_frame)
651+
652+
all_frames, *_ = get_frames_by_pts_in_range(
653+
decoder, start_seconds=0, stop_seconds=NASA_AUDIO.duration_seconds
654+
)
655+
assert_frames_equal(all_frames, reference_frames)
656+
657+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
658+
def test_audio_decode_all_samples_with_next(self, asset):
659+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
660+
add_audio_stream(decoder)
661+
662+
reference_frames = [
663+
asset.get_frame_data_by_index(i) for i in range(asset.num_frames)
664+
]
665+
666+
# shape is (C, num_frames * num_samples_per_frame) while preserving frame order and boundaries
667+
reference_frames = torch.cat(reference_frames, dim=-1)
668+
669+
all_frames = []
670+
while True:
671+
try:
672+
frame, *_ = get_next_frame(decoder)
673+
all_frames.append(frame)
674+
except IndexError:
675+
break
676+
all_frames = torch.cat(all_frames, axis=-1)
677+
678+
assert_frames_equal(all_frames, reference_frames)
679+
640680
@pytest.mark.parametrize(
641681
"start_seconds, stop_seconds",
642682
(
864 KB
Binary file not shown.
51.9 KB
Binary file not shown.

test/utils.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import pathlib
44
import sys
55

6-
from dataclasses import dataclass
7-
from typing import Dict, Optional
6+
from dataclasses import dataclass, field
7+
from typing import Dict, List, Optional
88

99
import numpy as np
1010
import pytest
@@ -203,6 +203,8 @@ class TestVideoStreamInfo:
203203

204204
@dataclass
205205
class TestVideo(TestContainerFile):
206+
"""Base class for the *video* streams of a video container"""
207+
206208
stream_infos: Dict[int, TestVideoStreamInfo]
207209

208210
def get_frame_data_by_index(
@@ -318,13 +320,16 @@ class TestAudioStreamInfo:
318320
sample_rate: int
319321
num_channels: int
320322
duration_seconds: float
323+
num_frames: int
321324

322325

323326
@dataclass
324327
class TestAudio(TestContainerFile):
328+
"""Base class for the *audio* streams of a container (potentially a video),
329+
or a pure audio file"""
325330

326331
stream_infos: Dict[int, TestAudioStreamInfo]
327-
_reference_frames: tuple[torch.Tensor] = tuple()
332+
_reference_frames: Dict[int, List[torch.Tensor]] = field(default_factory=dict)
328333

329334
# Storing each individual frame is too expensive for audio, because there's
330335
# a massive overhead in the binary format saved by pytorch. Saving all the
@@ -333,32 +338,22 @@ class TestAudio(TestContainerFile):
333338
# So we store the reference frames in a single file, and load/cache those
334339
# when the TestAudio instance is created.
335340
def __post_init__(self):
336-
# We hard-code the default stream index, see TODO below.
337-
file_path = _get_file_path(
338-
f"{self.filename}.stream{self.default_stream_index}.all_frames.pt"
339-
)
340-
if not file_path.exists():
341-
return # TODO-audio
342-
t = torch.load(file_path, weights_only=True)
341+
for stream_index in self.stream_infos:
342+
file_path = _get_file_path(
343+
f"{self.filename}.stream{stream_index}.all_frames.pt"
344+
)
343345

344-
# These are hard-coded value assuming stream 4 of nasa_13013.mp4. Each
345-
# of the 204 frames contains 1024 samples.
346-
# TODO make this more generic
347-
assert t.shape == (2, 204 * 1024)
348-
self._reference_frames = torch.chunk(t, chunks=204, dim=1)
346+
self._reference_frames[stream_index] = torch.load(
347+
file_path, weights_only=True
348+
)
349349

350350
def get_frame_data_by_index(
351351
self, idx: int, *, stream_index: Optional[int] = None
352352
) -> torch.Tensor:
353-
if stream_index is not None and stream_index != self.default_stream_index:
354-
# TODO address this, the fix should be to let _reference_frames be a
355-
# dict[tuple[torch.Tensor]] where keys are stream indices, and load
356-
# all of those indices in __post_init__.
357-
raise ValueError(
358-
"Can only use default stream index with TestAudio for now."
359-
)
353+
if stream_index is None:
354+
stream_index = self.default_stream_index
360355

361-
return self._reference_frames[idx]
356+
return self._reference_frames[stream_index][idx]
362357

363358
def pts_to_frame_index(self, pts_seconds: float) -> int:
364359
# These are hard-coded value assuming stream 4 of nasa_13013.mp4. Each
@@ -379,10 +374,9 @@ def num_channels(self) -> int:
379374
def duration_seconds(self) -> float:
380375
return self.stream_infos[self.default_stream_index].duration_seconds
381376

382-
# TODO: this shouldn't be named chw. Also values are hard-coded
383377
@property
384-
def empty_chw_tensor(self) -> torch.Tensor:
385-
return torch.empty([0, 2, 1024], dtype=torch.float32)
378+
def num_frames(self) -> int:
379+
return self.stream_infos[self.default_stream_index].num_frames
386380

387381

388382
NASA_AUDIO_MP3 = TestAudio(
@@ -391,7 +385,7 @@ def empty_chw_tensor(self) -> torch.Tensor:
391385
frames={}, # TODO
392386
stream_infos={
393387
0: TestAudioStreamInfo(
394-
sample_rate=8_000, num_channels=2, duration_seconds=13.248
388+
sample_rate=8_000, num_channels=2, duration_seconds=13.248, num_frames=183
395389
)
396390
},
397391
)
@@ -402,7 +396,7 @@ def empty_chw_tensor(self) -> torch.Tensor:
402396
frames={}, # TODO
403397
stream_infos={
404398
4: TestAudioStreamInfo(
405-
sample_rate=16_000, num_channels=2, duration_seconds=13.056
399+
sample_rate=16_000, num_channels=2, duration_seconds=13.056, num_frames=204
406400
)
407401
},
408402
)

0 commit comments

Comments
 (0)