Skip to content

Commit c381116

Browse files
committed
More changes
1 parent 071a7b3 commit c381116

File tree

5 files changed

+49
-22
lines changed

5 files changed

+49
-22
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,6 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM)
650650
list(APPEND _executorch_extensions tokenizers)
651651
endif()
652652

653-
654653
if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE)
655654
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple)
656655
endif()

examples/models/llava/export_llava.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,12 @@ def export_all(llava_model: LlavaModel):
224224

225225
lowered_and_edge = to_edge_transform_and_lower(
226226
{
227-
"image_encoder": image_encoder_ep,
227+
"vision_encoder": image_encoder_ep,
228228
"token_embedding": token_embedding_ep,
229229
"text_decoder": text_model_ep,
230230
},
231231
partitioner={
232-
"image_encoder": [XnnpackPartitioner()],
232+
"vision_encoder": [XnnpackPartitioner()],
233233
"text_decoder": [
234234
# First partition the DQLinear nodes, then partition the rest of the nodes,
235235
# to avoid multiple DQLinear nodes in the same partition,
@@ -254,7 +254,7 @@ def export_all(llava_model: LlavaModel):
254254
],
255255
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
256256
sym_shape_eval_pass={
257-
"image_encoder": ConstraintBasedSymShapeEvalPass(),
257+
"vision_encoder": ConstraintBasedSymShapeEvalPass(),
258258
"text_decoder": ConstraintBasedSymShapeEvalPass(),
259259
"token_embedding": HintBasedSymShapeEvalPass(),
260260
},

extension/llm/runner/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ if(EXECUTORCH_BUILD_PYBIND)
9393
)
9494
# Link with the extension_llm_runner library and its dependencies
9595
target_link_libraries(
96-
_llm_runner PRIVATE extension_llm_runner tokenizers::tokenizers portable_lib
96+
_llm_runner PRIVATE extension_llm_runner tokenizers::tokenizers
97+
portable_lib
9798
)
9899

99100
# Set properties for the Python extension
@@ -105,7 +106,9 @@ if(EXECUTORCH_BUILD_PYBIND)
105106
)
106107

107108
# Add include directories
108-
target_include_directories(_llm_runner PRIVATE ${_common_include_directories} ${TORCH_INCLUDE_DIRS})
109+
target_include_directories(
110+
_llm_runner PRIVATE ${_common_include_directories} ${TORCH_INCLUDE_DIRS}
111+
)
109112

110113
install(TARGETS _llm_runner
111114
LIBRARY DESTINATION executorch/extension/llm/runner

extension/llm/runner/constants.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ inline constexpr auto kUseKVCache = "use_kv_cache";
2020
inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
2121

2222
// Multimodal method name conventions
23-
inline constexpr auto kImageEncoderMethod = "image_encoder";
23+
inline constexpr auto kImageEncoderMethod = "vision_encoder";
2424
inline constexpr auto kAudioEncoderMethod = "audio_encoder";
2525
inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
2626
inline constexpr auto kTextModelMethod = "text_decoder";

extension/llm/runner/pybindings.cpp

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -219,15 +219,42 @@ PYBIND11_MODULE(_llm_runner, m) {
219219

220220
// Bind Image class
221221
py::class_<Image>(m, "Image")
222-
.def(py::init<>())
223-
.def_readwrite("data", &Image::data)
224-
.def_readwrite("width", &Image::width)
225-
.def_readwrite("height", &Image::height)
226-
.def_readwrite("channels", &Image::channels)
222+
.def(
223+
py::init<std::vector<uint8_t>&&, int32_t, int32_t, int32_t>(),
224+
py::arg("data"),
225+
py::arg("width"),
226+
py::arg("height"),
227+
py::arg("channels"))
228+
.def(
229+
py::init<std::vector<float>&&, int32_t, int32_t, int32_t>(),
230+
py::arg("data"),
231+
py::arg("width"),
232+
py::arg("height"),
233+
py::arg("channels"))
234+
.def("is_uint8", &Image::is_uint8)
235+
.def("is_float", &Image::is_float)
236+
.def_property_readonly("width", &Image::width)
237+
.def_property_readonly("height", &Image::height)
238+
.def_property_readonly("channels", &Image::channels)
239+
.def_property_readonly(
240+
"uint8_data",
241+
static_cast<const std::vector<uint8_t>& (Image::*)() const&>(
242+
&Image::get_uint8_data))
243+
.def_property_readonly(
244+
"float_data",
245+
static_cast<const std::vector<float>& (Image::*)() const&>(
246+
&Image::get_float_data))
227247
.def("__repr__", [](const Image& img) {
228-
return "<Image height=" + std::to_string(img.height) +
229-
" width=" + std::to_string(img.width) +
230-
" channels=" + std::to_string(img.channels) + ">";
248+
std::string dtype = "unknown";
249+
if (img.is_uint8()) {
250+
dtype = "uint8";
251+
} else if (img.is_float()) {
252+
dtype = "float32";
253+
}
254+
return "<Image height=" + std::to_string(img.height()) +
255+
" width=" + std::to_string(img.width()) +
256+
" channels=" + std::to_string(img.channels()) + " dtype=" + dtype +
257+
">";
231258
});
232259

233260
// Bind MultimodalInput
@@ -281,7 +308,6 @@ PYBIND11_MODULE(_llm_runner, m) {
281308
image_tensor = image_tensor.squeeze(0);
282309
}
283310

284-
285311
if (image_tensor.dim() != 3) {
286312
throw std::runtime_error(
287313
"Image tensor must be 3-dimensional (H, W, C) or 4-dimensional (1, H, W, C)");
@@ -322,12 +348,11 @@ PYBIND11_MODULE(_llm_runner, m) {
322348
uint8_t* data = image_tensor.data_ptr<uint8_t>();
323349
std::vector<uint8_t> image_data(data, data + image_tensor.numel());
324350

325-
Image image;
326-
image.data = std::move(image_data);
327-
image.width = static_cast<int32_t>(width);
328-
image.height = static_cast<int32_t>(height);
329-
image.channels = static_cast<int32_t>(channels);
330-
return MultimodalInput(std::move(image));
351+
return MultimodalInput(Image(
352+
std::move(image_data),
353+
static_cast<int32_t>(width),
354+
static_cast<int32_t>(height),
355+
static_cast<int32_t>(channels)));
331356
},
332357
"Create an image input from a torch tensor (H, W, C), (1, H, W, C), (C, H, W), or (1, C, H, W)",
333358
py::arg("image_tensor"));

0 commit comments

Comments
 (0)