1616#include < vector>
1717
1818#include < executorch/examples/models/llama2/runner/runner.h>
19+ #include < executorch/examples/models/llava/runner/llava_runner.h>
20+ #include < executorch/extension/llm/runner/image.h>
1921#include < executorch/runtime/platform/log.h>
2022#include < executorch/runtime/platform/platform.h>
2123#include < executorch/runtime/platform/runtime.h>
@@ -90,21 +92,29 @@ class ExecuTorchLlamaJni
9092 : public facebook::jni::HybridClass<ExecuTorchLlamaJni> {
9193 private:
9294 friend HybridBase;
95+ int model_type_category_;
9396 std::unique_ptr<Runner> runner_;
97+ std::unique_ptr<MultimodalRunner> multi_modal_runner_;
9498
9599 public:
96100 constexpr static auto kJavaDescriptor =
97101 " Lorg/pytorch/executorch/LlamaModule;" ;
98102
103+ constexpr static int MODEL_TYPE_CATEGORY_LLM = 1 ;
104+ constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2 ;
105+
99106 static facebook::jni::local_ref<jhybriddata> initHybrid (
100107 facebook::jni::alias_ref<jclass>,
108+ jint model_type_category,
101109 facebook::jni::alias_ref<jstring> model_path,
102110 facebook::jni::alias_ref<jstring> tokenizer_path,
103111 jfloat temperature) {
104- return makeCxxInstance (model_path, tokenizer_path, temperature);
112+ return makeCxxInstance (
113+ model_type_category, model_path, tokenizer_path, temperature);
105114 }
106115
107116 ExecuTorchLlamaJni (
117+ jint model_type_category,
108118 facebook::jni::alias_ref<jstring> model_path,
109119 facebook::jni::alias_ref<jstring> tokenizer_path,
110120 jfloat temperature) {
@@ -119,30 +129,72 @@ class ExecuTorchLlamaJni
119129 }
120130#endif
121131
122- runner_ = std::make_unique<Runner>(
123- model_path->toStdString ().c_str (),
124- tokenizer_path->toStdString ().c_str (),
125- temperature);
132+ model_type_category_ = model_type_category;
133+ if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
134+ multi_modal_runner_ = std::make_unique<LlavaRunner>(
135+ model_path->toStdString ().c_str (),
136+ tokenizer_path->toStdString ().c_str (),
137+ temperature);
138+ } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
139+ runner_ = std::make_unique<Runner>(
140+ model_path->toStdString ().c_str (),
141+ tokenizer_path->toStdString ().c_str (),
142+ temperature);
143+ }
126144 }
127145
128146 jint generate (
147+ facebook::jni::alias_ref<jintArray> image,
148+ jint width,
149+ jint height,
150+ jint channels,
129151 facebook::jni::alias_ref<jstring> prompt,
130152 jint seq_len,
131153 facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
132- runner_->generate (
133- prompt->toStdString (),
134- seq_len,
135- [callback](std::string result) { callback->onResult (result); },
136- [callback](const Stats& result) { callback->onStats (result); });
154+ if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
155+ auto image_size = image->size ();
156+ std::vector<Image> images;
157+ if (image_size != 0 ) {
158+ std::vector<jint> image_data_jint (image_size);
159+ std::vector<uint8_t > image_data (image_size);
160+ image->getRegion (0 , image_size, image_data_jint.data ());
161+ for (int i = 0 ; i < image_size; i++) {
162+ image_data[i] = image_data_jint[i];
163+ }
164+ Image image_runner{image_data, width, height, channels};
165+ images.push_back (image_runner);
166+ }
167+ multi_modal_runner_->generate (
168+ images,
169+ prompt->toStdString (),
170+ seq_len,
171+ [callback](std::string result) { callback->onResult (result); },
172+ [callback](const Stats& result) { callback->onStats (result); });
173+ } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
174+ runner_->generate (
175+ prompt->toStdString (),
176+ seq_len,
177+ [callback](std::string result) { callback->onResult (result); },
178+ [callback](const Stats& result) { callback->onStats (result); });
179+ }
137180 return 0 ;
138181 }
139182
140183 void stop () {
141- runner_->stop ();
184+ if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
185+ multi_modal_runner_->stop ();
186+ } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
187+ runner_->stop ();
188+ }
142189 }
143190
144191 jint load () {
145- return static_cast <jint>(runner_->load ());
192+ if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
193+ return static_cast <jint>(multi_modal_runner_->load ());
194+ } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
195+ return static_cast <jint>(runner_->load ());
196+ }
197+ return static_cast <jint>(Error::InvalidArgument);
146198 }
147199
148200 static void registerNatives () {
0 commit comments