Skip to content

Commit a7b67d5

Browse files
committed
WIP
1 parent 9a00c91 commit a7b67d5

File tree

6 files changed

+116
-19
lines changed

6 files changed

+116
-19
lines changed

src/torchcodec/_frame.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,16 @@ def __len__(self):
115115
def __repr__(self):
116116
return _frame_repr(self)
117117

118+
118119
@dataclass
119120
class AudioSamples(Iterable):
120121
"""Audio samples with associated metadata."""
122+
121123
# TODO-AUDIO: docs
122124
data: Tensor
123125
pts_seconds: float
124126
sample_rate: int
127+
125128
def __post_init__(self):
126129
# This is called after __init__() when a Frame is created. We can run
127130
# input validation checks here.
@@ -135,4 +138,4 @@ def __iter__(self) -> Iterator[Union[Tensor, float]]:
135138
yield getattr(self, field.name)
136139

137140
def __repr__(self):
138-
return _frame_repr(self)
141+
return _frame_repr(self)

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def get_samples_played_in_range(
8282
offset_beginning = round((start_seconds - first_pts) * sample_rate)
8383
output_pts_seconds = start_seconds
8484
else:
85+
# In normal cases we'll have first_pts <= start_pts, but in some
86+
# edge cases it's possible to have first_pts > start_seconds,
87+
# typically if the stream's first frame's pts isn't exactly 0.
8588
offset_beginning = 0
8689
output_pts_seconds = first_pts
8790

@@ -97,4 +100,3 @@ def get_samples_played_in_range(
97100
pts_seconds=output_pts_seconds,
98101
sample_rate=sample_rate,
99102
)
100-

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
854854

855855
if (startSeconds == stopSeconds) {
856856
// For consistency with video
857-
return AudioFramesOutput{torch::empty({0}), 0.0};
857+
return AudioFramesOutput{torch::empty({0, 0}), 0.0};
858858
}
859859

860860
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];

test/decoders/test_decoders.py

Lines changed: 104 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -957,33 +957,124 @@ def test_metadata(self, asset):
957957
assert decoder.metadata.num_channels == asset.num_channels
958958

959959
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
960-
def test_get_all_samples(self, asset):
960+
def test_error(self, asset):
961961
decoder = AudioDecoder(asset.path)
962-
963-
samples = decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=None)
962+
963+
with pytest.raises(ValueError, match="Invalid start seconds"):
964+
decoder.get_samples_played_in_range(start_seconds=-1300)
965+
966+
with pytest.raises(ValueError, match="Invalid start seconds"):
967+
decoder.get_samples_played_in_range(start_seconds=9999)
968+
969+
with pytest.raises(ValueError, match="Invalid start seconds"):
970+
decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=2)
971+
972+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
973+
@pytest.mark.parametrize("stop_seconds", (None, "duration", 99999999))
974+
def test_get_all_samples(self, asset, stop_seconds):
975+
decoder = AudioDecoder(asset.path)
976+
977+
if stop_seconds == "duration":
978+
stop_seconds = asset.duration_seconds
979+
980+
samples = decoder.get_samples_played_in_range(
981+
start_seconds=0, stop_seconds=stop_seconds
982+
)
964983

965984
reference_frames = asset.get_frame_data_by_range(
966-
start=0,
967-
stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1
985+
start=0, stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1
968986
)
969987

970988
torch.testing.assert_close(samples.data, reference_frames)
971-
assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds
989+
assert samples.sample_rate == asset.sample_rate
990+
991+
# TODO there's a bug with NASA_AUDIO_MP3: https://github.com/pytorch/torchcodec/issues/553
992+
expected_pts = (
993+
0.072
994+
if asset is NASA_AUDIO_MP3
995+
else asset.get_frame_info(idx=0).pts_seconds
996+
)
997+
assert samples.pts_seconds == expected_pts
972998

973999
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
974-
def test_get_samples_played_in_range(self, asset):
1000+
def test_at_frame_boundaries(self, asset):
9751001
decoder = AudioDecoder(asset.path)
976-
977-
start_seconds, stop_seconds = 2, 4
978-
samples = decoder.get_samples_played_in_range(start_seconds=start_seconds, stop_seconds=stop_seconds)
1002+
1003+
start_frame_index, stop_frame_index = 10, 40
1004+
start_seconds = asset.get_frame_info(start_frame_index).pts_seconds
1005+
stop_seconds = asset.get_frame_info(stop_frame_index).pts_seconds
1006+
1007+
samples = decoder.get_samples_played_in_range(
1008+
start_seconds=start_seconds, stop_seconds=stop_seconds
1009+
)
9791010

