@@ -43,9 +43,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
43
43
m.def (
44
44
" _create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor" );
45
45
m.def (
46
- " _add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\" cpu\" , str device_variant=\" default\" , (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()" );
46
+ " _add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\" cpu\" , str device_variant=\" default\" , str transform_specs= \" \" , (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()" );
47
47
m.def (
48
- " add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\" cpu\" , str device_variant=\" default\" , (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()" );
48
+ " add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\" cpu\" , str device_variant=\" default\" , str transform_specs= \" \" , (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()" );
49
49
m.def (
50
50
" add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()" );
51
51
m.def (" seek_to_pts(Tensor(a!) decoder, float seconds) -> ()" );
@@ -183,6 +183,69 @@ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
183
183
}
184
184
}
185
185
186
+ int checkedToPositiveInt (const std::string& str) {
187
+ int ret = 0 ;
188
+ try {
189
+ ret = std::stoi (str);
190
+ } catch (const std::invalid_argument&) {
191
+ TORCH_CHECK (false , " String cannot be converted to an int:" + str);
192
+ } catch (const std::out_of_range&) {
193
+ TORCH_CHECK (false , " String would become integer out of range:" + str);
194
+ }
195
+ TORCH_CHECK (ret > 0 , " String must be a positive integer:" + str);
196
+ return ret;
197
+ }
198
+
199
+ // Resize transform specs take the form:
200
+ //
201
+ // "resize, <height>, <width>"
202
+ //
203
+ // Where "resize" is the string literal and <height> and <width> are positive
204
+ // integers.
205
+ Transform* makeResizeTransform (
206
+ const std::vector<std::string>& resizeTransformSpec) {
207
+ TORCH_CHECK (
208
+ resizeTransformSpec.size () == 3 ,
209
+ " resizeTransformSpec must have 3 elements including its name" );
210
+ int height = checkedToPositiveInt (resizeTransformSpec[1 ]);
211
+ int width = checkedToPositiveInt (resizeTransformSpec[2 ]);
212
+ return new ResizeTransform (FrameDims (height, width));
213
+ }
214
+
215
+ std::vector<std::string> split (const std::string& str, char delimiter) {
216
+ std::vector<std::string> tokens;
217
+ std::string token;
218
+ std::istringstream tokenStream (str);
219
+ while (std::getline (tokenStream, token, delimiter)) {
220
+ tokens.push_back (token);
221
+ }
222
+ return tokens;
223
+ }
224
+
225
+ // The transformSpecsRaw string is always in the format:
226
+ //
227
+ // "name1, param1, param2, ...; name2, param1, param2, ...; ..."
228
+ //
229
+ // Where "nameX" is the name of the transform, and "paramX" are the parameters.
230
+ std::vector<Transform*> makeTransforms (const std::string& transformSpecsRaw) {
231
+ std::vector<Transform*> transforms;
232
+ std::vector<std::string> transformSpecs = split (transformSpecsRaw, ' ;' );
233
+ for (const std::string& transformSpecRaw : transformSpecs) {
234
+ std::vector<std::string> transformSpec = split (transformSpecRaw, ' ,' );
235
+ TORCH_CHECK (
236
+ transformSpec.size () >= 1 ,
237
+ " Invalid transform spec: " + transformSpecRaw);
238
+
239
+ auto name = transformSpec[0 ];
240
+ if (name == " resize" ) {
241
+ transforms.push_back (makeResizeTransform (transformSpec));
242
+ } else {
243
+ TORCH_CHECK (false , " Invalid transform name: " + name);
244
+ }
245
+ }
246
+ return transforms;
247
+ }
248
+
186
249
} // namespace
187
250
188
251
// ==============================
@@ -252,36 +315,18 @@ at::Tensor _create_from_file_like(
252
315
253
316
void _add_video_stream (
254
317
at::Tensor& decoder,
255
- std::optional<int64_t > width = std::nullopt ,
256
- std::optional<int64_t > height = std::nullopt ,
257
318
std::optional<int64_t > num_threads = std::nullopt ,
258
319
std::optional<std::string_view> dimension_order = std::nullopt ,
259
320
std::optional<int64_t > stream_index = std::nullopt ,
260
321
std::string_view device = " cpu" ,
261
322
std::string_view device_variant = " default" ,
323
+ std::string_view transform_specs = " " ,
262
324
std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
263
325
custom_frame_mappings = std::nullopt ,
264
326
std::optional<std::string_view> color_conversion_library = std::nullopt ) {
265
327
VideoStreamOptions videoStreamOptions;
266
328
videoStreamOptions.ffmpegThreadCount = num_threads;
267
329
268
- // TODO: Eliminate this temporary bridge code. This exists because we have
269
- // not yet exposed the transforms API on the Python side. We also want
270
- // to remove the `width` and `height` arguments from the Python API.
271
- //
272
- // TEMPORARY BRIDGE CODE START
273
- TORCH_CHECK (
274
- width.has_value () == height.has_value (),
275
- " width and height must both be set or unset." );
276
- std::vector<Transform*> transforms;
277
- if (width.has_value ()) {
278
- transforms.push_back (
279
- new ResizeTransform (FrameDims (height.value (), width.value ())));
280
- width.reset ();
281
- height.reset ();
282
- }
283
- // TEMPORARY BRIDGE CODE END
284
-
285
330
if (dimension_order.has_value ()) {
286
331
std::string stdDimensionOrder{dimension_order.value ()};
287
332
TORCH_CHECK (stdDimensionOrder == " NHWC" || stdDimensionOrder == " NCHW" );
@@ -309,6 +354,9 @@ void _add_video_stream(
309
354
videoStreamOptions.device = torch::Device (std::string (device));
310
355
videoStreamOptions.deviceVariant = device_variant;
311
356
357
+ std::vector<Transform*> transforms =
358
+ makeTransforms (std::string (transform_specs));
359
+
312
360
std::optional<SingleStreamDecoder::FrameMappings> converted_mappings =
313
361
custom_frame_mappings.has_value ()
314
362
? std::make_optional (makeFrameMappings (custom_frame_mappings.value ()))
@@ -324,24 +372,22 @@ void _add_video_stream(
324
372
// Add a new video stream at `stream_index` using the provided options.
325
373
void add_video_stream (
326
374
at::Tensor& decoder,
327
- std::optional<int64_t > width = std::nullopt ,
328
- std::optional<int64_t > height = std::nullopt ,
329
375
std::optional<int64_t > num_threads = std::nullopt ,
330
376
std::optional<std::string_view> dimension_order = std::nullopt ,
331
377
std::optional<int64_t > stream_index = std::nullopt ,
332
378
std::string_view device = " cpu" ,
333
379
std::string_view device_variant = " default" ,
380
+ std::string_view transform_specs = " " ,
334
381
const std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>&
335
382
custom_frame_mappings = std::nullopt ) {
336
383
_add_video_stream (
337
384
decoder,
338
- width,
339
- height,
340
385
num_threads,
341
386
dimension_order,
342
387
stream_index,
343
388
device,
344
389
device_variant,
390
+ transform_specs,
345
391
custom_frame_mappings);
346
392
}
347
393
0 commit comments