Skip to content

Commit cb2278a

Browse files
committed
Whisper JNI first commit
Added EXECUTORCH_BUILD_WHISPER_JNI flag
1 parent 3db27cd commit cb2278a

File tree

11 files changed

+339
-9
lines changed

11 files changed

+339
-9
lines changed

examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ std::vector<std::vector<std::vector<char>>> parse_input_list_file(
9797
int main(int argc, char** argv) {
9898
gflags::ParseCommandLineFlags(&argc, &argv, true);
9999
// create llama runner
100-
example::Runner runner(FLAGS_model_path, FLAGS_tokenizer_json_path);
100+
example::WhisperRunner runner(FLAGS_model_path, FLAGS_tokenizer_json_path);
101101

102102
std::vector<std::vector<std::vector<char>>> multi_turns_input_buffers =
103103
parse_input_list_file(FLAGS_input_list_path);

examples/qualcomm/oss_scripts/whisper/runner/runner.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,20 @@ static constexpr auto kDecoderStartTokenId = "decoder_start_token_id";
2727
static constexpr auto kEosId = "get_eos_id";
2828
static constexpr auto kMaxContextLen = "get_max_context_len";
2929
} // namespace
30-
Runner::Runner(
30+
WhisperRunner::WhisperRunner(
3131
const std::string& model_path,
3232
const std::string& tokenizer_json_path)
3333
: tokenizer_json_path_(tokenizer_json_path) {
3434
encoder_ = std::make_unique<WhisperEncoder>(model_path);
3535
decoder_ = std::make_unique<WhisperDecoder>(model_path);
3636
tokenizer_ = std::make_unique<tokenizers::HFTokenizer>();
3737
}
38-
bool Runner::is_loaded() const {
38+
bool WhisperRunner::is_loaded() const {
3939
return encoder_->is_method_loaded() && decoder_->is_method_loaded() &&
4040
tokenizer_->is_loaded() && sampler_;
4141
}
4242

43-
Error Runner::load() {
43+
Error WhisperRunner::load() {
4444
if (is_loaded()) {
4545
return Error::Ok;
4646
}
@@ -108,12 +108,12 @@ Error Runner::load() {
108108

109109
return Error::Ok;
110110
}
111-
uint64_t Runner::logits_to_token(
111+
uint64_t WhisperRunner::logits_to_token(
112112
const executorch::aten::Tensor& logits_tensor) {
113113
return sampler_->sample(logits_tensor.data_ptr<float>());
114114
}
115115

116-
Error Runner::transcribe(
116+
Error WhisperRunner::transcribe(
117117
int32_t seq_len,
118118
std::vector<std::vector<char>>& inputs,
119119
std::function<void(const std::string&)> token_callback) {
@@ -184,7 +184,7 @@ Error Runner::transcribe(
184184
return Error::Ok;
185185
}
186186

187-
Error Runner::print_performance() {
187+
Error WhisperRunner::print_performance() {
188188
ET_LOG(Info, "\tTotal Generated token:\t\t\t\t%ld", num_generated_token_);
189189

190190
ET_LOG(

examples/qualcomm/oss_scripts/whisper/runner/runner.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525
namespace example {
2626

27-
class Runner {
27+
class WhisperRunner {
2828
public:
29-
explicit Runner(
29+
explicit WhisperRunner(
3030
const std::string& model_path,
3131
const std::string& tokenizer_json_path);
3232

extension/android/CMakeLists.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,14 @@ set_target_properties(
6969

7070
executorch_target_link_options_shared_lib(executorch)
7171

72+
<<<<<<< HEAD
7273
add_library(
7374
executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp
7475
jni/jni_helper.cpp
7576
)
77+
=======
78+
add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp)
79+
>>>>>>> 37d0a6944a (Whisper JNI first commit)
7680

7781
set(link_libraries)
7882
list(
@@ -239,6 +243,25 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
239243
endif()
240244
endif()
241245

246+
if(EXECUTORCH_BUILD_WHISPER_JNI)
247+
target_sources(executorch_jni PRIVATE jni/jni_layer_whisper.cpp jni/log.cpp)
248+
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_WHISPER_JNI=1)
249+
if(QNN_SDK_ROOT)
250+
target_sources(
251+
executorch_jni
252+
PRIVATE
253+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/encoder.cpp
254+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/decoder.cpp
255+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp
256+
)
257+
258+
target_include_directories(
259+
executorch_jni
260+
PRIVATE ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner
261+
)
262+
endif()
263+
endif()
264+
242265
target_include_directories(
243266
executorch_jni
244267
PRIVATE
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch.extension.audio;
10+
11+
import com.facebook.jni.annotations.DoNotStrip;
12+
import org.pytorch.executorch.annotations.Experimental;
13+
14+
/**
15+
* Callback interface for Whisper model. Users can implement this interface to receive the generated
16+
* tokens and statistics.
17+
*
18+
* <p>Warning: These APIs are experimental and subject to change without notice
19+
*/
20+
@Experimental
21+
public interface WhisperCallback {
22+
/**
23+
* Called when a new result is available from JNI. Users will keep getting onResult() invocations
24+
* until generate() finishes.
25+
*
26+
* @param result Last generated token
27+
*/
28+
@DoNotStrip
29+
public void onResult(String result);
30+
31+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch.extension.audio;
10+
import com.facebook.jni.HybridData;
11+
import com.facebook.jni.annotations.DoNotStrip;
12+
import java.io.File;
13+
import org.pytorch.executorch.ExecuTorchRuntime;
14+
import org.pytorch.executorch.annotations.Experimental;
15+
16+
/**
17+
* WhisperModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text
18+
* from the model.
19+
*
20+
* <p>Warning: These APIs are experimental and subject to change without notice
21+
*/
22+
@Experimental
23+
public class WhisperModule {
24+
25+
@DoNotStrip private final HybridData mHybridData;
26+
27+
@DoNotStrip
28+
private static native HybridData initHybrid(
29+
String modulePath, String tokenizerPath);
30+
31+
public WhisperModule(
32+
String modulePath, String tokenizerPath) {
33+
ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime();
34+
35+
File modelFile = new File(modulePath);
36+
if (!modelFile.canRead() || !modelFile.isFile()) {
37+
throw new RuntimeException("Cannot load model path " + modulePath);
38+
}
39+
File tokenizerFile = new File(tokenizerPath);
40+
if (!tokenizerFile.canRead() || !tokenizerFile.isFile()) {
41+
throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath);
42+
}
43+
mHybridData = initHybrid(modulePath, tokenizerPath);
44+
}
45+
46+
public void resetNative() {
47+
mHybridData.resetNative();
48+
}
49+
50+
@DoNotStrip
51+
public native int transcribe(
52+
int seqLen,
53+
byte[][] inputs,
54+
WhisperCallback callback);
55+
56+
57+
/** Force loading the module. Otherwise the model is loaded during first generate(). */
58+
@DoNotStrip
59+
public native int load();
60+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
/** Extension for LLM related use cases for ExecuTorch Android Java/JNI package. */
2+
package org.pytorch.executorch.extension.audio;

extension/android/jni/BUCK

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,35 @@ non_fbcode_target(_kind = fb_android_cxx_library,
132132
],
133133
)
134134

135+
non_fbcode_target(_kind = fb_android_cxx_library,
136+
name = "executorch_whisper_jni",
137+
srcs = [
138+
"jni_layer.cpp",
139+
"jni_layer_whisper.cpp",
140+
"jni_layer_runtime.cpp",
141+
],
142+
allow_jni_merging = False,
143+
compiler_flags = ET_JNI_COMPILER_FLAGS + [
144+
"-DEXECUTORCH_BUILD_WHISPER_JNI",
145+
],
146+
soname = "libexecutorch.$(ext)",
147+
visibility = ["PUBLIC"],
148+
deps = [
149+
":jni_headers",
150+
":log_provider_static",
151+
"//fbandroid/libraries/fbjni:fbjni",
152+
"//fbandroid/native/fb:fb",
153+
"//third-party/glog:glog",
154+
"//xplat/executorch/backends/xnnpack:xnnpack_backend_static",
155+
"//xplat/executorch/examples/oss_scripts/qualcomm/whisper/runner:runner_static",
156+
"//xplat/executorch/extension/module:module_static",
157+
"//xplat/executorch/extension/runner_util:inputs_static",
158+
"//xplat/executorch/extension/tensor:tensor_static",
159+
"//xplat/executorch/extension/threadpool:cpuinfo_utils_static",
160+
"//xplat/executorch/extension/threadpool:threadpool_static",
161+
],
162+
)
163+
135164
non_fbcode_target(_kind = runtime.cxx_library,
136165
name = "log_provider",
137166
srcs = ["log.cpp"],

extension/android/jni/jni_layer.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
510510
};
511511
} // namespace executorch::extension
512512

513+
#ifdef EXECUTORCH_BUILD_WHISPER_JNI
514+
extern void register_natives_for_whisper();
515+
#else
516+
void register_natives_for_whisper() {}
517+
#endif
518+
513519
#ifdef EXECUTORCH_BUILD_LLAMA_JNI
514520
extern void register_natives_for_llm();
515521
#else
@@ -528,6 +534,7 @@ void register_natives_for_training() {}
528534
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
529535
return facebook::jni::initialize(vm, [] {
530536
executorch::extension::ExecuTorchJni::registerNatives();
537+
register_natives_for_whisper();
531538
register_natives_for_llm();
532539
register_natives_for_runtime();
533540
register_natives_for_training();

0 commit comments

Comments
 (0)