Skip to content

Commit d7be54c

Browse files
committed
More changes
1 parent f2c4b43 commit d7be54c

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

examples/models/llava/main.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ int32_t main(int32_t argc, char** argv) {
131131
#endif
132132
// Load tokenizer
133133
std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
134-
std::make_unique<tokenizers::Llama2cTokenizer>();
135-
tokenizer->load(tokenizer_path);
134+
::executorch::extension::llm::load_tokenizer(tokenizer_path);
136135
if (tokenizer == nullptr) {
137136
ET_LOG(Error, "Failed to load tokenizer from: %s", tokenizer_path);
138137
return 1;

extension/llm/runner/pybindings.cpp

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,14 @@ PYBIND11_MODULE(_llm_runner, m) {
277277
}
278278
return py::none();
279279
})
280+
.def(
281+
"get_image",
282+
[](const MultimodalInput& input) -> py::object {
283+
if (input.is_image()) {
284+
return py::cast(input.get_image());
285+
}
286+
return py::none();
287+
})
280288
.def("__repr__", [](const MultimodalInput& input) -> std::string {
281289
if (input.is_text()) {
282290
return "<MultimodalInput type=text content=\"" +
@@ -336,23 +344,27 @@ PYBIND11_MODULE(_llm_runner, m) {
336344
"Image must have 3 (RGB) or 4 (RGBA) channels");
337345
}
338346

339-
if (image_tensor.scalar_type() != torch::kUInt8) {
340-
if (image_tensor.max().item<double>() <= 1.0) {
341-
image_tensor = (image_tensor * 255).to(torch::kUInt8);
342-
} else {
343-
image_tensor = image_tensor.to(torch::kUInt8);
344-
}
345-
}
346-
347347
image_tensor = image_tensor.contiguous();
348-
uint8_t* data = image_tensor.data_ptr<uint8_t>();
349-
std::vector<uint8_t> image_data(data, data + image_tensor.numel());
350-
351-
return MultimodalInput(Image(
352-
std::move(image_data),
353-
static_cast<int32_t>(width),
354-
static_cast<int32_t>(height),
355-
static_cast<int32_t>(channels)));
348+
if (image_tensor.scalar_type() == torch::kUInt8) {
349+
uint8_t* data = image_tensor.data_ptr<uint8_t>();
350+
std::vector<uint8_t> image_data(data, data + image_tensor.numel());
351+
return MultimodalInput(Image(
352+
std::move(image_data),
353+
static_cast<int32_t>(width),
354+
static_cast<int32_t>(height),
355+
static_cast<int32_t>(channels)));
356+
} else if (image_tensor.scalar_type() == torch::kFloat) {
357+
float* data = image_tensor.data_ptr<float>();
358+
std::vector<float> image_data(data, data + image_tensor.numel());
359+
return MultimodalInput(Image(
360+
std::move(image_data),
361+
static_cast<int32_t>(width),
362+
static_cast<int32_t>(height),
363+
static_cast<int32_t>(channels)));
364+
} else {
365+
throw std::runtime_error(
366+
"Unsupported image tensor dtype. Only uint8 and float32 are supported.");
367+
}
356368
},
357369
"Create an image input from a torch tensor (H, W, C), (1, H, W, C), (C, H, W), or (1, C, H, W)",
358370
py::arg("image_tensor"));

0 commit comments

Comments
 (0)