Skip to content

Commit ffd07d4

Browse files
committed
Remove streamIndex from core and ops APIs
1 parent 0f50aba commit ffd07d4

File tree

4 files changed

+38
-96
lines changed

4 files changed

+38
-96
lines changed

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,23 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3939
m.def(
4040
"get_frame_at_pts(Tensor(a!) decoder, float seconds) -> (Tensor, Tensor, Tensor)");
4141
m.def(
42-
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
42+
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)");
4343
m.def(
44-
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
44+
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
4545
m.def(
46-
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
46+
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4747
m.def(
48-
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
48+
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
4949
m.def(
50-
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
51-
m.def(
52-
"_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor");
50+
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
51+
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
5352
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
5453
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
5554
m.def(
5655
"get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str");
5756
m.def("_get_json_ffmpeg_library_versions() -> str");
5857
m.def(
59-
"_test_frame_pts_equality(Tensor(a!) decoder, *, int stream_index, int frame_index, float pts_seconds_to_test) -> bool");
58+
"_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool");
6059
m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()");
6160
}
6261

@@ -251,18 +250,14 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
251250
return makeOpsFrameOutput(result);
252251
}
253252

254-
OpsFrameOutput get_frame_at_index(
255-
at::Tensor& decoder,
256-
[[maybe_unused]] int64_t stream_index,
257-
int64_t frame_index) {
253+
OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) {
258254
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
259255
auto result = videoDecoder->getFrameAtIndex(frame_index);
260256
return makeOpsFrameOutput(result);
261257
}
262258

263259
OpsFrameBatchOutput get_frames_at_indices(
264260
at::Tensor& decoder,
265-
[[maybe_unused]] int64_t stream_index,
266261
at::IntArrayRef frame_indices) {
267262
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
268263
std::vector<int64_t> frameIndicesVec(
@@ -273,7 +268,6 @@ OpsFrameBatchOutput get_frames_at_indices(
273268

274269
OpsFrameBatchOutput get_frames_in_range(
275270
at::Tensor& decoder,
276-
[[maybe_unused]] int64_t stream_index,
277271
int64_t start,
278272
int64_t stop,
279273
std::optional<int64_t> step) {
@@ -284,7 +278,6 @@ OpsFrameBatchOutput get_frames_in_range(
284278

285279
OpsFrameBatchOutput get_frames_by_pts(
286280
at::Tensor& decoder,
287-
[[maybe_unused]] int64_t stream_index,
288281
at::ArrayRef<double> timestamps) {
289282
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
290283
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
@@ -294,7 +287,6 @@ OpsFrameBatchOutput get_frames_by_pts(
294287

295288
OpsFrameBatchOutput get_frames_by_pts_in_range(
296289
at::Tensor& decoder,
297-
[[maybe_unused]] int64_t stream_index,
298290
double start_seconds,
299291
double stop_seconds) {
300292
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
@@ -327,17 +319,14 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
327319

328320
bool _test_frame_pts_equality(
329321
at::Tensor& decoder,
330-
[[maybe_unused]] int64_t stream_index,
331322
int64_t frame_index,
332323
double pts_seconds_to_test) {
333324
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
334325
return pts_seconds_to_test ==
335326
videoDecoder->getPtsSecondsForFrame(frame_index);
336327
}
337328

338-
torch::Tensor _get_key_frame_indices(
339-
at::Tensor& decoder,
340-
[[maybe_unused]] int64_t stream_index) {
329+
torch::Tensor _get_key_frame_indices(at::Tensor& decoder) {
341330
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
342331
return videoDecoder->getKeyFrameIndices();
343332
}

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,10 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds);
8585
// Return the frames at given ptss for a given stream
8686
OpsFrameBatchOutput get_frames_by_pts(
8787
at::Tensor& decoder,
88-
int64_t stream_index,
8988
at::ArrayRef<double> timestamps);
9089

9190
// Return the frame that is visible at a given index in the video.
92-
OpsFrameOutput get_frame_at_index(
93-
at::Tensor& decoder,
94-
int64_t stream_index,
95-
int64_t frame_index);
91+
OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index);
9692

9793
// Get the next frame from the video as a tuple that has the frame data, pts and
9894
// duration as tensors.
@@ -101,14 +97,12 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder);
10197
// Return the frames at given indices for a given stream
10298
OpsFrameBatchOutput get_frames_at_indices(
10399
at::Tensor& decoder,
104-
int64_t stream_index,
105100
at::IntArrayRef frame_indices);
106101

107102
// Return the frames inside a range as a single stacked Tensor. The range is
108103
// defined as [start, stop).
109104
OpsFrameBatchOutput get_frames_in_range(
110105
at::Tensor& decoder,
111-
int64_t stream_index,
112106
int64_t start,
113107
int64_t stop,
114108
std::optional<int64_t> step = std::nullopt);
@@ -118,7 +112,6 @@ OpsFrameBatchOutput get_frames_in_range(
118112
// order.
119113
OpsFrameBatchOutput get_frames_by_pts_in_range(
120114
at::Tensor& decoder,
121-
int64_t stream_index,
122115
double start_seconds,
123116
double stop_seconds);
124117

@@ -128,16 +121,15 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
128121
// We want to make sure that the value is preserved exactly, bit-for-bit, during
129122
// this process.
130123
//
131-
// Returns true if for the given decoder, in the stream stream_index, the pts
124+
// Returns true if for the given decoder, the pts
132125
// value when converted to seconds as a double is exactly pts_seconds_to_test.
133126
// Returns false otherwise.
134127
bool _test_frame_pts_equality(
135128
at::Tensor& decoder,
136-
int64_t stream_index,
137129
int64_t frame_index,
138130
double pts_seconds_to_test);
139131

140-
torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index);
132+
torch::Tensor _get_key_frame_indices(at::Tensor& decoder);
141133

142134
// Get the metadata from the video as a string.
143135
std::string get_json_metadata(at::Tensor& decoder);

src/torchcodec/decoders/_video_decoder.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,7 @@ def _getitem_int(self, key: int) -> Tensor:
152152
f"Index {key} is out of bounds; length is {self._num_frames}"
153153
)
154154

155-
frame_data, *_ = core.get_frame_at_index(
156-
self._decoder, frame_index=key, stream_index=self.stream_index
157-
)
155+
frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key)
158156
return frame_data
159157

160158
def _getitem_slice(self, key: slice) -> Tensor:
@@ -163,7 +161,6 @@ def _getitem_slice(self, key: slice) -> Tensor:
163161
start, stop, step = key.indices(len(self))
164162
frame_data, *_ = core.get_frames_in_range(
165163
self._decoder,
166-
stream_index=self.stream_index,
167164
start=start,
168165
stop=stop,
169166
step=step,
@@ -189,9 +186,7 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
189186
)
190187

191188
def _get_key_frame_indices(self) -> list[int]:
192-
return core._get_key_frame_indices(
193-
self._decoder, stream_index=self.stream_index
194-
)
189+
return core._get_key_frame_indices(self._decoder)
195190

196191
def get_frame_at(self, index: int) -> Frame:
197192
"""Return a single frame at the given index.
@@ -208,7 +203,7 @@ def get_frame_at(self, index: int) -> Frame:
208203
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})."
209204
)
210205
data, pts_seconds, duration_seconds = core.get_frame_at_index(
211-
self._decoder, frame_index=index, stream_index=self.stream_index
206+
self._decoder, frame_index=index
212207
)
213208
return Frame(
214209
data=data,
@@ -234,7 +229,7 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
234229
"""
235230

236231
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
237-
self._decoder, stream_index=self.stream_index, frame_indices=indices
232+
self._decoder, frame_indices=indices
238233
)
239234
return FrameBatch(
240235
data=data,
@@ -268,7 +263,6 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
268263
raise IndexError(f"Step ({step}) must be greater than 0.")
269264
frames = core.get_frames_in_range(
270265
self._decoder,
271-
stream_index=self.stream_index,
272266
start=start,
273267
stop=stop,
274268
step=step,
@@ -316,7 +310,7 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
316310
FrameBatch: The frames that are played at ``seconds``.
317311
"""
318312
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
319-
self._decoder, timestamps=seconds, stream_index=self.stream_index
313+
self._decoder, timestamps=seconds
320314
)
321315
return FrameBatch(
322316
data=data,
@@ -359,7 +353,6 @@ def get_frames_played_in_range(
359353
)
360354
frames = core.get_frames_by_pts_in_range(
361355
self._decoder,
362-
stream_index=self.stream_index,
363356
start_seconds=start_seconds,
364357
stop_seconds=stop_seconds,
365358
)

0 commit comments

Comments
 (0)