Skip to content

Commit 4495150

Browse files
committed
Add get_samples_played_in_range public method
1 parent d75fc58 commit 4495150

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,52 @@ def __init__(
3737
) = get_and_validate_stream_metadata(
3838
decoder=self._decoder, stream_index=stream_index, media_type="audio"
3939
)
40+
41+
def get_samples_played_in_range(
42+
self, start_seconds: float = 0, stop_seconds: Optional[float] = None
43+
) -> Tensor:
44+
"""TODO-AUDIO docs"""
45+
if stop_seconds is not None and not start_seconds <= stop_seconds:
46+
raise ValueError(
47+
f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})."
48+
)
49+
if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds:
50+
raise ValueError(
51+
f"Invalid start seconds: {start_seconds}. "
52+
f"It must be greater than or equal to {self._begin_stream_seconds} "
53+
f"and less than or equal to {self._end_stream_seconds}."
54+
)
55+
frames, first_pts = core.get_frames_by_pts_in_range_audio(
56+
self._decoder,
57+
start_seconds=start_seconds,
58+
stop_seconds=stop_seconds,
59+
)
60+
first_pts = first_pts.item()
61+
62+
# x = frame boundaries
63+
#
64+
# first_pts last_pts
65+
# v v
66+
# ....x..........x..........x...........x..........x..........x..........x.....
67+
# ^ ^
68+
# start_seconds stop_seconds
69+
#
70+
# We want to return the samples in [start_seconds, stop_seconds). But
71+
# because the core API is based on frames, the `frames` tensor contains
72+
# the samples in [first_pts, last_pts).pts
73+
#
74+
# So we return a view on that tensor and do some basic math to figure
75+
# out where to chunk it.
76+
77+
offset_beginning = round(
78+
(max(0, start_seconds - first_pts)) * self.metadata.sample_rate
79+
)
80+
81+
num_samples = frames.shape[1]
82+
offset_end = num_samples
83+
last_pts = first_pts + num_samples / self.metadata.sample_rate
84+
if stop_seconds is not None and stop_seconds < last_pts:
85+
offset_end -= round((last_pts - stop_seconds) * self.metadata.sample_rate)
86+
87+
return frames[:, offset_beginning:offset_end]
88+
# return frames[:, offset_beginning:offset_end]

0 commit comments

Comments
 (0)