Skip to content

Commit 274730d

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into audio_support
2 parents 224e18c + 590fe1c commit 274730d

File tree

6 files changed

+39
-100
lines changed

6 files changed

+39
-100
lines changed

src/torchcodec/_samplers/video_clip_sampler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ def _get_clips_for_index_based_sampling(
242242
]
243243
frames, *_ = get_frames_at_indices(
244244
video_decoder,
245-
stream_index=metadata_json["bestVideoStreamIndex"],
246245
frame_indices=batch_indexes,
247246
)
248247
clips.append(frames)

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,23 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4141
m.def(
4242
"get_frame_at_pts(Tensor(a!) decoder, float seconds) -> (Tensor, Tensor, Tensor)");
4343
m.def(
44-
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
44+
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)");
4545
m.def(
46-
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
46+
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
4747
m.def(
48-
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
48+
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4949
m.def(
50-
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
50+
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
5151
m.def(
52-
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
53-
m.def(
54-
"_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor");
52+
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
53+
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
5554
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
5655
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
5756
m.def(
5857
"get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str");
5958
m.def("_get_json_ffmpeg_library_versions() -> str");
6059
m.def(
61-
"_test_frame_pts_equality(Tensor(a!) decoder, *, int stream_index, int frame_index, float pts_seconds_to_test) -> bool");
60+
"_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool");
6261
m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()");
6362
}
6463

@@ -254,18 +253,14 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
254253
return makeOpsFrameOutput(result);
255254
}
256255

257-
OpsFrameOutput get_frame_at_index(
258-
at::Tensor& decoder,
259-
[[maybe_unused]] int64_t stream_index,
260-
int64_t frame_index) {
256+
OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) {
261257
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
262258
auto result = videoDecoder->getFrameAtIndex(frame_index);
263259
return makeOpsFrameOutput(result);
264260
}
265261

266262
OpsFrameBatchOutput get_frames_at_indices(
267263
at::Tensor& decoder,
268-
[[maybe_unused]] int64_t stream_index,
269264
at::IntArrayRef frame_indices) {
270265
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
271266
std::vector<int64_t> frameIndicesVec(
@@ -276,7 +271,6 @@ OpsFrameBatchOutput get_frames_at_indices(
276271

277272
OpsFrameBatchOutput get_frames_in_range(
278273
at::Tensor& decoder,
279-
[[maybe_unused]] int64_t stream_index,
280274
int64_t start,
281275
int64_t stop,
282276
std::optional<int64_t> step) {
@@ -287,7 +281,6 @@ OpsFrameBatchOutput get_frames_in_range(
287281

288282
OpsFrameBatchOutput get_frames_by_pts(
289283
at::Tensor& decoder,
290-
[[maybe_unused]] int64_t stream_index,
291284
at::ArrayRef<double> timestamps) {
292285
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
293286
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
@@ -297,7 +290,6 @@ OpsFrameBatchOutput get_frames_by_pts(
297290

298291
OpsFrameBatchOutput get_frames_by_pts_in_range(
299292
at::Tensor& decoder,
300-
[[maybe_unused]] int64_t stream_index,
301293
double start_seconds,
302294
double stop_seconds) {
303295
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
@@ -330,17 +322,14 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
330322

331323
bool _test_frame_pts_equality(
332324
at::Tensor& decoder,
333-
[[maybe_unused]] int64_t stream_index,
334325
int64_t frame_index,
335326
double pts_seconds_to_test) {
336327
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
337328
return pts_seconds_to_test ==
338329
videoDecoder->getPtsSecondsForFrame(frame_index);
339330
}
340331

341-
torch::Tensor _get_key_frame_indices(
342-
at::Tensor& decoder,
343-
[[maybe_unused]] int64_t stream_index) {
332+
torch::Tensor _get_key_frame_indices(at::Tensor& decoder) {
344333
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
345334
return videoDecoder->getKeyFrameIndices();
346335
}

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,10 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds);
8989
// Return the frames at given ptss for a given stream
9090
OpsFrameBatchOutput get_frames_by_pts(
9191
at::Tensor& decoder,
92-
int64_t stream_index,
9392
at::ArrayRef<double> timestamps);
9493

9594
// Return the frame that is visible at a given index in the video.
96-
OpsFrameOutput get_frame_at_index(
97-
at::Tensor& decoder,
98-
int64_t stream_index,
99-
int64_t frame_index);
95+
OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index);
10096

10197
// Get the next frame from the video as a tuple that has the frame data, pts and
10298
// duration as tensors.
@@ -105,14 +101,12 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder);
105101
// Return the frames at given indices for a given stream
106102
OpsFrameBatchOutput get_frames_at_indices(
107103
at::Tensor& decoder,
108-
int64_t stream_index,
109104
at::IntArrayRef frame_indices);
110105

111106
// Return the frames inside a range as a single stacked Tensor. The range is
112107
// defined as [start, stop).
113108
OpsFrameBatchOutput get_frames_in_range(
114109
at::Tensor& decoder,
115-
int64_t stream_index,
116110
int64_t start,
117111
int64_t stop,
118112
std::optional<int64_t> step = std::nullopt);
@@ -122,7 +116,6 @@ OpsFrameBatchOutput get_frames_in_range(
122116
// order.
123117
OpsFrameBatchOutput get_frames_by_pts_in_range(
124118
at::Tensor& decoder,
125-
int64_t stream_index,
126119
double start_seconds,
127120
double stop_seconds);
128121

@@ -132,16 +125,15 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
132125
// We want to make sure that the value is preserved exactly, bit-for-bit, during
133126
// this process.
134127
//
135-
// Returns true if for the given decoder, in the stream stream_index, the pts
128+
// Returns true if for the given decoder, the pts
136129
// value when converted to seconds as a double is exactly pts_seconds_to_test.
137130
// Returns false otherwise.
138131
bool _test_frame_pts_equality(
139132
at::Tensor& decoder,
140-
int64_t stream_index,
141133
int64_t frame_index,
142134
double pts_seconds_to_test);
143135

144-
torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index);
136+
torch::Tensor _get_key_frame_indices(at::Tensor& decoder);
145137

146138
// Get the metadata from the video as a string.
147139
std::string get_json_metadata(at::Tensor& decoder);

src/torchcodec/decoders/_video_decoder.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,7 @@ def _getitem_int(self, key: int) -> Tensor:
152152
f"Index {key} is out of bounds; length is {self._num_frames}"
153153
)
154154

155-
frame_data, *_ = core.get_frame_at_index(
156-
self._decoder, frame_index=key, stream_index=self.stream_index
157-
)
155+
frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key)
158156
return frame_data
159157

160158
def _getitem_slice(self, key: slice) -> Tensor:
@@ -163,7 +161,6 @@ def _getitem_slice(self, key: slice) -> Tensor:
163161
start, stop, step = key.indices(len(self))
164162
frame_data, *_ = core.get_frames_in_range(
165163
self._decoder,
166-
stream_index=self.stream_index,
167164
start=start,
168165
stop=stop,
169166
step=step,
@@ -189,9 +186,7 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
189186
)
190187

191188
def _get_key_frame_indices(self) -> list[int]:
192-
return core._get_key_frame_indices(
193-
self._decoder, stream_index=self.stream_index
194-
)
189+
return core._get_key_frame_indices(self._decoder)
195190

196191
def get_frame_at(self, index: int) -> Frame:
197192
"""Return a single frame at the given index.
@@ -208,7 +203,7 @@ def get_frame_at(self, index: int) -> Frame:
208203
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})."
209204
)
210205
data, pts_seconds, duration_seconds = core.get_frame_at_index(
211-
self._decoder, frame_index=index, stream_index=self.stream_index
206+
self._decoder, frame_index=index
212207
)
213208
return Frame(
214209
data=data,
@@ -234,7 +229,7 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
234229
"""
235230

236231
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
237-
self._decoder, stream_index=self.stream_index, frame_indices=indices
232+
self._decoder, frame_indices=indices
238233
)
239234
return FrameBatch(
240235
data=data,
@@ -268,7 +263,6 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
268263
raise IndexError(f"Step ({step}) must be greater than 0.")
269264
frames = core.get_frames_in_range(
270265
self._decoder,
271-
stream_index=self.stream_index,
272266
start=start,
273267
stop=stop,
274268
step=step,
@@ -316,7 +310,7 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
316310
FrameBatch: The frames that are played at ``seconds``.
317311
"""
318312
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
319-
self._decoder, timestamps=seconds, stream_index=self.stream_index
313+
self._decoder, timestamps=seconds
320314
)
321315
return FrameBatch(
322316
data=data,
@@ -359,7 +353,6 @@ def get_frames_played_in_range(
359353
)
360354
frames = core.get_frames_by_pts_in_range(
361355
self._decoder,
362-
stream_index=self.stream_index,
363356
start_seconds=start_seconds,
364357
stop_seconds=stop_seconds,
365358
)

test/decoders/manual_smoke_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,5 @@
1616
)
1717
torchcodec.decoders._core.scan_all_streams_to_update_metadata(decoder)
1818
torchcodec.decoders._core.add_video_stream(decoder, stream_index=3)
19-
frame, _, _ = torchcodec.decoders._core.get_frame_at_index(
20-
decoder, stream_index=3, frame_index=180
21-
)
19+
frame, _, _ = torchcodec.decoders._core.get_frame_at_index(decoder, frame_index=180)
2220
write_png(frame, "frame180.png")

0 commit comments

Comments
 (0)