|
9 | 9 |
|
10 | 10 | from torch import Tensor |
11 | 11 |
|
| 12 | +from torchcodec import AudioSamples |
12 | 13 | from torchcodec.decoders import _core as core |
13 | 14 | from torchcodec.decoders._decoder_utils import ( |
14 | 15 | create_decoder, |
@@ -37,3 +38,70 @@ def __init__( |
37 | 38 | ) = get_and_validate_stream_metadata( |
38 | 39 | decoder=self._decoder, stream_index=stream_index, media_type="audio" |
39 | 40 | ) |
| 41 | + assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy |
| 42 | + |
| 43 | + def get_samples_played_in_range( |
| 44 | + self, start_seconds: float, stop_seconds: Optional[float] = None |
| 45 | + ) -> AudioSamples: |
| 46 | + """TODO-AUDIO docs""" |
| 47 | + if stop_seconds is not None and not start_seconds <= stop_seconds: |
| 48 | + raise ValueError( |
| 49 | + f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})." |
| 50 | + ) |
| 51 | + if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds: |
| 52 | + raise ValueError( |
| 53 | + f"Invalid start seconds: {start_seconds}. " |
| 54 | + f"It must be greater than or equal to {self._begin_stream_seconds} " |
| 55 | + f"and less than or equal to {self._end_stream_seconds}." |
| 56 | + ) |
| 57 | + frames, first_pts = core.get_frames_by_pts_in_range_audio( |
| 58 | + self._decoder, |
| 59 | + start_seconds=start_seconds, |
| 60 | + stop_seconds=stop_seconds, |
| 61 | + ) |
| 62 | + first_pts = first_pts.item() |
| 63 | + |
| 64 | + # x = frame boundaries |
| 65 | + # |
| 66 | + # first_pts last_pts |
| 67 | + # v v |
| 68 | + # ....x..........x..........x...........x..........x..........x..... |
| 69 | + # ^ ^ |
| 70 | + # start_seconds stop_seconds |
| 71 | + # |
| 72 | + # We want to return the samples in [start_seconds, stop_seconds). But |
| 73 | + # because the core API is based on frames, the `frames` tensor contains |
| 74 | + # the samples in [first_pts, last_pts) |
| 75 | + # So we do some basic math to figure out the position of the view that |
| 76 | + # we'll return. |
| 77 | + |
| 78 | + # TODO: sample_rate is either the original one from metadata, or the |
| 79 | + # user-specified one (NIY) |
| 80 | + assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy |
| 81 | + sample_rate = self.metadata.sample_rate |
| 82 | + |
| 83 | + # TODO: metadata's sample_rate should probably not be Optional |
| 84 | + assert sample_rate is not None # mypy. |
| 85 | + |
| 86 | + if first_pts < start_seconds: |
| 87 | + offset_beginning = round((start_seconds - first_pts) * sample_rate) |
| 88 | + output_pts_seconds = start_seconds |
| 89 | + else: |
| 90 | + # In normal cases we'll have first_pts <= start_pts, but in some |
| 91 | + # edge cases it's possible to have first_pts > start_seconds, |
| 92 | + # typically if the stream's first frame's pts isn't exactly 0. |
| 93 | + offset_beginning = 0 |
| 94 | + output_pts_seconds = first_pts |
| 95 | + |
| 96 | + num_samples = frames.shape[1] |
| 97 | + last_pts = first_pts + num_samples / self.metadata.sample_rate |
| 98 | + if stop_seconds is not None and stop_seconds < last_pts: |
| 99 | + offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate) |
| 100 | + else: |
| 101 | + offset_end = num_samples |
| 102 | + |
| 103 | + return AudioSamples( |
| 104 | + data=frames[:, offset_beginning:offset_end], |
| 105 | + pts_seconds=output_pts_seconds, |
| 106 | + sample_rate=sample_rate, |
| 107 | + ) |
0 commit comments