Skip to content

Commit fb86e61

Browse files
authored
Add Echo parameter to llama runner, jni+java layer, and demo app
Differential Revision: D62247137 Pull Request resolved: #5011
1 parent f55ce1f commit fb86e61

File tree

6 files changed

+51
-10
lines changed

6 files changed

+51
-10
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,15 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa
7373

7474
@Override
7575
public void onResult(String result) {
76-
mResultMessage.appendText(result);
77-
run();
76+
if (result.equals("\n\n")) {
77+
if (!mResultMessage.getText().isEmpty()) {
78+
mResultMessage.appendText(result);
79+
run();
80+
}
81+
} else {
82+
mResultMessage.appendText(result);
83+
run();
84+
}
7885
}
7986

8087
@Override
@@ -614,6 +621,7 @@ public void run() {
614621
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
615622
prompt,
616623
ModelUtils.VISION_MODEL_SEQ_LEN,
624+
false,
617625
MainActivity.this);
618626
} else {
619627
// no image selected, we pass in empty int array
@@ -624,10 +632,12 @@ public void run() {
624632
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
625633
prompt,
626634
ModelUtils.VISION_MODEL_SEQ_LEN,
635+
false,
627636
MainActivity.this);
628637
}
629638
} else {
630-
mModule.generate(prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, MainActivity.this);
639+
mModule.generate(
640+
prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, false, MainActivity.this);
631641
}
632642

633643
long generateDuration = System.currentTimeMillis() - generateStartTime;

examples/models/llama2/export_llama_lib.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,6 @@ def build_args_parser() -> argparse.ArgumentParser:
313313

314314

315315
def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
316-
317316
path = str(path)
318317

319318
if verbose_export():

examples/models/llama2/runner/runner.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ Error Runner::generate(
143143
const std::string& prompt,
144144
int32_t seq_len,
145145
std::function<void(const std::string&)> token_callback,
146-
std::function<void(const Stats&)> stats_callback) {
146+
std::function<void(const Stats&)> stats_callback,
147+
bool echo) {
147148
// Prepare the inputs.
148149
// Use ones-initialized inputs.
149150
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
@@ -208,7 +209,9 @@ Error Runner::generate(
208209
// after the prompt. After that we will enter generate loop.
209210

210211
// print prompts
211-
wrapped_callback(prompt);
212+
if (echo) {
213+
wrapped_callback(prompt);
214+
}
212215
int64_t pos = 0;
213216
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
214217
stats_.first_token_ms = util::time_in_ms();

examples/models/llama2/runner/runner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class Runner {
4040
const std::string& prompt,
4141
int32_t seq_len = 128,
4242
std::function<void(const std::string&)> token_callback = {},
43-
std::function<void(const Stats&)> stats_callback = {});
43+
std::function<void(const Stats&)> stats_callback = {},
44+
bool echo = true);
4445
void stop();
4546

4647
private:

extension/android/jni/jni_layer_llama.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class ExecuTorchLlamaJni
150150
jint channels,
151151
facebook::jni::alias_ref<jstring> prompt,
152152
jint seq_len,
153+
jboolean echo,
153154
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
154155
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
155156
auto image_size = image->size();
@@ -175,7 +176,8 @@ class ExecuTorchLlamaJni
175176
prompt->toStdString(),
176177
seq_len,
177178
[callback](std::string result) { callback->onResult(result); },
178-
[callback](const Stats& result) { callback->onStats(result); });
179+
[callback](const Stats& result) { callback->onStats(result); },
180+
echo);
179181
}
180182
return 0;
181183
}

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public class LlamaModule {
3333

3434
private final HybridData mHybridData;
3535
private static final int DEFAULT_SEQ_LEN = 128;
36+
private static final boolean DEFAULT_ECHO = true;
3637

3738
@DoNotStrip
3839
private static native HybridData initHybrid(
@@ -59,7 +60,7 @@ public void resetNative() {
5960
* @param llamaCallback callback object to receive results.
6061
*/
6162
public int generate(String prompt, LlamaCallback llamaCallback) {
62-
return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback);
63+
return generate(prompt, DEFAULT_SEQ_LEN, DEFAULT_ECHO, llamaCallback);
6364
}
6465

6566
/**
@@ -70,7 +71,30 @@ public int generate(String prompt, LlamaCallback llamaCallback) {
7071
* @param llamaCallback callback object to receive results.
7172
*/
7273
public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) {
73-
return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback);
74+
return generate(null, 0, 0, 0, prompt, seqLen, DEFAULT_ECHO, llamaCallback);
75+
}
76+
77+
/**
78+
* Start generating tokens from the module.
79+
*
80+
* @param prompt Input prompt
81+
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
82+
* @param llamaCallback callback object to receive results.
83+
*/
84+
public int generate(String prompt, boolean echo, LlamaCallback llamaCallback) {
85+
return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, echo, llamaCallback);
86+
}
87+
88+
/**
89+
* Start generating tokens from the module.
90+
*
91+
* @param prompt Input prompt
92+
* @param seqLen sequence length
93+
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
94+
* @param llamaCallback callback object to receive results.
95+
*/
96+
public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llamaCallback) {
97+
return generate(null, 0, 0, 0, prompt, seqLen, echo, llamaCallback);
7498
}
7599

76100
/**
@@ -82,6 +106,7 @@ public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) {
82106
* @param channels Input image number of channels
83107
* @param prompt Input prompt
84108
* @param seqLen sequence length
109+
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
85110
* @param llamaCallback callback object to receive results.
86111
*/
87112
@DoNotStrip
@@ -92,6 +117,7 @@ public native int generate(
92117
int channels,
93118
String prompt,
94119
int seqLen,
120+
boolean echo,
95121
LlamaCallback llamaCallback);
96122

97123
/**

0 commit comments

Comments
 (0)