Skip to content

Commit cc4c2e5

Browse files
committed
First commit
1 parent 378c700 commit cc4c2e5

File tree

3 files changed

+59
-12
lines changed

3 files changed

+59
-12
lines changed

examples/models/llava/export_llava.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,20 @@ def __init__(self, llava):
7777
super().__init__()
7878
self.text_model = llava.text_model
7979

80-
def forward(self, input_pos, embeddings):
81-
return self.text_model(None, {"input_pos": input_pos}, embeddings)
80+
def forward(self, cache_positions, embeddings):
81+
return self.text_model(None, {"input_pos": cache_positions[:1]}, embeddings)
8282

8383
llava_text_model = LlavaTextModel(llava)
84-
8584
text_model_em = LLMEdgeManager(
8685
model=llava_text_model,
8786
modelname="llava_text_model",
8887
max_seq_len=llava.text_model_args.max_seq_len,
8988
dtype=DType.fp32,
9089
use_kv_cache=True,
91-
example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings),
90+
example_inputs=(
91+
torch.tensor(list(range(embeddings.shape[1])), dtype=torch.int64),
92+
embeddings,
93+
),
9294
dynamic_shapes=dynamic_shapes,
9395
)
9496

examples/models/llava/main.cpp

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/examples/models/llava/runner/llava_runner.h>
9+
#include <executorch/extension/llm/runner/image.h>
10+
#include <executorch/extension/llm/runner/multimodal_input.h>
11+
#include <executorch/extension/llm/runner/multimodal_runner.h>
1012
#include <gflags/gflags.h>
13+
#include <pytorch/tokenizers/llama2c_tokenizer.h>
1114
#define STB_IMAGE_IMPLEMENTATION
1215
#include <stb_image.h>
1316
#define STB_IMAGE_RESIZE_IMPLEMENTATION
@@ -44,7 +47,10 @@ DEFINE_int32(
4447
-1,
4548
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
4649

47-
using executorch::extension::llm::Image;
50+
using ::executorch::extension::llm::Image;
51+
using ::executorch::extension::llm::make_image_input;
52+
using ::executorch::extension::llm::make_text_input;
53+
using ::executorch::extension::llm::MultimodalInput;
4854

4955
void load_image(const std::string& image_path, Image& image) {
5056
int width, height, channels;
@@ -127,14 +133,53 @@ int32_t main(int32_t argc, char** argv) {
127133
->_unsafe_reset_threadpool(num_performant_cores);
128134
}
129135
#endif
130-
// create llama runner
131-
example::LlavaRunner runner(model_path, tokenizer_path, temperature);
136+
// Load tokenizer
137+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
138+
std::make_unique<tokenizers::Llama2cTokenizer>();
139+
tokenizer->load(tokenizer_path);
140+
if (tokenizer == nullptr) {
141+
ET_LOG(Error, "Failed to load tokenizer from: %s", tokenizer_path);
142+
return 1;
143+
}
144+
145+
// Create multimodal runner
146+
std::unique_ptr<::executorch::extension::llm::MultimodalRunner> runner =
147+
::executorch::extension::llm::create_multimodal_runner(
148+
model_path, std::move(tokenizer));
149+
if (runner == nullptr) {
150+
ET_LOG(Error, "Failed to create multimodal runner");
151+
return 1;
152+
}
132153

154+
// Load runner
155+
auto load_error = runner->load();
156+
if (load_error != ::executorch::runtime::Error::Ok) {
157+
ET_LOG(Error, "Failed to load multimodal runner");
158+
return 1;
159+
}
160+
161+
// Prepare inputs
162+
static const char* kPresetPrompt =
163+
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";
133164
Image image;
134165
load_image(image_path, image);
135-
std::vector<Image> images = {image};
166+
std::vector<MultimodalInput> inputs = {
167+
make_text_input(std::string(kPresetPrompt)),
168+
make_image_input(image),
169+
make_text_input(std::string(prompt)),
170+
};
171+
172+
::executorch::extension::llm::GenerationConfig config;
173+
config.temperature = temperature;
174+
175+
// Generate
176+
ET_LOG(Info, "Starting generation...");
177+
auto error = runner->generate(inputs, config);
178+
if (error != ::executorch::runtime::Error::Ok) {
179+
ET_LOG(Error, "Failed to generate with multimodal runner");
180+
return 1;
181+
}
136182

137-
// generate
138-
runner.generate(std::move(images), prompt, seq_len);
183+
printf("\n");
139184
return 0;
140185
}

examples/models/llava/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,5 +405,5 @@ def _get_image_dynamic_shapes(self):
405405

406406
def _get_prompt_dynamic_shapes(self):
407407
dim = torch.export.Dim("token_dim", min=2, max=self.max_seq_len)
408-
text_model_dynamic_shapes = ({0: 1}, {1: dim})
408+
text_model_dynamic_shapes = ({0: dim}, {1: dim})
409409
return text_model_dynamic_shapes

0 commit comments

Comments
 (0)