diff --git a/examples/models/llava/main.cpp b/examples/models/llava/main.cpp index 6cb84aa088e..3946a629ade 100644 --- a/examples/models/llava/main.cpp +++ b/examples/models/llava/main.cpp @@ -81,24 +81,20 @@ void load_image(const std::string& image_path, Image& image) { new_height, 0, channels); - // transpose to CHW - image.data.resize(channels * new_width * new_height); + std::vector chw_data(channels * new_width * new_height); for (int i = 0; i < new_width * new_height; ++i) { for (int c = 0; c < channels; ++c) { - image.data[c * new_width * new_height + i] = - resized_data[i * channels + c]; + chw_data[c * new_width * new_height + i] = resized_data[i * channels + c]; } } - image.width = new_width; - image.height = new_height; - image.channels = channels; + image = Image(std::move(chw_data), new_width, new_height, channels); // convert to tensor ET_LOG( Info, "image Channels: %" PRId32 ", Height: %" PRId32 ", Width: %" PRId32, - image.channels, - image.height, - image.width); + image.channels(), + image.height(), + image.width()); stbi_image_free(data); } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 23686f01ee7..cabf30c42e4 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -268,7 +268,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { for (int i = 0; i < image_size; i++) { image_data[i] = image_data_jint[i]; } - llm::Image image_runner{image_data, width, height, channels}; + llm::Image image_runner{std::move(image_data), width, height, channels}; prefill_inputs_.emplace_back( llm::MultimodalInput{std::move(image_runner)}); } diff --git a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm index dcc5dc98806..b95e480aded 100644 --- a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm +++ b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm @@ -172,12 +172,12 @@ - (BOOL)generate:(NSArray *)inputs case ExecuTorchLLMMultimodalInputTypeImage: { ExecuTorchLLMImage *image = input.image; std::vector data((uint8_t *)image.data.bytes, (uint8_t *)image.data.bytes + image.data.length); - nativeInputs.emplace_back(llm::MultimodalInput(llm::Image{ - .data = std::move(data), - .width = (int32_t)image.width, - .height = (int32_t)image.height, - .channels = (int32_t)image.channels - })); + nativeInputs.emplace_back(llm::MultimodalInput(llm::Image( + std::move(data), + (int32_t)image.width, + (int32_t)image.height, + (int32_t)image.channels + ))); break; } default: { diff --git a/extension/llm/runner/image.h b/extension/llm/runner/image.h index 67fb8939518..dbdba273536 100644 --- a/extension/llm/runner/image.h +++ b/extension/llm/runner/image.h @@ -10,19 +10,112 @@ #pragma once #include +#include #include +#include #include +#include +#include + namespace executorch { namespace extension { namespace llm { -struct ET_EXPERIMENTAL Image { +class ET_EXPERIMENTAL Image { + public: + // Default constructor + Image() : width_(0), height_(0), channels_(0) {} + + // Constructor for uint8_t data + Image( + std::vector&& data, + int32_t width, + int32_t height, + int32_t channels) + : data_(std::move(data)), + width_(width), + height_(height), + channels_(channels) {} + + // Constructor for float data + Image( + std::vector&& data, + int32_t width, + int32_t height, + int32_t channels) + : data_(std::move(data)), + width_(width), + height_(height), + channels_(channels) {} + + // Getters + int32_t width() const { + return width_; + } + int32_t height() const { + return height_; + } + int32_t channels() const { + return channels_; + } + + // Data access + bool is_uint8() const { + return std::holds_alternative>(data_); + } + + bool is_float() const { + return std::holds_alternative>(data_); + } + + const std::vector& get_uint8_data() const& { + return std::get>(data_); + } + + std::vector& get_uint8_data() & { + return std::get>(data_); + } + + const std::vector& get_float_data() const& { + return std::get>(data_); + } + + std::vector& get_float_data() & { + return std::get>(data_); + } + + executorch::runtime::Result toTensor( + bool with_batch = false) const { + // Note: This creates a 3D tensor (CHW). The model might expect a 4D + // tensor (NCHW). The caller should handle reshaping if needed. + std::vector sizes = { + channels(), height(), width()}; + if (with_batch) { + sizes.insert(sizes.begin(), 1); + } + if (is_float()) { + return executorch::extension::from_blob( + const_cast(get_float_data().data()), + sizes, + ::executorch::aten::ScalarType::Float); + } else if (is_uint8()) { + return executorch::extension::from_blob( + const_cast(get_uint8_data().data()), + sizes, + ::executorch::aten::ScalarType::Byte); + } + ET_LOG( + Error, "Image data is not initialized with uint8_t or float vector."); + return ::executorch::runtime::Error::NotSupported; + } + + private: // Assuming NCHW format - std::vector data; - int32_t width; - int32_t height; - int32_t channels; + std::variant, std::vector> data_; + int32_t width_; + int32_t height_; + int32_t channels_; }; } // namespace llm diff --git a/extension/llm/runner/multimodal_prefiller.cpp b/extension/llm/runner/multimodal_prefiller.cpp index 2705a9eadff..3f8777d4acf 100644 --- a/extension/llm/runner/multimodal_prefiller.cpp +++ b/extension/llm/runner/multimodal_prefiller.cpp @@ -41,10 +41,42 @@ Result MultimodalPrefiller::prefill( ::executorch::runtime::EValue encoder_output; if (input.is_image()) { Image image = input.get_image(); - auto image_tensor = executorch::extension::from_blob( - image.data.data(), - {3, image.height, image.width}, - ::executorch::aten::ScalarType::Byte); + + auto method_meta = ET_UNWRAP( + module_->method_meta(kImageEncoderMethod), + "Failed to get method_meta for %s", + kImageEncoderMethod); + + ET_CHECK_MSG( + method_meta.num_inputs() > 0, + "Image encoder should have at least 1 input"); + auto input_meta = ET_UNWRAP( + method_meta.input_tensor_meta(0), + "Cannot get input tensor meta at index 0"); + auto expected_dtype = input_meta.scalar_type(); + + if (expected_dtype == ::executorch::aten::ScalarType::Float) { + ET_CHECK_MSG( + image.is_float(), + "Model expects float image data, but image has uint8_t data."); + } else if (expected_dtype == ::executorch::aten::ScalarType::Byte) { + ET_CHECK_MSG( + image.is_uint8(), + "Model expects uint8_t image data, but image has float data."); + } else { + ET_LOG( + Error, + "Unsupported image encoder input dtype: %s", + ::executorch::runtime::toString(expected_dtype)); + return ::executorch::runtime::Error::NotSupported; + } + + // The model might expect a 4D tensor (NCHW), but toTensor() returns a 3D + // tensor (CHW). Add a batch dimension of 1 if needed. + auto expected_dims = input_meta.sizes(); + auto image_tensor = ET_UNWRAP( + image.toTensor(/*with_batch*/ expected_dims.size() == 4), + "Failed to convert image to tensor"); // Run image encoder auto image_encoder_outputs = diff --git a/extension/llm/runner/test/test_multimodal_input.cpp b/extension/llm/runner/test/test_multimodal_input.cpp index 97b9cc1379e..486515175e8 100644 --- a/extension/llm/runner/test/test_multimodal_input.cpp +++ b/extension/llm/runner/test/test_multimodal_input.cpp @@ -16,7 +16,6 @@ using executorch::extension::llm::make_image_input; using executorch::extension::llm::make_text_input; using executorch::extension::llm::MultimodalInput; -namespace { class MultimodalInputTest : public Test { protected: std::string createTestText() { @@ -28,21 +27,13 @@ class MultimodalInputTest : public Test { } Image createTestImage() { - Image img; - img.width = 224; - img.height = 224; - img.channels = 3; - img.data = std::vector(224 * 224 * 3, 128); // Fill with gray - return img; + std::vector data(224 * 224 * 3, 128); // Fill with gray + return Image(std::move(data), 224, 224, 3); } Image createTestImageSmall() { - Image img; - img.width = 32; - img.height = 32; - img.channels = 1; - img.data = std::vector(32 * 32, 255); // Fill with white - return img; + std::vector data(32 * 32, 255); // Fill with white + return Image(std::move(data), 32, 32, 1); } }; @@ -76,28 +67,28 @@ TEST_F(MultimodalInputTest, ImageConstructorFromImage) { EXPECT_FALSE(input.is_text()); EXPECT_TRUE(input.is_image()); EXPECT_EQ(input.get_type(), MultimodalInput::Type::IMAGE); - EXPECT_EQ(input.get_image().width, 224); - EXPECT_EQ(input.get_image().height, 224); - EXPECT_EQ(input.get_image().channels, 3); - EXPECT_EQ(input.get_image().data.size(), 224 * 224 * 3); + EXPECT_EQ(input.get_image().width(), 224); + EXPECT_EQ(input.get_image().height(), 224); + EXPECT_EQ(input.get_image().channels(), 3); + EXPECT_EQ(input.get_image().get_uint8_data().size(), 224 * 224 * 3); } TEST_F(MultimodalInputTest, ImageConstructorFromRvalueImage) { Image img = createTestImage(); - int width = img.width; - int height = img.height; - int channels = img.channels; - size_t data_size = img.data.size(); + int width = img.width(); + int height = img.height(); + int channels = img.channels(); + size_t data_size = img.get_uint8_data().size(); MultimodalInput input(std::move(img)); EXPECT_FALSE(input.is_text()); EXPECT_TRUE(input.is_image()); EXPECT_EQ(input.get_type(), MultimodalInput::Type::IMAGE); - EXPECT_EQ(input.get_image().width, width); - EXPECT_EQ(input.get_image().height, height); - EXPECT_EQ(input.get_image().channels, channels); - EXPECT_EQ(input.get_image().data.size(), data_size); + EXPECT_EQ(input.get_image().width(), width); + EXPECT_EQ(input.get_image().height(), height); + EXPECT_EQ(input.get_image().channels(), channels); + EXPECT_EQ(input.get_image().get_uint8_data().size(), data_size); } // Test copy constructor and assignment @@ -129,10 +120,10 @@ TEST_F(MultimodalInputTest, CopyConstructorImage) { MultimodalInput copy(original); EXPECT_TRUE(copy.is_image()); - EXPECT_EQ(copy.get_image().width, 224); - EXPECT_EQ(copy.get_image().height, 224); - EXPECT_EQ(copy.get_image().channels, 3); - EXPECT_EQ(original.get_image().width, 224); // Original should be unchanged + EXPECT_EQ(copy.get_image().width(), 224); + EXPECT_EQ(copy.get_image().height(), 224); + EXPECT_EQ(copy.get_image().channels(), 3); + EXPECT_EQ(original.get_image().width(), 224); // Original should be unchanged } TEST_F(MultimodalInputTest, CopyAssignmentImage) { @@ -143,10 +134,10 @@ TEST_F(MultimodalInputTest, CopyAssignmentImage) { copy = original; EXPECT_TRUE(copy.is_image()); - EXPECT_EQ(copy.get_image().width, 224); - EXPECT_EQ(copy.get_image().height, 224); - EXPECT_EQ(copy.get_image().channels, 3); - EXPECT_EQ(original.get_image().width, 224); // Original should be unchanged + EXPECT_EQ(copy.get_image().width(), 224); + EXPECT_EQ(copy.get_image().height(), 224); + EXPECT_EQ(copy.get_image().channels(), 3); + EXPECT_EQ(original.get_image().width(), 224); // Original should be unchanged } // Test move constructor and assignment @@ -174,32 +165,32 @@ TEST_F(MultimodalInputTest, MoveAssignmentText) { TEST_F(MultimodalInputTest, MoveConstructorImage) { Image img = createTestImage(); - int width = img.width; - int height = img.height; - int channels = img.channels; + int width = img.width(); + int height = img.height(); + int channels = img.channels(); MultimodalInput original(std::move(img)); MultimodalInput moved(std::move(original)); EXPECT_TRUE(moved.is_image()); - EXPECT_EQ(moved.get_image().width, width); - EXPECT_EQ(moved.get_image().height, height); - EXPECT_EQ(moved.get_image().channels, channels); + EXPECT_EQ(moved.get_image().width(), width); + EXPECT_EQ(moved.get_image().height(), height); + EXPECT_EQ(moved.get_image().channels(), channels); } TEST_F(MultimodalInputTest, MoveAssignmentImage) { Image img = createTestImage(); - int width = img.width; - int height = img.height; - int channels = img.channels; + int width = img.width(); + int height = img.height(); + int channels = img.channels(); MultimodalInput original(std::move(img)); MultimodalInput moved(createTestText()); // Start with different type moved = std::move(original); EXPECT_TRUE(moved.is_image()); - EXPECT_EQ(moved.get_image().width, width); - EXPECT_EQ(moved.get_image().height, height); - EXPECT_EQ(moved.get_image().channels, channels); + EXPECT_EQ(moved.get_image().width(), width); + EXPECT_EQ(moved.get_image().height(), height); + EXPECT_EQ(moved.get_image().channels(), channels); } // Test getter methods with correct types @@ -227,16 +218,13 @@ TEST_F(MultimodalInputTest, GetImageWithImageInput) { // Test const lvalue reference version const MultimodalInput& const_input = input; - EXPECT_EQ(const_input.get_image().width, 224); - - // Test mutable lvalue reference version - Image& mutable_image = input.get_image(); - mutable_image.width = 448; - EXPECT_EQ(input.get_image().width, 448); + EXPECT_EQ(const_input.get_image().width(), 224); + EXPECT_EQ(const_input.get_image().height(), 224); + EXPECT_EQ(const_input.get_image().channels(), 3); // Test rvalue reference version Image moved_image = std::move(input).get_image(); - EXPECT_EQ(moved_image.width, 448); + EXPECT_EQ(moved_image.width(), 224); } // Test getter methods with wrong types (should throw) @@ -296,18 +284,14 @@ TEST_F(MultimodalInputTest, TryGetImageWithImageInput) { const MultimodalInput& const_input = input; const Image* image_ptr = const_input.try_get_image(); ASSERT_NE(image_ptr, nullptr); - EXPECT_EQ(image_ptr->width, 224); - EXPECT_EQ(image_ptr->height, 224); - EXPECT_EQ(image_ptr->channels, 3); + EXPECT_EQ(image_ptr->width(), 224); + EXPECT_EQ(image_ptr->height(), 224); + EXPECT_EQ(image_ptr->channels(), 3); // Test mutable version Image* mutable_image_ptr = input.try_get_image(); ASSERT_NE(mutable_image_ptr, nullptr); - EXPECT_EQ(mutable_image_ptr->width, 224); - - // Modify through pointer - mutable_image_ptr->width = 448; - EXPECT_EQ(input.get_image().width, 448); + EXPECT_EQ(mutable_image_ptr->width(), 224); } TEST_F(MultimodalInputTest, TryGetImageWithTextInput) { @@ -344,22 +328,22 @@ TEST_F(MultimodalInputTest, MakeImageInputFromImage) { MultimodalInput input = make_image_input(img); EXPECT_TRUE(input.is_image()); - EXPECT_EQ(input.get_image().width, 224); - EXPECT_EQ(input.get_image().height, 224); - EXPECT_EQ(input.get_image().channels, 3); + EXPECT_EQ(input.get_image().width(), 224); + EXPECT_EQ(input.get_image().height(), 224); + EXPECT_EQ(input.get_image().channels(), 3); } TEST_F(MultimodalInputTest, MakeImageInputFromRvalueImage) { Image img = createTestImage(); - int width = img.width; - int height = img.height; - int channels = img.channels; + int width = img.width(); + int height = img.height(); + int channels = img.channels(); MultimodalInput input = make_image_input(std::move(img)); EXPECT_TRUE(input.is_image()); - EXPECT_EQ(input.get_image().width, width); - EXPECT_EQ(input.get_image().height, height); - EXPECT_EQ(input.get_image().channels, channels); + EXPECT_EQ(input.get_image().width(), width); + EXPECT_EQ(input.get_image().height(), height); + EXPECT_EQ(input.get_image().channels(), channels); } // Test with different image sizes @@ -368,10 +352,10 @@ TEST_F(MultimodalInputTest, DifferentImageSizes) { MultimodalInput input(small_img); EXPECT_TRUE(input.is_image()); - EXPECT_EQ(input.get_image().width, 32); - EXPECT_EQ(input.get_image().height, 32); - EXPECT_EQ(input.get_image().channels, 1); - EXPECT_EQ(input.get_image().data.size(), 32 * 32); + EXPECT_EQ(input.get_image().width(), 32); + EXPECT_EQ(input.get_image().height(), 32); + EXPECT_EQ(input.get_image().channels(), 1); + EXPECT_EQ(input.get_image().get_uint8_data().size(), 32 * 32); } // Test with empty text @@ -424,11 +408,10 @@ TEST_F(MultimodalInputTest, AssignmentBetweenTypes) { // Assign image to text input input = MultimodalInput(img); EXPECT_TRUE(input.is_image()); - EXPECT_EQ(input.get_image().width, 224); + EXPECT_EQ(input.get_image().width(), 224); // Assign text back to image input input = MultimodalInput(text); EXPECT_TRUE(input.is_text()); EXPECT_EQ(input.get_text(), text); } -} // namespace