1313#include < unordered_map>
1414#include < vector>
1515
16- #include < executorch/examples/qualcomm/oss_scripts/whisper/ runner/runner .h>
16+ #include < executorch/extension/llm/ runner/asr_runner .h>
1717#include < executorch/runtime/platform/log.h>
1818#include < executorch/runtime/platform/platform.h>
1919#include < executorch/runtime/platform/runtime.h>
2323#include < executorch/extension/threadpool/threadpool.h>
2424#endif
2525
26+ #if defined(EXECUTORCH_BUILD_QNN)
27+ #include < executorch/examples/qualcomm/oss_scripts/whisper/runner/runner.h>
28+ #endif
29+
2630#include < fbjni/ByteBuffer.h>
2731#include < fbjni/fbjni.h>
2832
@@ -67,14 +71,14 @@ std::string token_buffer;
6771
6872namespace executorch_jni {
6973
70- class ExecuTorchWhisperCallbackJni
71- : public facebook::jni::JavaClass<ExecuTorchWhisperCallbackJni > {
74+ class ExecuTorchASRCallbackJni
75+ : public facebook::jni::JavaClass<ExecuTorchASRCallbackJni > {
7276 public:
7377 constexpr static const char * kJavaDescriptor =
74- " Lorg/pytorch/executorch/extension/audio/WhisperCallback ;" ;
78+ " Lorg/pytorch/executorch/extension/audio/ASRCallback ;" ;
7579
7680 void onResult (std::string result) const {
77- static auto cls = ExecuTorchWhisperCallbackJni ::javaClassStatic ();
81+ static auto cls = ExecuTorchASRCallbackJni ::javaClassStatic ();
7882 static const auto method =
7983 cls->getMethod <void (facebook::jni::local_ref<jstring>)>(" onResult" );
8084
@@ -91,15 +95,14 @@ class ExecuTorchWhisperCallbackJni
9195 }
9296};
9397
94- class ExecuTorchWhisperJni
95- : public facebook::jni::HybridClass<ExecuTorchWhisperJni> {
98+ class ExecuTorchASRJni : public facebook ::jni::HybridClass<ExecuTorchASRJni> {
9699 private:
97100 friend HybridBase;
98- std::unique_ptr<example::WhisperRunner > runner_;
101+ std::unique_ptr<::executorch::extension::llm::ASRRunner > runner_;
99102
100103 public:
101104 constexpr static auto kJavaDescriptor =
102- " Lorg/pytorch/executorch/extension/audio/WhisperModule ;" ;
105+ " Lorg/pytorch/executorch/extension/audio/ASRModule ;" ;
103106
104107 static facebook::jni::local_ref<jhybriddata> initHybrid (
105108 facebook::jni::alias_ref<jclass>,
@@ -108,7 +111,7 @@ class ExecuTorchWhisperJni
108111 return makeCxxInstance (model_path, tokenizer_path);
109112 }
110113
111- ExecuTorchWhisperJni (
114+ ExecuTorchASRJni (
112115 facebook::jni::alias_ref<jstring> model_path,
113116 facebook::jni::alias_ref<jstring> tokenizer_path) {
114117#if defined(ET_USE_THREADPOOL)
@@ -121,17 +124,18 @@ class ExecuTorchWhisperJni
121124 ->_unsafe_reset_threadpool(num_performant_cores);
122125 }
123126#endif
124-
127+ # if defined(EXECUTORCH_BUILD_QNN)
125128 // create runner
126129 runner_ = std::make_unique<example::WhisperRunner>(
127130 model_path->toStdString (), tokenizer_path->toStdString ());
131+ #endif
128132 }
129133
130134 jint transcribe (
131135 jint seq_len,
132136 facebook::jni::alias_ref<
133137 facebook::jni::JArrayClass<jbyteArray>::javaobject> inputs,
134- facebook::jni::alias_ref<ExecuTorchWhisperCallbackJni > callback) {
138+ facebook::jni::alias_ref<ExecuTorchASRCallbackJni > callback) {
135139 // Convert Java byte[][] to C++ vector<vector<char>>
136140 std::vector<std::vector<char >> cppData;
137141 auto input_size = inputs->size ();
@@ -162,15 +166,15 @@ class ExecuTorchWhisperJni
162166
163167 static void registerNatives () {
164168 registerHybrid ({
165- makeNativeMethod (" initHybrid" , ExecuTorchWhisperJni ::initHybrid),
166- makeNativeMethod (" transcribe" , ExecuTorchWhisperJni ::transcribe),
167- makeNativeMethod (" load" , ExecuTorchWhisperJni ::load),
169+ makeNativeMethod (" initHybrid" , ExecuTorchASRJni ::initHybrid),
170+ makeNativeMethod (" transcribe" , ExecuTorchASRJni ::transcribe),
171+ makeNativeMethod (" load" , ExecuTorchASRJni ::load),
168172 });
169173 }
170174};
171175
172176} // namespace executorch_jni
173177
174- void register_natives_for_whisper () {
175- executorch_jni::ExecuTorchWhisperJni ::registerNatives ();
178+ void register_natives_for_asr () {
179+ executorch_jni::ExecuTorchASRJni ::registerNatives ();
176180}
0 commit comments