Skip to content

Commit de023c1

Browse files
larryliu0820pytorchbot
authored andcommitted
[multimodal] Allow float32 image input (#14359)
Letting `Image` class support both `uint8_t` and `float` data types, changing `MultimodalPrefiller` class to support text, image, and audio modalities with error checking and modularity. **Image Data Handling and Type Safety:** * Refactored the `Image` class in `image.h` from a simple struct to a class that uses a `std::variant` to support both `uint8_t` and `float` image data, providing type-safe accessors and a `toTensor` method for conversion to tensors. * Updated `load_image` in Llava `main.cpp` to construct `Image` objects using the new class interface and move semantics, ensuring correct data layout and encapsulation. * Added a runtime check in `LlavaImagePrefiller` to ensure only `uint8_t` images are processed, using the new type-checking methods. **Multimodal Prefill Logic and Flexibility:** * Updated the `MultimodalPrefiller` class in `multimodal_prefiller.h` to dynamically check input types, validate tensor types against model expectations, and handles encoder/decoder execution with improved error handling and modularity. (cherry picked from commit bc18834)
1 parent 573f30d commit de023c1

File tree

6 files changed

+171
-99
lines changed

6 files changed

+171
-99
lines changed

examples/models/llava/main.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,20 @@ void load_image(const std::string& image_path, Image& image) {
7575
new_height,
7676
0,
7777
channels);
78-
// transpose to CHW
79-
image.data.resize(channels * new_width * new_height);
78+
std::vector<uint8_t> chw_data(channels * new_width * new_height);
8079
for (int i = 0; i < new_width * new_height; ++i) {
8180
for (int c = 0; c < channels; ++c) {
82-
image.data[c * new_width * new_height + i] =
83-
resized_data[i * channels + c];
81+
chw_data[c * new_width * new_height + i] = resized_data[i * channels + c];
8482
}
8583
}
86-
image.width = new_width;
87-
image.height = new_height;
88-
image.channels = channels;
84+
image = Image(std::move(chw_data), new_width, new_height, channels);
8985
// convert to tensor
9086
ET_LOG(
9187
Info,
9288
"image Channels: %" PRId32 ", Height: %" PRId32 ", Width: %" PRId32,
93-
image.channels,
94-
image.height,
95-
image.width);
89+
image.channels(),
90+
image.height(),
91+
image.width());
9692
stbi_image_free(data);
9793
}
9894

extension/android/jni/jni_layer_llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
268268
for (int i = 0; i < image_size; i++) {
269269
image_data[i] = image_data_jint[i];
270270
}
271-
llm::Image image_runner{image_data, width, height, channels};
271+
llm::Image image_runner{std::move(image_data), width, height, channels};
272272
prefill_inputs_.emplace_back(
273273
llm::MultimodalInput{std::move(image_runner)});
274274
}

extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,12 @@ - (BOOL)generate:(NSArray<ExecuTorchLLMMultimodalInput *> *)inputs
172172
case ExecuTorchLLMMultimodalInputTypeImage: {
173173
ExecuTorchLLMImage *image = input.image;
174174
std::vector<uint8_t> data((uint8_t *)image.data.bytes, (uint8_t *)image.data.bytes + image.data.length);
175-
nativeInputs.emplace_back(llm::MultimodalInput(llm::Image{
176-
.data = std::move(data),
177-
.width = (int32_t)image.width,
178-
.height = (int32_t)image.height,
179-
.channels = (int32_t)image.channels
180-
}));
175+
nativeInputs.emplace_back(llm::MultimodalInput(llm::Image(
176+
std::move(data),
177+
(int32_t)image.width,
178+
(int32_t)image.height,
179+
(int32_t)image.channels
180+
)));
181181
break;
182182
}
183183
default: {

extension/llm/runner/image.h

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,112 @@
1010

1111
#pragma once
1212
#include <executorch/runtime/platform/compiler.h>
13+
#include <cstddef>
1314
#include <cstdint>
15+
#include <variant>
1416
#include <vector>
1517

18+
#include <executorch/extension/tensor/tensor.h>
19+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
20+
1621
namespace executorch {
1722
namespace extension {
1823
namespace llm {
1924

20-
struct ET_EXPERIMENTAL Image {
25+
class ET_EXPERIMENTAL Image {
26+
public:
27+
// Default constructor
28+
Image() : width_(0), height_(0), channels_(0) {}
29+
30+
// Constructor for uint8_t data
31+
Image(
32+
std::vector<uint8_t>&& data,
33+
int32_t width,
34+
int32_t height,
35+
int32_t channels)
36+
: data_(std::move(data)),
37+
width_(width),
38+
height_(height),
39+
channels_(channels) {}
40+
41+
// Constructor for float data
42+
Image(
43+
std::vector<float>&& data,
44+
int32_t width,
45+
int32_t height,
46+
int32_t channels)
47+
: data_(std::move(data)),
48+
width_(width),
49+
height_(height),
50+
channels_(channels) {}
51+
52+
// Getters
53+
int32_t width() const {
54+
return width_;
55+
}
56+
int32_t height() const {
57+
return height_;
58+
}
59+
int32_t channels() const {
60+
return channels_;
61+
}
62+
63+
// Data access
64+
bool is_uint8() const {
65+
return std::holds_alternative<std::vector<uint8_t>>(data_);
66+
}
67+
68+
bool is_float() const {
69+
return std::holds_alternative<std::vector<float>>(data_);
70+
}
71+
72+
const std::vector<uint8_t>& get_uint8_data() const& {
73+
return std::get<std::vector<uint8_t>>(data_);
74+
}
75+
76+
std::vector<uint8_t>& get_uint8_data() & {
77+
return std::get<std::vector<uint8_t>>(data_);
78+
}
79+
80+
const std::vector<float>& get_float_data() const& {
81+
return std::get<std::vector<float>>(data_);
82+
}
83+
84+
std::vector<float>& get_float_data() & {
85+
return std::get<std::vector<float>>(data_);
86+
}
87+
88+
executorch::runtime::Result<executorch::extension::TensorPtr> toTensor(
89+
bool with_batch = false) const {
90+
// Note: This creates a 3D tensor (CHW). The model might expect a 4D
91+
// tensor (NCHW). The caller should handle reshaping if needed.
92+
std::vector<executorch::aten::SizesType> sizes = {
93+
channels(), height(), width()};
94+
if (with_batch) {
95+
sizes.insert(sizes.begin(), 1);
96+
}
97+
if (is_float()) {
98+
return executorch::extension::from_blob(
99+
const_cast<float*>(get_float_data().data()),
100+
sizes,
101+
::executorch::aten::ScalarType::Float);
102+
} else if (is_uint8()) {
103+
return executorch::extension::from_blob(
104+
const_cast<uint8_t*>(get_uint8_data().data()),
105+
sizes,
106+
::executorch::aten::ScalarType::Byte);
107+
}
108+
ET_LOG(
109+
Error, "Image data is not initialized with uint8_t or float vector.");
110+
return ::executorch::runtime::Error::NotSupported;
111+
}
112+
113+
private:
21114
// Assuming NCHW format
22-
std::vector<uint8_t> data;
23-
int32_t width;
24-
int32_t height;
25-
int32_t channels;
115+
std::variant<std::vector<uint8_t>, std::vector<float>> data_;
116+
int32_t width_;
117+
int32_t height_;
118+
int32_t channels_;
26119
};
27120

28121
} // namespace llm

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ Result<uint64_t> MultimodalPrefiller::prefill(
4343
Image image = input.get_image();
4444

4545
auto method_meta = ET_UNWRAP(
46-
module_->method_meta(kVisionEncoderMethod),
46+
module_->method_meta(kImageEncoderMethod),
4747
"Failed to get method_meta for %s",
48-
kVisionEncoderMethod);
48+
kImageEncoderMethod);
4949

5050
ET_CHECK_MSG(
5151
method_meta.num_inputs() > 0,

0 commit comments

Comments
 (0)