Skip to content

Commit fb2cf2d

Browse files
committed
[multimodal] Allow float32 image input
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 108d29d commit fb2cf2d

File tree

4 files changed

+306
-64
lines changed

4 files changed

+306
-64
lines changed

examples/models/llava/main.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,24 +81,20 @@ void load_image(const std::string& image_path, Image& image) {
8181
new_height,
8282
0,
8383
channels);
84-
// transpose to CHW
85-
image.data.resize(channels * new_width * new_height);
84+
std::vector<uint8_t> chw_data(channels * new_width * new_height);
8685
for (int i = 0; i < new_width * new_height; ++i) {
8786
for (int c = 0; c < channels; ++c) {
88-
image.data[c * new_width * new_height + i] =
89-
resized_data[i * channels + c];
87+
chw_data[c * new_width * new_height + i] = resized_data[i * channels + c];
9088
}
9189
}
92-
image.width = new_width;
93-
image.height = new_height;
94-
image.channels = channels;
90+
image = Image(std::move(chw_data), new_width, new_height, channels);
9591
// convert to tensor
9692
ET_LOG(
9793
Info,
9894
"image Channels: %" PRId32 ", Height: %" PRId32 ", Width: %" PRId32,
99-
image.channels,
100-
image.height,
101-
image.width);
95+
image.channels(),
96+
image.height(),
97+
image.width());
10298
stbi_image_free(data);
10399
}
104100

examples/models/llava/runner/llava_image_prefiller.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ class ET_EXPERIMENTAL LlavaImagePrefiller {
3333
inline ::executorch::runtime::Result<executorch::aten::Tensor> prefill(
3434
::executorch::extension::llm::Image& image,
3535
int64_t& start_pos) {
36+
ET_CHECK_MSG(
37+
image.is_uint8(), "LlavaImagePrefiller only supports uint8_t images");
3638
auto image_tensor = executorch::extension::from_blob(
37-
image.data.data(),
38-
{3, image.height, image.width},
39+
image.get_uint8_data().data(),
40+
{3, image.height(), image.width()},
3941
::executorch::aten::ScalarType::Byte);
4042
// Run image encoder
4143
auto image_encoder_outputs =

extension/llm/runner/image.h

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,112 @@
1010

1111
#pragma once
1212
#include <executorch/runtime/platform/compiler.h>
13+
#include <cstddef>
1314
#include <cstdint>
15+
#include <variant>
1416
#include <vector>
1517

18+
#include <executorch/extension/tensor/tensor.h>
19+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
20+
1621
namespace executorch {
1722
namespace extension {
1823
namespace llm {
1924

20-
struct ET_EXPERIMENTAL Image {
25+
class ET_EXPERIMENTAL Image {
26+
public:
27+
// Default constructor
28+
Image() : width_(0), height_(0), channels_(0) {}
29+
30+
// Constructor for uint8_t data
31+
Image(
32+
std::vector<uint8_t>&& data,
33+
int32_t width,
34+
int32_t height,
35+
int32_t channels)
36+
: data_(std::move(data)),
37+
width_(width),
38+
height_(height),
39+
channels_(channels) {}
40+
41+
// Constructor for float data
42+
Image(
43+
std::vector<float>&& data,
44+
int32_t width,
45+
int32_t height,
46+
int32_t channels)
47+
: data_(std::move(data)),
48+
width_(width),
49+
height_(height),
50+
channels_(channels) {}
51+
52+
// Getters
53+
int32_t width() const {
54+
return width_;
55+
}
56+
int32_t height() const {
57+
return height_;
58+
}
59+
int32_t channels() const {
60+
return channels_;
61+
}
62+
63+
// Data access
64+
bool is_uint8() const {
65+
return std::holds_alternative<std::vector<uint8_t>>(data_);
66+
}
67+
68+
bool is_float() const {
69+
return std::holds_alternative<std::vector<float>>(data_);
70+
}
71+
72+
const std::vector<uint8_t>& get_uint8_data() const& {
73+
return std::get<std::vector<uint8_t>>(data_);
74+
}
75+
76+
std::vector<uint8_t>& get_uint8_data() & {
77+
return std::get<std::vector<uint8_t>>(data_);
78+
}
79+
80+
const std::vector<float>& get_float_data() const& {
81+
return std::get<std::vector<float>>(data_);
82+
}
83+
84+
std::vector<float>& get_float_data() & {
85+
return std::get<std::vector<float>>(data_);
86+
}
87+
88+
executorch::runtime::Result<executorch::extension::TensorPtr> toTensor(
89+
bool with_batch = false) const {
90+
// Note: This creates a 3D tensor (CHW). The model might expect a 4D
91+
// tensor (NCHW). The caller should handle reshaping if needed.
92+
std::vector<executorch::aten::SizesType> sizes = {
93+
channels(), height(), width()};
94+
if (with_batch) {
95+
sizes.insert(sizes.begin(), 1);
96+
}
97+
if (is_float()) {
98+
return executorch::extension::from_blob(
99+
const_cast<float*>(get_float_data().data()),
100+
sizes,
101+
::executorch::aten::ScalarType::Float);
102+
} else if (is_uint8()) {
103+
return executorch::extension::from_blob(
104+
const_cast<uint8_t*>(get_uint8_data().data()),
105+
sizes,
106+
::executorch::aten::ScalarType::Byte);
107+
}
108+
ET_LOG(
109+
Error, "Image data is not initialized with uint8_t or float vector.");
110+
return ::executorch::runtime::Error::NotSupported;
111+
}
112+
113+
private:
21114
// Assuming NCHW format
22-
std::vector<uint8_t> data;
23-
int32_t width;
24-
int32_t height;
25-
int32_t channels;
115+
std::variant<std::vector<uint8_t>, std::vector<float>> data_;
116+
int32_t width_;
117+
int32_t height_;
118+
int32_t channels_;
26119
};
27120

28121
} // namespace llm

0 commit comments

Comments
 (0)