Skip to content

Commit 0baae40

Browse files
committed
make_image_input take tensor
1 parent 70b37df commit 0baae40

File tree

3 files changed

+53
-20
lines changed

3 files changed

+53
-20
lines changed

CMakeLists.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -650,10 +650,6 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM)
650650
list(APPEND _executorch_extensions tokenizers)
651651
endif()
652652

653-
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
654-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner)
655-
list(APPEND _executorch_extensions extension_llm_runner)
656-
endif()
657653

658654
if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE)
659655
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple)
@@ -904,6 +900,11 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
904900
list(APPEND _executorch_extensions extension_training)
905901
endif()
906902

903+
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
904+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner)
905+
list(APPEND _executorch_extensions extension_llm_runner)
906+
endif()
907+
907908
if(EXECUTORCH_BUILD_KERNELS_LLM)
908909
# TODO: move all custom kernels to ${CMAKE_CURRENT_SOURCE_DIR}/kernels/custom
909910
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/custom_ops)

extension/llm/runner/CMakeLists.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,13 @@ if(EXECUTORCH_BUILD_PYBIND)
8787
_llm_runner SHARED ${CMAKE_CURRENT_SOURCE_DIR}/pybindings.cpp
8888
)
8989

90+
find_package_torch()
91+
find_library(
92+
TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib"
93+
)
9094
# Link with the extension_llm_runner library and its dependencies
9195
target_link_libraries(
92-
_llm_runner PRIVATE extension_llm_runner executorch_core extension_module
93-
extension_tensor tokenizers::tokenizers
96+
_llm_runner PRIVATE extension_llm_runner tokenizers::tokenizers portable_lib
9497
)
9598

9699
# Set properties for the Python extension
@@ -102,7 +105,7 @@ if(EXECUTORCH_BUILD_PYBIND)
102105
)
103106

104107
# Add include directories
105-
target_include_directories(_llm_runner PRIVATE ${_common_include_directories})
108+
target_include_directories(_llm_runner PRIVATE ${_common_include_directories} ${TORCH_INCLUDE_DIRS})
106109

107110
install(TARGETS _llm_runner
108111
LIBRARY DESTINATION executorch/extension/llm/runner

extension/llm/runner/pybindings.cpp

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <pybind11/numpy.h>
1111
#include <pybind11/pybind11.h>
1212
#include <pybind11/stl.h>
13+
#include <torch/python.h>
1314

1415
#include <executorch/extension/llm/runner/llm_runner_helper.h>
1516
#include <executorch/extension/llm/runner/multimodal_input.h>
@@ -271,27 +272,55 @@ PYBIND11_MODULE(_llm_runner, m) {
271272

272273
m.def(
273274
"make_image_input",
274-
[](py::array_t<uint8_t> image_array) -> MultimodalInput {
275-
// Get image dimensions
276-
py::buffer_info buf = image_array.request();
275+
[](torch::Tensor image_tensor) -> MultimodalInput {
276+
if (image_tensor.dim() == 4) {
277+
if (image_tensor.size(0) != 1) {
278+
throw std::runtime_error(
279+
"Batch size for 4D image tensor must be 1");
280+
}
281+
image_tensor = image_tensor.squeeze(0);
282+
}
277283

278-
if (buf.ndim != 3) {
284+
285+
if (image_tensor.dim() != 3) {
279286
throw std::runtime_error(
280-
"Image array must be 3-dimensional (H, W, C)");
287+
"Image tensor must be 3-dimensional (H, W, C) or 4-dimensional (1, H, W, C)");
281288
}
282289

283-
size_t height = buf.shape[0];
284-
size_t width = buf.shape[1];
285-
size_t channels = buf.shape[2];
290+
int64_t height, width, channels;
291+
// Check for memory format and permute to CHW if necessary
292+
if (image_tensor.is_contiguous(at::MemoryFormat::ChannelsLast)) {
293+
// Input is HWC, permute to CHW
294+
height = image_tensor.size(0);
295+
width = image_tensor.size(1);
296+
channels = image_tensor.size(2);
297+
image_tensor = image_tensor.permute({2, 0, 1});
298+
} else if (image_tensor.is_contiguous(at::MemoryFormat::Contiguous)) {
299+
// Input is CHW
300+
channels = image_tensor.size(0);
301+
height = image_tensor.size(1);
302+
width = image_tensor.size(2);
303+
} else {
304+
throw std::runtime_error(
305+
"Image tensor must be contiguous in either channels last (H, W, C) or contiguous (C, H, W) format.");
306+
}
286307

287308
if (channels != 3 && channels != 4) {
288309
throw std::runtime_error(
289310
"Image must have 3 (RGB) or 4 (RGBA) channels");
290311
}
291312

292-
// Create Image object from numpy array
293-
uint8_t* data = static_cast<uint8_t*>(buf.ptr);
294-
std::vector<uint8_t> image_data(data, data + height * width * channels);
313+
if (image_tensor.scalar_type() != torch::kUInt8) {
314+
if (image_tensor.max().item<double>() <= 1.0) {
315+
image_tensor = (image_tensor * 255).to(torch::kUInt8);
316+
} else {
317+
image_tensor = image_tensor.to(torch::kUInt8);
318+
}
319+
}
320+
321+
image_tensor = image_tensor.contiguous();
322+
uint8_t* data = image_tensor.data_ptr<uint8_t>();
323+
std::vector<uint8_t> image_data(data, data + image_tensor.numel());
295324

296325
Image image;
297326
image.data = std::move(image_data);
@@ -300,8 +329,8 @@ PYBIND11_MODULE(_llm_runner, m) {
300329
image.channels = static_cast<int32_t>(channels);
301330
return MultimodalInput(std::move(image));
302331
},
303-
"Create an image input from a numpy array (H, W, C)",
304-
py::arg("image_array"));
332+
"Create an image input from a torch tensor (H, W, C), (1, H, W, C), (C, H, W), or (1, C, H, W)",
333+
py::arg("image_tensor"));
305334

306335
// Bind PyMultimodalRunner
307336
py::class_<PyMultimodalRunner>(m, "MultimodalRunner")

0 commit comments

Comments
 (0)