Skip to content

Commit e5b2eef

Browse files
authored
Transforms bridge between Python and C++ (#948)
1 parent b9f77b5 commit e5b2eef

File tree

5 files changed

+130
-46
lines changed

5 files changed

+130
-46
lines changed

benchmarks/decoders/gpu_benchmark.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,17 @@ def decode_full_video(video_path, decode_device_string, resize_device_string):
2929
num_threads = None
3030
if "cuda" in decode_device_string:
3131
num_threads = 1
32-
width = None
33-
height = None
32+
33+
resize_spec = ""
3434
if "native" in resize_device_string:
35-
width = RESIZED_WIDTH
36-
height = RESIZED_HEIGHT
35+
resize_spec = f"resize, {RESIZED_HEIGHT}, {RESIZED_WIDTH}"
36+
3737
torchcodec._core._add_video_stream(
3838
decoder,
3939
stream_index=-1,
4040
device=decode_device_string,
4141
num_threads=num_threads,
42-
width=width,
43-
height=height,
42+
transform_specs=resize_spec,
4443
)
4544

4645
start_time = time.time()

src/torchcodec/_core/custom_ops.cpp

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4343
m.def(
4444
"_create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor");
4545
m.def(
46-
"_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"default\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
46+
"_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"default\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
4747
m.def(
48-
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"default\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()");
48+
"add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"default\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()");
4949
m.def(
5050
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()");
5151
m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
@@ -183,6 +183,69 @@ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
183183
}
184184
}
185185

186+
int checkedToPositiveInt(const std::string& str) {
187+
int ret = 0;
188+
try {
189+
ret = std::stoi(str);
190+
} catch (const std::invalid_argument&) {
191+
TORCH_CHECK(false, "String cannot be converted to an int:" + str);
192+
} catch (const std::out_of_range&) {
193+
TORCH_CHECK(false, "String would become integer out of range:" + str);
194+
}
195+
TORCH_CHECK(ret > 0, "String must be a positive integer:" + str);
196+
return ret;
197+
}
198+
199+
// Resize transform specs take the form:
200+
//
201+
// "resize, <height>, <width>"
202+
//
203+
// Where "resize" is the string literal and <height> and <width> are positive
204+
// integers.
205+
Transform* makeResizeTransform(
206+
const std::vector<std::string>& resizeTransformSpec) {
207+
TORCH_CHECK(
208+
resizeTransformSpec.size() == 3,
209+
"resizeTransformSpec must have 3 elements including its name");
210+
int height = checkedToPositiveInt(resizeTransformSpec[1]);
211+
int width = checkedToPositiveInt(resizeTransformSpec[2]);
212+
return new ResizeTransform(FrameDims(height, width));
213+
}
214+
215+
std::vector<std::string> split(const std::string& str, char delimiter) {
216+
std::vector<std::string> tokens;
217+
std::string token;
218+
std::istringstream tokenStream(str);
219+
while (std::getline(tokenStream, token, delimiter)) {
220+
tokens.push_back(token);
221+
}
222+
return tokens;
223+
}
224+
225+
// The transformSpecsRaw string is always in the format:
226+
//
227+
// "name1, param1, param2, ...; name2, param1, param2, ...; ..."
228+
//
229+
// Where "nameX" is the name of the transform, and "paramX" are the parameters.
230+
std::vector<Transform*> makeTransforms(const std::string& transformSpecsRaw) {
231+
std::vector<Transform*> transforms;
232+
std::vector<std::string> transformSpecs = split(transformSpecsRaw, ';');
233+
for (const std::string& transformSpecRaw : transformSpecs) {
234+
std::vector<std::string> transformSpec = split(transformSpecRaw, ',');
235+
TORCH_CHECK(
236+
transformSpec.size() >= 1,
237+
"Invalid transform spec: " + transformSpecRaw);
238+
239+
auto name = transformSpec[0];
240+
if (name == "resize") {
241+
transforms.push_back(makeResizeTransform(transformSpec));
242+
} else {
243+
TORCH_CHECK(false, "Invalid transform name: " + name);
244+
}
245+
}
246+
return transforms;
247+
}
248+
186249
} // namespace
187250

