Skip to content

Commit 7676d45

Browse files
authored
JNI support for multiple ptd files
Differential Revision: D82072929 Pull Request resolved: #14168
1 parent d002ab1 commit 7676d45

File tree

2 files changed

+47
-15
lines changed

2 files changed

+47
-15
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import com.facebook.jni.HybridData;
1212
import com.facebook.jni.annotations.DoNotStrip;
1313
import java.io.File;
14+
import java.util.List;
1415
import org.pytorch.executorch.ExecuTorchRuntime;
1516
import org.pytorch.executorch.annotations.Experimental;
1617

@@ -32,14 +33,22 @@ public class LlmModule {
3233

3334
@DoNotStrip
3435
private static native HybridData initHybrid(
35-
int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath);
36+
int modelType,
37+
String modulePath,
38+
String tokenizerPath,
39+
float temperature,
40+
List<String> dataFiles);
3641

3742
/**
3843
* Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
39-
* data path.
44+
* dataFiles.
4045
*/
4146
public LlmModule(
42-
int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) {
47+
int modelType,
48+
String modulePath,
49+
String tokenizerPath,
50+
float temperature,
51+
List<String> dataFiles) {
4352
ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime();
4453

4554
File modelFile = new File(modulePath);
@@ -50,25 +59,35 @@ public LlmModule(
5059
if (!tokenizerFile.canRead() || !tokenizerFile.isFile()) {
5160
throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath);
5261
}
53-
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataPath);
62+
63+
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataFiles);
64+
}
65+
66+
/**
67+
* Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
68+
* data path.
69+
*/
70+
public LlmModule(
71+
int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) {
72+
this(modelType, modulePath, tokenizerPath, temperature, List.of(dataPath));
5473
}
5574

5675
/** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */
5776
public LlmModule(String modulePath, String tokenizerPath, float temperature) {
58-
this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null);
77+
this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, List.of());
5978
}
6079

6180
/**
6281
* Constructs a LLM Module for a model with given model path, tokenizer, temperature and data
6382
* path.
6483
*/
6584
public LlmModule(String modulePath, String tokenizerPath, float temperature, String dataPath) {
66-
this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath);
85+
this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, List.of(dataPath));
6786
}
6887

6988
/** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
7089
public LlmModule(int modelType, String modulePath, String tokenizerPath, float temperature) {
71-
this(modelType, modulePath, tokenizerPath, temperature, null);
90+
this(modelType, modulePath, tokenizerPath, temperature, List.of());
7291
}
7392

7493
/** Constructs a LLM Module for a model with the given LlmModuleConfig */

extension/android/jni/jni_layer_llama.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,21 +140,21 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
140140
facebook::jni::alias_ref<jstring> model_path,
141141
facebook::jni::alias_ref<jstring> tokenizer_path,
142142
jfloat temperature,
143-
facebook::jni::alias_ref<jstring> data_path) {
143+
facebook::jni::alias_ref<jobject> data_files) {
144144
return makeCxxInstance(
145145
model_type_category,
146146
model_path,
147147
tokenizer_path,
148148
temperature,
149-
data_path);
149+
data_files);
150150
}
151151

152152
ExecuTorchLlmJni(
153153
jint model_type_category,
154154
facebook::jni::alias_ref<jstring> model_path,
155155
facebook::jni::alias_ref<jstring> tokenizer_path,
156156
jfloat temperature,
157-
facebook::jni::alias_ref<jstring> data_path = nullptr) {
157+
facebook::jni::alias_ref<jobject> data_files = nullptr) {
158158
temperature_ = temperature;
159159
#if defined(ET_USE_THREADPOOL)
160160
// Reserve 1 thread for the main thread.
@@ -173,26 +173,39 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
173173
model_path->toStdString().c_str(),
174174
llm::load_tokenizer(tokenizer_path->toStdString()));
175175
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
176-
std::optional<const std::string> data_path_str = data_path
177-
? std::optional<const std::string>{data_path->toStdString()}
178-
: std::nullopt;
176+
std::vector<std::string> data_files_vector;
177+
if (data_files != nullptr) {
178+
// Convert Java List<String> to C++ std::vector<string>
179+
auto list_class = facebook::jni::findClassStatic("java/util/List");
180+
auto size_method = list_class->getMethod<jint()>("size");
181+
auto get_method =
182+
list_class->getMethod<facebook::jni::local_ref<jobject>(jint)>(
183+
"get");
184+
185+
jint size = size_method(data_files);
186+
for (jint i = 0; i < size; ++i) {
187+
auto str_obj = get_method(data_files, i);
188+
auto jstr = facebook::jni::static_ref_cast<jstring>(str_obj);
189+
data_files_vector.push_back(jstr->toStdString());
190+
}
191+
}
179192
runner_ = executorch::extension::llm::create_text_llm_runner(
180193
model_path->toStdString(),
181194
llm::load_tokenizer(tokenizer_path->toStdString()),
182-
data_path_str);
195+
data_files_vector);
183196
#if defined(EXECUTORCH_BUILD_QNN)
184197
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
185198
std::unique_ptr<executorch::extension::Module> module = std::make_unique<
186199
executorch::extension::Module>(
187200
model_path->toStdString().c_str(),
201+
data_files_set,
188202
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
189203
std::string decoder_model = "llama3"; // use llama3 for now
190204
runner_ = std::make_unique<example::Runner<uint16_t>>( // QNN runner
191205
std::move(module),
192206
decoder_model.c_str(),
193207
model_path->toStdString().c_str(),
194208
tokenizer_path->toStdString().c_str(),
195-
data_path->toStdString().c_str(),
196209
"");
197210
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
198211
#endif

0 commit comments

Comments
 (0)