@@ -277,6 +277,14 @@ PYBIND11_MODULE(_llm_runner, m) {
277277 }
278278 return py::none ();
279279 })
280+ .def (
281+ " get_image" ,
282+ [](const MultimodalInput& input) -> py::object {
283+ if (input.is_image ()) {
284+ return py::cast (input.get_image ());
285+ }
286+ return py::none ();
287+ })
280288 .def (" __repr__" , [](const MultimodalInput& input) -> std::string {
281289 if (input.is_text ()) {
282290 return " <MultimodalInput type=text content=\" " +
@@ -336,23 +344,27 @@ PYBIND11_MODULE(_llm_runner, m) {
336344 " Image must have 3 (RGB) or 4 (RGBA) channels" );
337345 }
338346
339- if (image_tensor.scalar_type () != torch::kUInt8 ) {
340- if (image_tensor.max ().item <double >() <= 1.0 ) {
341- image_tensor = (image_tensor * 255 ).to (torch::kUInt8 );
342- } else {
343- image_tensor = image_tensor.to (torch::kUInt8 );
344- }
345- }
346-
347347 image_tensor = image_tensor.contiguous ();
348- uint8_t * data = image_tensor.data_ptr <uint8_t >();
349- std::vector<uint8_t > image_data (data, data + image_tensor.numel ());
350-
351- return MultimodalInput (Image (
352- std::move (image_data),
353- static_cast <int32_t >(width),
354- static_cast <int32_t >(height),
355- static_cast <int32_t >(channels)));
348+ if (image_tensor.scalar_type () == torch::kUInt8 ) {
349+ uint8_t * data = image_tensor.data_ptr <uint8_t >();
350+ std::vector<uint8_t > image_data (data, data + image_tensor.numel ());
351+ return MultimodalInput (Image (
352+ std::move (image_data),
353+ static_cast <int32_t >(width),
354+ static_cast <int32_t >(height),
355+ static_cast <int32_t >(channels)));
356+ } else if (image_tensor.scalar_type () == torch::kFloat ) {
357+ float * data = image_tensor.data_ptr <float >();
358+ std::vector<float > image_data (data, data + image_tensor.numel ());
359+ return MultimodalInput (Image (
360+ std::move (image_data),
361+ static_cast <int32_t >(width),
362+ static_cast <int32_t >(height),
363+ static_cast <int32_t >(channels)));
364+ } else {
365+ throw std::runtime_error (
366+ " Unsupported image tensor dtype. Only uint8 and float32 are supported." );
367+ }
356368 },
357369 " Create an image input from a torch tensor (H, W, C), (1, H, W, C), (C, H, W), or (1, C, H, W)" ,
358370 py::arg (" image_tensor" ));
0 commit comments