188251
// ==============================
@@ -252,36 +315,18 @@ at::Tensor _create_from_file_like(
252315

253316
void _add_video_stream(
254317
at::Tensor& decoder,
255-
std::optional<int64_t> width = std::nullopt,
256-
std::optional<int64_t> height = std::nullopt,
257318
std::optional<int64_t> num_threads = std::nullopt,
258319
std::optional<std::string_view> dimension_order = std::nullopt,
259320
std::optional<int64_t> stream_index = std::nullopt,
260321
std::string_view device = "cpu",
261322
std::string_view device_variant = "default",
323+
std::string_view transform_specs = "",
262324
std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
263325
custom_frame_mappings = std::nullopt,
264326
std::optional<std::string_view> color_conversion_library = std::nullopt) {
265327
VideoStreamOptions videoStreamOptions;
266328
videoStreamOptions.ffmpegThreadCount = num_threads;
267329

268-
// TODO: Eliminate this temporary bridge code. This exists because we have
269-
// not yet exposed the transforms API on the Python side. We also want
270-
// to remove the `width` and `height` arguments from the Python API.
271-
//
272-
// TEMPORARY BRIDGE CODE START
273-
TORCH_CHECK(
274-
width.has_value() == height.has_value(),
275-
"width and height must both be set or unset.");
276-
std::vector<Transform*> transforms;
277-
if (width.has_value()) {
278-
transforms.push_back(
279-
new ResizeTransform(FrameDims(height.value(), width.value())));
280-
width.reset();
281-
height.reset();
282-
}
283-
// TEMPORARY BRIDGE CODE END
284-
285330
if (dimension_order.has_value()) {
286331
std::string stdDimensionOrder{dimension_order.value()};
287332
TORCH_CHECK(stdDimensionOrder == "NHWC" || stdDimensionOrder == "NCHW");
@@ -309,6 +354,9 @@ void _add_video_stream(
309354
videoStreamOptions.device = torch::Device(std::string(device));
310355
videoStreamOptions.deviceVariant = device_variant;
311356

357+
std::vector<Transform*> transforms =
358+
makeTransforms(std::string(transform_specs));
359+
312360
std::optional<SingleStreamDecoder::FrameMappings> converted_mappings =
313361
custom_frame_mappings.has_value()
314362
? std::make_optional(makeFrameMappings(custom_frame_mappings.value()))
@@ -324,24 +372,22 @@ void _add_video_stream(
324372
// Add a new video stream at `stream_index` using the provided options.
325373
void add_video_stream(
326374
at::Tensor& decoder,
327-
std::optional<int64_t> width = std::nullopt,
328-
std::optional<int64_t> height = std::nullopt,
329375
std::optional<int64_t> num_threads = std::nullopt,
330376
std::optional<std::string_view> dimension_order = std::nullopt,
331377
std::optional<int64_t> stream_index = std::nullopt,
332378
std::string_view device = "cpu",
333379
std::string_view device_variant = "default",
380+
std::string_view transform_specs = "",
334381
const std::optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>&
335382
custom_frame_mappings = std::nullopt) {
336383
_add_video_stream(
337384
decoder,
338-
width,
339-
height,
340385
num_threads,
341386
dimension_order,
342387
stream_index,
343388
device,
344389
device_variant,
390+
transform_specs,
345391
custom_frame_mappings);
346392
}
347393

src/torchcodec/_core/ops.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,12 @@ def create_from_tensor_abstract(
299299
def _add_video_stream_abstract(
300300
decoder: torch.Tensor,
301301
*,
302-
width: Optional[int] = None,
303-
height: Optional[int] = None,
304302
num_threads: Optional[int] = None,
305303
dimension_order: Optional[str] = None,
306304
stream_index: Optional[int] = None,
307305
device: str = "cpu",
308306
device_variant: str = "default",
307+
transform_specs: str = "",
309308
custom_frame_mappings: Optional[
310309
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
311310
] = None,
@@ -318,13 +317,12 @@ def _add_video_stream_abstract(
318317
def add_video_stream_abstract(
319318
decoder: torch.Tensor,
320319
*,
321-
width: Optional[int] = None,
322-
height: Optional[int] = None,
323320
num_threads: Optional[int] = None,
324321
dimension_order: Optional[str] = None,
325322
stream_index: Optional[int] = None,
326323
device: str = "cpu",
327324
device_variant: str = "default",
325+
transform_specs: str = "",
328326
custom_frame_mappings: Optional[
329327
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
330328
] = None,

src/torchcodec/_samplers/video_clip_sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,7 @@ def forward(self, video_data: Tensor) -> Union[List[Any]]:
147147
scan_all_streams_to_update_metadata(video_decoder)
148148
add_video_stream(
149149
video_decoder,
150-
width=target_width,
151-
height=target_height,
150+
transform_specs=f"resize, {target_height}, {target_width}",
152151
num_threads=self.decoder_args.num_threads,
153152
)
154153

test/test_ops.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -631,17 +631,15 @@ def test_color_conversion_library_with_scaling(
631631
filtergraph_decoder = create_from_file(str(input_video.path))
632632
_add_video_stream(
633633
filtergraph_decoder,
634-
width=target_width,
635-
height=target_height,
634+
transform_specs=f"resize, {target_height}, {target_width}",
636635
color_conversion_library="filtergraph",
637636
)
638637
filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder)
639638

640639
swscale_decoder = create_from_file(str(input_video.path))
641640
_add_video_stream(
642641
swscale_decoder,
643-
width=target_width,
644-
height=target_height,
642+
transform_specs=f"resize, {target_height}, {target_width}",
645643
color_conversion_library="swscale",
646644
)
647645
swscale_frame0, _, _ = get_next_frame(swscale_decoder)
@@ -655,7 +653,53 @@ def test_scaling_on_cuda_fails(self):
655653
RuntimeError,
656654
match="Transforms are only supported for CPU devices.",
657655
):
658-
add_video_stream(decoder, device="cuda", width=100, height=100)
656+
add_video_stream(decoder, device="cuda", transform_specs="resize, 100, 100")
657+
658+
def test_transform_fails(self):
659+
decoder = create_from_file(str(NASA_VIDEO.path))
660+
with pytest.raises(
661+
RuntimeError,
662+
match="Invalid transform spec",
663+
):
664+
add_video_stream(decoder, transform_specs=";")
665+
666+
with pytest.raises(
667+
RuntimeError,
668+
match="Invalid transform name",
669+
):
670+
add_video_stream(decoder, transform_specs="invalid, 1, 2")
671+
672+
def test_resize_transform_fails(self):
673+
decoder = create_from_file(str(NASA_VIDEO.path))
674+
with pytest.raises(
675+
RuntimeError,
676+
match="must have 3 elements",
677+
):
678+
add_video_stream(decoder, transform_specs="resize, 100, 100, 100")
679+
680+
with pytest.raises(
681+
RuntimeError,
682+
match="must be a positive integer",
683+
):
684+
add_video_stream(decoder, transform_specs="resize, -10, 100")
685+
686+
with pytest.raises(
687+
RuntimeError,
688+
match="must be a positive integer",
689+
):
690+
add_video_stream(decoder, transform_specs="resize, 100, 0")
691+
692+
with pytest.raises(
693+
RuntimeError,
694+
match="cannot be converted to an int",
695+
):
696+
add_video_stream(decoder, transform_specs="resize, blah, 100")
697+
698+
with pytest.raises(
699+
RuntimeError,
700+
match="out of range",
701+
):
702+
add_video_stream(decoder, transform_specs="resize, 100, 1000000000000")
659703

660704
@pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW"))
661705
@pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale"))
@@ -763,17 +807,15 @@ def test_color_conversion_library_with_generated_videos(
763807
filtergraph_decoder = create_from_file(str(video_path))
764808
_add_video_stream(
765809
filtergraph_decoder,
766-
width=target_width,
767-
height=target_height,
810+
transform_specs=f"resize, {target_height}, {target_width}",
768811
color_conversion_library="filtergraph",
769812
)
770813
filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder)
771814

772815
auto_decoder = create_from_file(str(video_path))
773816
add_video_stream(
774817
auto_decoder,
775-
width=target_width,
776-
height=target_height,
818+
transform_specs=f"resize, {target_height}, {target_width}",
777819
)
778820
auto_frame0, _, _ = get_next_frame(auto_decoder)
779821
assert_frames_equal(filtergraph_frame0, auto_frame0)

0 commit comments

Comments
 (0)