Skip to content

Commit 590fe1c

Browse files
authored
Remove streamIndex parameter from core and ops APIs (#509)
1 parent 9028793 commit 590fe1c

File tree

6 files changed

+39
-100
lines changed

6 files changed

+39
-100
lines changed

src/torchcodec/_samplers/video_clip_sampler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ def _get_clips_for_index_based_sampling(
242242
]
243243
frames, *_ = get_frames_at_indices(
244244
video_decoder,
245-
stream_index=metadata_json["bestVideoStreamIndex"],
246245
frame_indices=batch_indexes,
247246
)
248247
clips.append(frames)

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,23 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3939
m.def(
4040
"get_frame_at_pts(Tensor(a!) decoder, float seconds) -> (Tensor, Tensor, Tensor)");
4141
m.def(
42-
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
42+
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)");
4343
m.def(
44-
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
44+
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
4545
m.def(
46-
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
46+
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4747
m.def(
48-
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
48+
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
4949
m.def(
50-
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
51-
m.def(
52-
"_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor");
50+
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
51+
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
5352
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
5453
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
5554
m.def(
5655
"get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str");
5756
m.def("_get_json_ffmpeg_library_versions() -> str");
5857
m.def(
59-
"_test_frame_pts_equality(Tensor(a!) decoder, *, int stream_index, int frame_index, float pts_seconds_to_test) -> bool");
58+
"_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool");
6059
m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()");
6160
}
6261

@@ -245,18 +244,14 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
245244
return makeOpsFrameOutput(result);
246245
}
247246

248-
OpsFrameOutput get_frame_at_index(
249-
at::Tensor& decoder,
250-
[[maybe_unused]] int64_t stream_index,
251-
int64_t frame_index) {
247+
OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) {
252248
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
253249
auto result = videoDecoder->getFrameAtIndex(frame_index);
254250
return makeOpsFrameOutput(result);
255251
}
256252

257253
OpsFrameBatchOutput get_frames_at_indices(
258254
at::Tensor& decoder,
259-
[[maybe_unused]] int64_t stream_index,
260255
at::IntArrayRef frame_indices) {
261256
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
262257
std::vector<int64_t> frameIndicesVec(
@@ -267,7 +262,6 @@ OpsFrameBatchOutput get_frames_at_indices(
267262

268263
OpsFrameBatchOutput get_frames_in_range(
269264
at::Tensor& decoder,
270-
[[maybe_unused]] int64_t stream_index,
271265
int64_t start,
272266
int64_t stop,
273267
std::optional<int64_t> step) {
@@ -278,7 +272,6 @@ OpsFrameBatchOutput get_frames_in_range(
278272

279273
OpsFrameBatchOutput get_frames_by_pts(
280274
at::Tensor& decoder,
281-
[[maybe_unused]] int64_t stream_index,
282275
at::ArrayRef<double> timestamps) {
283276
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
284277
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
@@ -288,7 +281,6 @@ OpsFrameBatchOutput get_frames_by_pts(
288281

289282
OpsFrameBatchOutput get_frames_by_pts_in_range(
290283
at::Tensor& decoder,
291-
[[maybe_unused]] int64_t stream_index,
292284
double start_seconds,
293285
double stop_seconds) {
294286
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
@@ -321,17 +313,14 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
321313

322314
bool _test_frame_pts_equality(
323315
at::Tensor& decoder,
324-
[[maybe_unused]] int64_t stream_index,
325316
int64_t frame_index,
326317
double pts_seconds_to_test) {
327318
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
328319
return pts_seconds_to_test ==
329320
videoDecoder->getPtsSecondsForFrame(frame_index);
330321
}
331322

332-
torch::Tensor _get_key_frame_indices(
333-
at::Tensor& decoder,
334-
[[maybe_unused]] int64_t stream_index) {
323+
torch::Tensor _get_key_frame_indices(at::Tensor& decoder) {
335324
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
336325
return videoDecoder->getKeyFrameIndices();
337326
}

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,10 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds);
8585
// Return the frames at given ptss for a given stream
8686
OpsFrameBatchOutput get_frames_by_pts(
8787
at::Tensor& decoder,
88-
int64_t stream_index,
8988
at::ArrayRef<double> timestamps);
9089

9190
// Return the frame that is visible at a given index in the video.
92-
OpsFrameOutput get_frame_at_index(
93-
at::Tensor& decoder,
94-
int64_t stream_index,
95-
int64_t frame_index);
91+
OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index);
9692

9793
// Get the next frame from the video as a tuple that has the frame data, pts and
9894
// duration as tensors.
@@ -101,14 +97,12 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder);
10197
// Return the frames at given indices for a given stream
10298
OpsFrameBatchOutput get_frames_at_indices(
10399
at::Tensor& decoder,
104-
int64_t stream_index,
105100
at::IntArrayRef frame_indices);
106101

107102
// Return the frames inside a range as a single stacked Tensor. The range is
108103
// defined as [start, stop).
109104
OpsFrameBatchOutput get_frames_in_range(
110105
at::Tensor& decoder,
111-
int64_t stream_index,
112106
int64_t start,
113107
int64_t stop,
114108
std::optional<int64_t> step = std::nullopt);
@@ -118,7 +112,6 @@ OpsFrameBatchOutput get_frames_in_range(
118112
// order.
119113
OpsFrameBatchOutput get_frames_by_pts_in_range(
120114
at::Tensor& decoder,
121-
int64_t stream_index,
122115
double start_seconds,
123116
double stop_seconds);
124117

@@ -128,16 +121,15 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
128121
// We want to make sure that the value is preserved exactly, bit-for-bit, during
129122
// this process.
130123
//
131-
// Returns true if for the given decoder, in the stream stream_index, the pts
124+
// Returns true if for the given decoder, the pts
132125
// value when converted to seconds as a double is exactly pts_seconds_to_test.
133126
// Returns false otherwise.
134127
bool _test_frame_pts_equality(
135128
at::Tensor& decoder,
136-
int64_t stream_index,
137129
int64_t frame_index,
138130
double pts_seconds_to_test);
139131

140-
torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index);
132+
torch::Tensor _get_key_frame_indices(at::Tensor& decoder);
141133

142134
// Get the metadata from the video as a string.
143135
std::string get_json_metadata(at::Tensor& decoder);

src/torchcodec/decoders/_video_decoder.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,7 @@ def _getitem_int(self, key: int) -> Tensor:
152152
f"Index {key} is out of bounds; length is {self._num_frames}"
153153
)
154154

155-
frame_data, *_ = core.get_frame_at_index(
156-
self._decoder, frame_index=key, stream_index=self.stream_index
157-
)
155+
frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key)
158156
return frame_data
159157

160158
def _getitem_slice(self, key: slice) -> Tensor:
@@ -163,7 +161,6 @@ def _getitem_slice(self, key: slice) -> Tensor:
163161
start, stop, step = key.indices(len(self))
164162
frame_data, *_ = core.get_frames_in_range(
165163
self._decoder,
166-
stream_index=self.stream_index,
167164
start=start,
168165
stop=stop,
169166
step=step,
@@ -189,9 +186,7 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
189186
)
190187

191188
def _get_key_frame_indices(self) -> list[int]:
192-
return core._get_key_frame_indices(
193-
self._decoder, stream_index=self.stream_index
194-
)
189+
return core._get_key_frame_indices(self._decoder)
195190

196191
def get_frame_at(self, index: int) -> Frame:
197192
"""Return a single frame at the given index.
@@ -208,7 +203,7 @@ def get_frame_at(self, index: int) -> Frame:
208203
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})."
209204
)
210205
data, pts_seconds, duration_seconds = core.get_frame_at_index(
211-
self._decoder, frame_index=index, stream_index=self.stream_index
206+
self._decoder, frame_index=index
212207
)
213208
return Frame(
214209
data=data,
@@ -234,7 +229,7 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
234229
"""
235230

236231
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
237-
self._decoder, stream_index=self.stream_index, frame_indices=indices
232+
self._decoder, frame_indices=indices
238233
)
239234
return FrameBatch(
240235
data=data,
@@ -268,7 +263,6 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
268263
raise IndexError(f"Step ({step}) must be greater than 0.")
269264
frames = core.get_frames_in_range(
270265
self._decoder,
271-
stream_index=self.stream_index,
272266
start=start,
273267
stop=stop,
274268
step=step,
@@ -316,7 +310,7 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
316310
FrameBatch: The frames that are played at ``seconds``.
317311
"""
318312
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
319-
self._decoder, timestamps=seconds, stream_index=self.stream_index
313+
self._decoder, timestamps=seconds
320314
)
321315
return FrameBatch(
322316
data=data,
@@ -359,7 +353,6 @@ def get_frames_played_in_range(
359353
)
360354
frames = core.get_frames_by_pts_in_range(
361355
self._decoder,
362-
stream_index=self.stream_index,
363356
start_seconds=start_seconds,
364357
stop_seconds=stop_seconds,
365358
)

test/decoders/manual_smoke_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,5 @@
1616
)
1717
torchcodec.decoders._core.scan_all_streams_to_update_metadata(decoder)
1818
torchcodec.decoders._core.add_video_stream(decoder, stream_index=3)
19-
frame, _, _ = torchcodec.decoders._core.get_frame_at_index(
20-
decoder, stream_index=3, frame_index=180
21-
)
19+
frame, _, _ = torchcodec.decoders._core.get_frame_at_index(decoder, frame_index=180)
2220
write_png(frame, "frame180.png")

0 commit comments

Comments
 (0)