1010#include < pybind11/numpy.h>
1111#include < pybind11/pybind11.h>
1212#include < pybind11/stl.h>
13+ #include < torch/python.h>
1314
1415#include < executorch/extension/llm/runner/llm_runner_helper.h>
1516#include < executorch/extension/llm/runner/multimodal_input.h>
@@ -271,27 +272,55 @@ PYBIND11_MODULE(_llm_runner, m) {
271272
272273 m.def (
273274 " make_image_input" ,
274- [](py::array_t <uint8_t > image_array) -> MultimodalInput {
275- // Get image dimensions
276- py::buffer_info buf = image_array.request ();
275+ [](torch::Tensor image_tensor) -> MultimodalInput {
276+ if (image_tensor.dim () == 4 ) {
277+ if (image_tensor.size (0 ) != 1 ) {
278+ throw std::runtime_error (
279+ " Batch size for 4D image tensor must be 1" );
280+ }
281+ image_tensor = image_tensor.squeeze (0 );
282+ }
277283
278- if (buf.ndim != 3 ) {
284+
285+ if (image_tensor.dim () != 3 ) {
279286 throw std::runtime_error (
280- " Image array must be 3-dimensional (H, W, C)" );
287+ " Image tensor must be 3-dimensional (H, W, C) or 4-dimensional (1, H, W, C)" );
281288 }
282289
283- size_t height = buf.shape [0 ];
284- size_t width = buf.shape [1 ];
285- size_t channels = buf.shape [2 ];
290+ int64_t height, width, channels;
291+ // Check for memory format and permute to CHW if necessary
292+ if (image_tensor.is_contiguous (at::MemoryFormat::ChannelsLast)) {
293+ // Input is HWC, permute to CHW
294+ height = image_tensor.size (0 );
295+ width = image_tensor.size (1 );
296+ channels = image_tensor.size (2 );
297+ image_tensor = image_tensor.permute ({2 , 0 , 1 });
298+ } else if (image_tensor.is_contiguous (at::MemoryFormat::Contiguous)) {
299+ // Input is CHW
300+ channels = image_tensor.size (0 );
301+ height = image_tensor.size (1 );
302+ width = image_tensor.size (2 );
303+ } else {
304+ throw std::runtime_error (
305+ " Image tensor must be contiguous in either channels last (H, W, C) or contiguous (C, H, W) format." );
306+ }
286307
287308 if (channels != 3 && channels != 4 ) {
288309 throw std::runtime_error (
289310 " Image must have 3 (RGB) or 4 (RGBA) channels" );
290311 }
291312
292- // Create Image object from numpy array
293- uint8_t * data = static_cast <uint8_t *>(buf.ptr );
294- std::vector<uint8_t > image_data (data, data + height * width * channels);
313+ if (image_tensor.scalar_type () != torch::kUInt8 ) {
314+ if (image_tensor.max ().item <double >() <= 1.0 ) {
315+ image_tensor = (image_tensor * 255 ).to (torch::kUInt8 );
316+ } else {
317+ image_tensor = image_tensor.to (torch::kUInt8 );
318+ }
319+ }
320+
321+ image_tensor = image_tensor.contiguous ();
322+ uint8_t * data = image_tensor.data_ptr <uint8_t >();
323+ std::vector<uint8_t > image_data (data, data + image_tensor.numel ());
295324
296325 Image image;
297326 image.data = std::move (image_data);
@@ -300,8 +329,8 @@ PYBIND11_MODULE(_llm_runner, m) {
300329 image.channels = static_cast <int32_t >(channels);
301330 return MultimodalInput (std::move (image));
302331 },
303- " Create an image input from a numpy array (H, W, C)" ,
304- py::arg (" image_array " ));
332+ " Create an image input from a torch tensor (H, W, C), (1, H, W, C), (C, H, W), or (1, C, H, W )" ,
333+ py::arg (" image_tensor " ));
305334
306335 // Bind PyMultimodalRunner
307336 py::class_<PyMultimodalRunner>(m, " MultimodalRunner" )
0 commit comments