9801011
reference_frames = asset.get_frame_data_by_range(
981-
start=asset.get_frame_index(pts_seconds=start_seconds),
982-
stop=asset.get_frame_index(pts_seconds=stop_seconds) + 1
1012+
start=start_frame_index, stop=stop_frame_index
1013+
)
1014+
1015+
assert samples.pts_seconds == start_seconds
1016+
num_samples = samples.data.shape[1]
1017+
assert (
1018+
num_samples
1019+
== reference_frames.shape[1]
1020+
== (stop_seconds - start_seconds) * decoder.metadata.sample_rate
1021+
)
1022+
torch.testing.assert_close(samples.data, reference_frames)
1023+
assert samples.sample_rate == asset.sample_rate
1024+
1025+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
1026+
def test_not_at_frame_boundaries(self, asset):
1027+
decoder = AudioDecoder(asset.path)
1028+
1029+
start_frame_index, stop_frame_index = 10, 40
1030+
start_frame_info = asset.get_frame_info(start_frame_index)
1031+
stop_frame_info = asset.get_frame_info(stop_frame_index)
1032+
start_seconds = start_frame_info.pts_seconds + (
1033+
start_frame_info.duration_seconds / 2
1034+
)
1035+
stop_seconds = stop_frame_info.pts_seconds + (
1036+
stop_frame_info.duration_seconds / 2
1037+
)
1038+
samples = decoder.get_samples_played_in_range(
1039+
start_seconds=start_seconds, stop_seconds=stop_seconds
1040+
)
1041+
1042+
reference_frames = asset.get_frame_data_by_range(
1043+
start=start_frame_index, stop=stop_frame_index + 1
9831044
)
9841045

9851046
assert samples.pts_seconds == start_seconds
9861047
num_samples = samples.data.shape[1]
9871048
assert num_samples < reference_frames.shape[1]
988-
assert num_samples == (stop_seconds - start_seconds) * decoder.metadata.sample_rate
1049+
assert (
1050+
num_samples == (stop_seconds - start_seconds) * decoder.metadata.sample_rate
1051+
)
1052+
assert samples.sample_rate == asset.sample_rate
1053+
1054+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
1055+
def test_start_equals_stop(self, asset):
1056+
decoder = AudioDecoder(asset.path)
1057+
samples = decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=3)
1058+
assert samples.data.shape == (0, 0)
1059+
1060+
def test_frame_start_is_not_zero(self):
1061+
# For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.072 [1].
1062+
# So if we request start = 0.05, we shouldn't be truncating anything.
1063+
#
1064+
# [1] well, really it's at 0.138125, not 0.072 (see
1065+
# https://github.com/pytorch/torchcodec/issues/553), but for the purpose
1066+
# of this test it doesn't matter.
1067+
1068+
asset = NASA_AUDIO_MP3
1069+
start_seconds = 0.05 # this is less than the first frame's pts
1070+
stop_frame_index = 10
1071+
stop_seconds = asset.get_frame_info(stop_frame_index).pts_seconds
9891072

1073+
decoder = AudioDecoder(asset.path)
1074+
1075+
samples = decoder.get_samples_played_in_range(
1076+
start_seconds=start_seconds, stop_seconds=stop_seconds
1077+
)
1078+
1079+
reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index)
1080+
torch.testing.assert_close(samples.data, reference_frames)

test/decoders/test_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def test_decode_start_equal_stop(self, asset):
742742
frames, pts_seconds = get_frames_by_pts_in_range_audio(
743743
decoder, start_seconds=1, stop_seconds=1
744744
)
745-
assert frames.shape == (0,)
745+
assert frames.shape == (0, 0)
746746
assert pts_seconds == 0
747747

748748
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))

test/test_frame_dataclasses.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
import torch
3-
from torchcodec import Frame, FrameBatch, AudioSamples
3+
from torchcodec import AudioSamples, Frame, FrameBatch
44

55

66
def test_unpacking():
@@ -141,6 +141,7 @@ def test_framebatch_indexing():
141141
assert isinstance(fb_fancy, FrameBatch)
142142
assert fb_fancy.data.shape == (1, C, H, W)
143143

144+
144145
def test_audio_samples_error():
145146
with pytest.raises(ValueError, match="data must be 2-dimensional"):
146147
AudioSamples(
@@ -153,4 +154,4 @@ def test_audio_samples_error():
153154
data=torch.rand(1, 2, 3),
154155
pts_seconds=1,
155156
sample_rate=16_000,
156-
)
157+
)

0 commit comments

Comments
 (0)