Skip to content

Commit 0f9e14d

Browse files
committed
WIP
1 parent 277fac2 commit 0f9e14d

File tree

6 files changed

+97
-16
lines changed

6 files changed

+97
-16
lines changed

src/torchcodec/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# Note: usort wants to put Frame and FrameBatch after decoders and samplers,
88
# but that results in circular import.
9-
from ._frame import Frame, FrameBatch # usort:skip # noqa
9+
from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa
1010
from . import decoders, samplers # noqa
1111

1212
try:

src/torchcodec/_frame.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def _frame_repr(self):
15-
# Utility to replace Frame and FrameBatch __repr__ method. This prints the
15+
# Utility to replace __repr__ method of dataclasses below. This prints the
1616
# shape of the .data tensor rather than printing the (potentially very long)
1717
# data tensor itself.
1818
s = self.__class__.__name__ + ":\n"
@@ -114,3 +114,25 @@ def __len__(self):
114114

115115
def __repr__(self):
116116
return _frame_repr(self)
117+
118+
@dataclass
119+
class AudioSamples(Iterable):
120+
"""Audio samples with associated metadata."""
121+
# TODO-AUDIO: docs
122+
data: Tensor
123+
pts_seconds: float
124+
sample_rate: int
125+
def __post_init__(self):
126+
# This is called after __init__() when a Frame is created. We can run
127+
# input validation checks here.
128+
if not self.data.ndim == 2:
129+
raise ValueError(f"data must be 2-dimensional, got {self.data.shape = }")
130+
self.pts_seconds = float(self.pts_seconds)
131+
self.sample_rate = int(self.sample_rate)
132+
133+
def __iter__(self) -> Iterator[Union[Tensor, float]]:
134+
for field in dataclasses.fields(self):
135+
yield getattr(self, field.name)
136+
137+
def __repr__(self):
138+
return _frame_repr(self)

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from torch import Tensor
1111

12+
from torchcodec import AudioSamples
1213
from torchcodec.decoders import _core as core
1314
from torchcodec.decoders._decoder_utils import (
1415
create_decoder,
@@ -39,7 +40,7 @@ def __init__(
3940
)
4041

4142
def get_samples_played_in_range(
42-
self, start_seconds: float = 0, stop_seconds: Optional[float] = None
43+
self, start_seconds: float, stop_seconds: Optional[float] = None
4344
) -> Tensor:
4445
"""TODO-AUDIO docs"""
4546
if stop_seconds is not None and not start_seconds <= stop_seconds:
@@ -63,26 +64,37 @@ def get_samples_played_in_range(
6364
#
6465
# first_pts last_pts
6566
# v v
66-
# ....x..........x..........x...........x..........x..........x..........x.....
67+
# ....x..........x..........x...........x..........x..........x.....
6768
# ^ ^
6869
# start_seconds stop_seconds
6970
#
7071
# We want to return the samples in [start_seconds, stop_seconds). But
7172
# because the core API is based on frames, the `frames` tensor contains
7273
# the samples in [first_pts, last_pts)
73-
#
7474
# So we do some basic math to figure out the position of the view that
75-
# we'l; return.
75+
# we'll return.
7676

77-
offset_beginning = round(
78-
(max(0, start_seconds - first_pts)) * self.metadata.sample_rate
79-
)
77+
# TODO: sample_rate is either the original one from metadata, or the
78+
# user-specified one (NIY)
79+
sample_rate = self.metadata.sample_rate
80+
81+
if first_pts < start_seconds:
82+
offset_beginning = round((start_seconds - first_pts) * sample_rate)
83+
output_pts_seconds = start_seconds
84+
else:
85+
offset_beginning = 0
86+
output_pts_seconds = first_pts
8087

8188
num_samples = frames.shape[1]
82-
offset_end = num_samples
8389
last_pts = first_pts + num_samples / self.metadata.sample_rate
8490
if stop_seconds is not None and stop_seconds < last_pts:
85-
offset_end -= round((last_pts - stop_seconds) * self.metadata.sample_rate)
91+
offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate)
92+
else:
93+
offset_end = num_samples
94+
95+
return AudioSamples(
96+
data=frames[:, offset_beginning:offset_end],
97+
pts_seconds=output_pts_seconds,
98+
sample_rate=sample_rate,
99+
)
86100

87-
return frames[:, offset_beginning:offset_end]
88-
# return frames[:, offset_beginning:offset_end]

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class VideoDecoder {
147147
// DECODING AND SEEKING APIs
148148
// --------------------------------------------------------------------------
149149

150-
// All public decoding entry points return either a FrameOutput or a
150+
// All public video decoding entry points return either a FrameOutput or a
151151
// FrameBatchOutput.
152152
// They are the equivalent of the user-facing Frame and FrameBatch classes in
153153
// Python. They contain RGB decoded frames along with some associated data

test/decoders/test_decoders.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,3 +955,35 @@ def test_metadata(self, asset):
955955
)
956956
assert decoder.metadata.sample_rate == asset.sample_rate
957957
assert decoder.metadata.num_channels == asset.num_channels
958+
959+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
960+
def test_get_all_samples(self, asset):
961+
decoder = AudioDecoder(asset.path)
962+
963+
samples = decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=None)
964+
965+
reference_frames = asset.get_frame_data_by_range(
966+
start=0,
967+
stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1
968+
)
969+
970+
torch.testing.assert_close(samples.data, reference_frames)
971+
assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds
972+
973+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
974+
def test_get_samples_played_in_range(self, asset):
975+
decoder = AudioDecoder(asset.path)
976+
977+
start_seconds, stop_seconds = 2, 4
978+
samples = decoder.get_samples_played_in_range(start_seconds=start_seconds, stop_seconds=stop_seconds)
979+
980+
reference_frames = asset.get_frame_data_by_range(
981+
start=asset.get_frame_index(pts_seconds=start_seconds),
982+
stop=asset.get_frame_index(pts_seconds=stop_seconds) + 1
983+
)
984+
985+
assert samples.pts_seconds == start_seconds
986+
num_samples = samples.data.shape[1]
987+
assert num_samples < reference_frames.shape[1]
988+
assert num_samples == (stop_seconds - start_seconds) * decoder.metadata.sample_rate
989+

test/test_frame_dataclasses.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import pytest
22
import torch
3-
from torchcodec import Frame, FrameBatch
3+
from torchcodec import Frame, FrameBatch, AudioSamples
44

55

6-
def test_frame_unpacking():
6+
def test_unpacking():
77
data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa
8+
data, pts_seconds, sample_rate = AudioSamples(torch.rand(2, 4), 2, 16_000)
89

910

1011
def test_frame_error():
@@ -139,3 +140,17 @@ def test_framebatch_indexing():
139140
fb_fancy = fb[[[0], [1]]] # select T=0 and N=1.
140141
assert isinstance(fb_fancy, FrameBatch)
141142
assert fb_fancy.data.shape == (1, C, H, W)
143+
144+
def test_audio_samples_error():
145+
with pytest.raises(ValueError, match="data must be 2-dimensional"):
146+
AudioSamples(
147+
data=torch.rand(1),
148+
pts_seconds=1,
149+
sample_rate=16_000,
150+
)
151+
with pytest.raises(ValueError, match="data must be 2-dimensional"):
152+
AudioSamples(
153+
data=torch.rand(1, 2, 3),
154+
pts_seconds=1,
155+
sample_rate=16_000,
156+
)

0 commit comments

Comments
 (0)