Skip to content

Commit 899206b

Browse files
committed
Changed API Whisper -> ASR
1 parent 66c9463 commit 899206b

File tree

12 files changed

+113
-43
lines changed

12 files changed

+113
-43
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,12 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
630630
list(APPEND _executorch_extensions extension_module_static)
631631
endif()
632632

633+
if(EXECUTORCH_BUILD_EXTENSION_AUDIO)
634+
message(STATUS "Audio/ASR extension enabled")
635+
endif()
636+
633637
if(EXECUTORCH_BUILD_EXTENSION_LLM)
638+
message(STATUS "LLM extension enabled")
634639
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
635640
set(SUPPORT_REGEX_LOOKAHEAD ON)
636641
# llama/runner/CMakeLists.txt builds a shared library libllama_runner.so

examples/qualcomm/oss_scripts/whisper/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ set(_qnn_whisper_runner__srcs
1414
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
1515
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
1616
${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp
17+
${EXECUTORCH_ROOT}/extension/llm/runner/asr_runner.h
1718
)
1819

1920
# build qnn whisper runner

examples/qualcomm/oss_scripts/whisper/runner/runner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/decoder.h>
1515
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/encoder.h>
16+
#include <executorch/extension/llm/runner/asr_runner.h>
1617
#include <executorch/extension/llm/sampler/sampler.h>
1718
#include <executorch/runtime/core/error.h>
1819
#include <pytorch/tokenizers/tokenizer.h>
@@ -24,7 +25,7 @@
2425

2526
namespace example {
2627

27-
class WhisperRunner {
28+
class WhisperRunner : public executorch::extension::llm::ASRRunner {
2829
public:
2930
explicit WhisperRunner(
3031
const std::string& model_path,

extension/android/CMakeLists.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,10 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
166166
)
167167
endif()
168168

169-
if(EXECUTORCH_BUILD_LLAMA_JNI)
169+
if(EXECUTORCH_BUILD_EXTENSION_LLM)
170170
target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp)
171171
list(APPEND link_libraries extension_llm_runner)
172-
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1)
172+
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_LLM=1)
173173

174174
if(QNN_SDK_ROOT)
175175
target_sources(
@@ -221,11 +221,10 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
221221
endif()
222222
endif()
223223

224-
if(EXECUTORCH_BUILD_WHISPER_JNI)
225-
target_sources(executorch_jni PRIVATE jni/jni_layer_whisper.cpp jni/log.cpp)
226-
target_compile_definitions(
227-
executorch_jni PUBLIC EXECUTORCH_BUILD_WHISPER_JNI=1
228-
)
224+
if(EXECUTORCH_BUILD_EXTENSION_AUDIO)
225+
target_sources(executorch_jni PRIVATE jni/jni_layer_asr.cpp jni/log.cpp)
226+
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_AUDIO=1)
227+
229228
if(QNN_SDK_ROOT)
230229
target_sources(
231230
executorch_jni
@@ -239,6 +238,7 @@ if(EXECUTORCH_BUILD_WHISPER_JNI)
239238
executorch_jni
240239
PRIVATE ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner
241240
)
241+
target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_QNN=1)
242242
endif()
243243
endif()
244244

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
* <p>Warning: These APIs are experimental and subject to change without notice
1919
*/
2020
@Experimental
21-
public interface WhisperCallback {
21+
public interface ASRCallback {
2222
/**
2323
* Called when a new result is available from JNI. Users will keep getting onResult() invocations
2424
* until generate() finishes.
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,20 @@
1414
import org.pytorch.executorch.annotations.Experimental;
1515

1616
/**
17-
* WhisperModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text
18-
* from the model.
17+
* ASRModule is a wrapper around the Executorch ASR runners like Whisper runner.
1918
*
2019
* <p>Warning: These APIs are experimental and subject to change without notice
2120
*/
2221
@Experimental
23-
public class WhisperModule {
22+
public class ASRModule {
2423

2524
@DoNotStrip private final HybridData mHybridData;
2625

2726
@DoNotStrip
2827
private static native HybridData initHybrid(
2928
String modulePath, String tokenizerPath);
3029

31-
public WhisperModule(
30+
public ASRModule(
3231
String modulePath, String tokenizerPath) {
3332
ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime();
3433

@@ -51,7 +50,7 @@ public void resetNative() {
5150
public native int transcribe(
5251
int seqLen,
5352
byte[][] inputs,
54-
WhisperCallback callback);
53+
ASRCallback callback);
5554

5655

5756
/** Force loading the module. Otherwise the model is loaded during first generate(). */
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
/** Extension for LLM related use cases for ExecuTorch Android Java/JNI package. */
1+
/** Extension for ASR related use cases for ExecuTorch Android Java/JNI package. */
22
package org.pytorch.executorch.extension.audio;

extension/android/jni/BUCK

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ non_fbcode_target(_kind = fb_android_cxx_library,
101101
srcs = ["jni_layer.cpp", "jni_layer_llama.cpp", "jni_layer_runtime.cpp", "jni_helper.cpp"],
102102
allow_jni_merging = False,
103103
compiler_flags = ET_JNI_COMPILER_FLAGS + [
104-
"-DEXECUTORCH_BUILD_LLAMA_JNI",
104+
"-DEXECUTORCH_BUILD_EXTENSION_LLM",
105105
],
106106
soname = "libexecutorch.$(ext)",
107107
visibility = ["PUBLIC"],
@@ -122,15 +122,15 @@ non_fbcode_target(_kind = fb_android_cxx_library,
122122
)
123123

124124
non_fbcode_target(_kind = fb_android_cxx_library,
125-
name = "executorch_whisper_jni",
125+
name = "executorch_asr_jni",
126126
srcs = [
127127
"jni_layer.cpp",
128-
"jni_layer_whisper.cpp",
128+
"jni_layer_asr.cpp",
129129
"jni_layer_runtime.cpp",
130130
],
131131
allow_jni_merging = False,
132132
compiler_flags = ET_JNI_COMPILER_FLAGS + [
133-
"-DEXECUTORCH_BUILD_WHISPER_JNI",
133+
"-DEXECUTORCH_BUILD_EXTENSION_AUDIO",
134134
],
135135
soname = "libexecutorch.$(ext)",
136136
visibility = ["PUBLIC"],

extension/android/jni/jni_layer.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -508,13 +508,13 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
508508
};
509509
} // namespace executorch::extension
510510

511-
#ifdef EXECUTORCH_BUILD_WHISPER_JNI
512-
extern void register_natives_for_whisper();
511+
#ifdef EXECUTORCH_BUILD_EXTENSION_AUDIO
512+
extern void register_natives_for_asr();
513513
#else
514-
void register_natives_for_whisper() {}
514+
void register_natives_for_asr() {}
515515
#endif
516516

517-
#ifdef EXECUTORCH_BUILD_LLAMA_JNI
517+
#ifdef EXECUTORCH_BUILD_EXTENSION_LLM
518518
extern void register_natives_for_llm();
519519
#else
520520
// No op if we don't build LLM
@@ -532,7 +532,7 @@ void register_natives_for_training() {}
532532
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
533533
return facebook::jni::initialize(vm, [] {
534534
executorch::extension::ExecuTorchJni::registerNatives();
535-
register_natives_for_whisper();
535+
register_natives_for_asr();
536536
register_natives_for_llm();
537537
register_natives_for_runtime();
538538
register_natives_for_training();

extension/android/jni/jni_layer_whisper.cpp renamed to extension/android/jni/jni_layer_asr.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include <unordered_map>
1414
#include <vector>
1515

16-
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/runner.h>
16+
#include <executorch/extension/llm/runner/asr_runner.h>
1717
#include <executorch/runtime/platform/log.h>
1818
#include <executorch/runtime/platform/platform.h>
1919
#include <executorch/runtime/platform/runtime.h>
@@ -23,6 +23,10 @@
2323
#include <executorch/extension/threadpool/threadpool.h>
2424
#endif
2525

26+
#if defined(EXECUTORCH_BUILD_QNN)
27+
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/runner.h>
28+
#endif
29+
2630
#include <fbjni/ByteBuffer.h>
2731
#include <fbjni/fbjni.h>
2832

@@ -67,14 +71,14 @@ std::string token_buffer;
6771

6872
namespace executorch_jni {
6973

70-
class ExecuTorchWhisperCallbackJni
71-
: public facebook::jni::JavaClass<ExecuTorchWhisperCallbackJni> {
74+
class ExecuTorchASRCallbackJni
75+
: public facebook::jni::JavaClass<ExecuTorchASRCallbackJni> {
7276
public:
7377
constexpr static const char* kJavaDescriptor =
74-
"Lorg/pytorch/executorch/extension/audio/WhisperCallback;";
78+
"Lorg/pytorch/executorch/extension/audio/ASRCallback;";
7579

7680
void onResult(std::string result) const {
77-
static auto cls = ExecuTorchWhisperCallbackJni::javaClassStatic();
81+
static auto cls = ExecuTorchASRCallbackJni::javaClassStatic();
7882
static const auto method =
7983
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");
8084

@@ -91,15 +95,14 @@ class ExecuTorchWhisperCallbackJni
9195
}
9296
};
9397

94-
class ExecuTorchWhisperJni
95-
: public facebook::jni::HybridClass<ExecuTorchWhisperJni> {
98+
class ExecuTorchASRJni : public facebook::jni::HybridClass<ExecuTorchASRJni> {
9699
private:
97100
friend HybridBase;
98-
std::unique_ptr<example::WhisperRunner> runner_;
101+
std::unique_ptr<::executorch::extension::llm::ASRRunner> runner_;
99102

100103
public:
101104
constexpr static auto kJavaDescriptor =
102-
"Lorg/pytorch/executorch/extension/audio/WhisperModule;";
105+
"Lorg/pytorch/executorch/extension/audio/ASRModule;";
103106

104107
static facebook::jni::local_ref<jhybriddata> initHybrid(
105108
facebook::jni::alias_ref<jclass>,
@@ -108,7 +111,7 @@ class ExecuTorchWhisperJni
108111
return makeCxxInstance(model_path, tokenizer_path);
109112
}
110113

111-
ExecuTorchWhisperJni(
114+
ExecuTorchASRJni(
112115
facebook::jni::alias_ref<jstring> model_path,
113116
facebook::jni::alias_ref<jstring> tokenizer_path) {
114117
#if defined(ET_USE_THREADPOOL)
@@ -121,17 +124,18 @@ class ExecuTorchWhisperJni
121124
->_unsafe_reset_threadpool(num_performant_cores);
122125
}
123126
#endif
124-
127+
#if defined(EXECUTORCH_BUILD_QNN)
125128
// create runner
126129
runner_ = std::make_unique<example::WhisperRunner>(
127130
model_path->toStdString(), tokenizer_path->toStdString());
131+
#endif
128132
}
129133

130134
jint transcribe(
131135
jint seq_len,
132136
facebook::jni::alias_ref<
133137
facebook::jni::JArrayClass<jbyteArray>::javaobject> inputs,
134-
facebook::jni::alias_ref<ExecuTorchWhisperCallbackJni> callback) {
138+
facebook::jni::alias_ref<ExecuTorchASRCallbackJni> callback) {
135139
// Convert Java byte[][] to C++ vector<vector<char>>
136140
std::vector<std::vector<char>> cppData;
137141
auto input_size = inputs->size();
@@ -162,15 +166,15 @@ class ExecuTorchWhisperJni
162166

163167
static void registerNatives() {
164168
registerHybrid({
165-
makeNativeMethod("initHybrid", ExecuTorchWhisperJni::initHybrid),
166-
makeNativeMethod("transcribe", ExecuTorchWhisperJni::transcribe),
167-
makeNativeMethod("load", ExecuTorchWhisperJni::load),
169+
makeNativeMethod("initHybrid", ExecuTorchASRJni::initHybrid),
170+
makeNativeMethod("transcribe", ExecuTorchASRJni::transcribe),
171+
makeNativeMethod("load", ExecuTorchASRJni::load),
168172
});
169173
}
170174
};
171175

172176
} // namespace executorch_jni
173177

174-
void register_natives_for_whisper() {
175-
executorch_jni::ExecuTorchWhisperJni::registerNatives();
178+
void register_natives_for_asr() {
179+
executorch_jni::ExecuTorchASRJni::registerNatives();
176180
}

0 commit comments

Comments
 (0)