Skip to content

Commit d5ed11f

Browse files
committed
Changed API Whisper -> ASR
1 parent 3b8ef4e commit d5ed11f

File tree

11 files changed

+104
-38
lines changed

11 files changed

+104
-38
lines changed

examples/qualcomm/oss_scripts/whisper/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ set(_qnn_whisper_runner__srcs
1414
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
1515
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
1616
${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp
17+
${EXECUTORCH_ROOT}/extension/llm/runner/asr_runner.h
1718
)
1819

1920
# build qnn whisper runner

examples/qualcomm/oss_scripts/whisper/runner/runner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/decoder.h>
1515
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/encoder.h>
16+
#include <executorch/extension/llm/runner/asr_runner.h>
1617
#include <executorch/extension/llm/sampler/sampler.h>
1718
#include <executorch/runtime/core/error.h>
1819
#include <pytorch/tokenizers/tokenizer.h>
@@ -24,7 +25,7 @@
2425

2526
namespace example {
2627

27-
class WhisperRunner {
28+
class WhisperRunner : public executorch::extension::llm::ASRRunner {
2829
public:
2930
explicit WhisperRunner(
3031
const std::string& model_path,

extension/android/CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,10 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
239239
endif()
240240
endif()
241241

242-
if(EXECUTORCH_BUILD_WHISPER_JNI)
243-
target_sources(executorch_jni PRIVATE jni/jni_layer_whisper.cpp jni/log.cpp)
244-
target_compile_definitions(
245-
executorch_jni PUBLIC EXECUTORCH_BUILD_WHISPER_JNI=1
246-
)
242+
if(EXECUTORCH_BUILD_ASR_JNI)
243+
target_sources(executorch_jni PRIVATE jni/jni_layer_asr.cpp jni/log.cpp)
244+
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_ASR_JNI=1)
245+
247246
if(QNN_SDK_ROOT)
248247
target_sources(
249248
executorch_jni
@@ -257,6 +256,7 @@ if(EXECUTORCH_BUILD_WHISPER_JNI)
257256
executorch_jni
258257
PRIVATE ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/whisper/runner
259258
)
259+
target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_QNN=1)
260260
endif()
261261
endif()
262262

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
* <p>Warning: These APIs are experimental and subject to change without notice
1919
*/
2020
@Experimental
21-
public interface WhisperCallback {
21+
public interface ASRCallback {
2222
/**
2323
* Called when a new result is available from JNI. Users will keep getting onResult() invocations
2424
* until generate() finishes.
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,20 @@
1414
import org.pytorch.executorch.annotations.Experimental;
1515

1616
/**
17-
* WhisperModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text
18-
* from the model.
17+
* ASRModule is a wrapper around the Executorch ASR runners like Whisper runner.
1918
*
2019
* <p>Warning: These APIs are experimental and subject to change without notice
2120
*/
2221
@Experimental
23-
public class WhisperModule {
22+
public class ASRModule {
2423

2524
@DoNotStrip private final HybridData mHybridData;
2625

2726
@DoNotStrip
2827
private static native HybridData initHybrid(
2928
String modulePath, String tokenizerPath);
3029

31-
public WhisperModule(
30+
public ASRModule(
3231
String modulePath, String tokenizerPath) {
3332
ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime();
3433

@@ -51,7 +50,7 @@ public void resetNative() {
5150
public native int transcribe(
5251
int seqLen,
5352
byte[][] inputs,
54-
WhisperCallback callback);
53+
ASRCallback callback);
5554

5655

5756
/** Force loading the module. Otherwise the model is loaded during first generate(). */
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
/** Extension for LLM related use cases for ExecuTorch Android Java/JNI package. */
1+
/** Extension for ASR related use cases for ExecuTorch Android Java/JNI package. */
22
package org.pytorch.executorch.extension.audio;

extension/android/jni/BUCK

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,15 @@ non_fbcode_target(_kind = fb_android_cxx_library,
123123
)
124124

125125
non_fbcode_target(_kind = fb_android_cxx_library,
126-
name = "executorch_whisper_jni",
126+
name = "executorch_asr_jni",
127127
srcs = [
128128
"jni_layer.cpp",
129-
"jni_layer_whisper.cpp",
129+
"jni_layer_asr.cpp",
130130
"jni_layer_runtime.cpp",
131131
],
132132
allow_jni_merging = False,
133133
compiler_flags = ET_JNI_COMPILER_FLAGS + [
134-
"-DEXECUTORCH_BUILD_WHISPER_JNI",
134+
"-DEXECUTORCH_BUILD_ASR_JNI",
135135
],
136136
soname = "libexecutorch.$(ext)",
137137
visibility = ["PUBLIC"],

extension/android/jni/jni_layer.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -508,10 +508,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
508508
};
509509
} // namespace executorch::extension
510510

511-
#ifdef EXECUTORCH_BUILD_WHISPER_JNI
512-
extern void register_natives_for_whisper();
511+
#ifdef EXECUTORCH_BUILD_ASR_JNI
512+
extern void register_natives_for_asr();
513513
#else
514-
void register_natives_for_whisper() {}
514+
void register_natives_for_asr() {}
515515
#endif
516516

517517
#ifdef EXECUTORCH_BUILD_LLAMA_JNI
@@ -532,7 +532,7 @@ void register_natives_for_training() {}
532532
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
533533
return facebook::jni::initialize(vm, [] {
534534
executorch::extension::ExecuTorchJni::registerNatives();
535-
register_natives_for_whisper();
535+
register_natives_for_asr();
536536
register_natives_for_llm();
537537
register_natives_for_runtime();
538538
register_natives_for_training();

extension/android/jni/jni_layer_whisper.cpp renamed to extension/android/jni/jni_layer_asr.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include <unordered_map>
1414
#include <vector>
1515

16-
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/runner.h>
16+
#include <executorch/extension/llm/runner/asr_runner.h>
1717
#include <executorch/runtime/platform/log.h>
1818
#include <executorch/runtime/platform/platform.h>
1919
#include <executorch/runtime/platform/runtime.h>
@@ -23,6 +23,10 @@
2323
#include <executorch/extension/threadpool/threadpool.h>
2424
#endif
2525

26+
#if defined(EXECUTORCH_BUILD_QNN)
27+
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/runner.h>
28+
#endif
29+
2630
#include <fbjni/ByteBuffer.h>
2731
#include <fbjni/fbjni.h>
2832

@@ -67,14 +71,14 @@ std::string token_buffer;
6771

6872
namespace executorch_jni {
6973

70-
class ExecuTorchWhisperCallbackJni
71-
: public facebook::jni::JavaClass<ExecuTorchWhisperCallbackJni> {
74+
class ExecuTorchASRCallbackJni
75+
: public facebook::jni::JavaClass<ExecuTorchASRCallbackJni> {
7276
public:
7377
constexpr static const char* kJavaDescriptor =
74-
"Lorg/pytorch/executorch/extension/audio/WhisperCallback;";
78+
"Lorg/pytorch/executorch/extension/audio/ASRCallback;";
7579

7680
void onResult(std::string result) const {
77-
static auto cls = ExecuTorchWhisperCallbackJni::javaClassStatic();
81+
static auto cls = ExecuTorchASRCallbackJni::javaClassStatic();
7882
static const auto method =
7983
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");
8084

@@ -91,15 +95,14 @@ class ExecuTorchWhisperCallbackJni
9195
}
9296
};
9397

94-
class ExecuTorchWhisperJni
95-
: public facebook::jni::HybridClass<ExecuTorchWhisperJni> {
98+
class ExecuTorchASRJni : public facebook::jni::HybridClass<ExecuTorchASRJni> {
9699
private:
97100
friend HybridBase;
98-
std::unique_ptr<example::WhisperRunner> runner_;
101+
std::unique_ptr<::executorch::extension::llm::ASRRunner> runner_;
99102

100103
public:
101104
constexpr static auto kJavaDescriptor =
102-
"Lorg/pytorch/executorch/extension/audio/WhisperModule;";
105+
"Lorg/pytorch/executorch/extension/audio/ASRModule;";
103106

104107
static facebook::jni::local_ref<jhybriddata> initHybrid(
105108
facebook::jni::alias_ref<jclass>,
@@ -108,7 +111,7 @@ class ExecuTorchWhisperJni
108111
return makeCxxInstance(model_path, tokenizer_path);
109112
}
110113

111-
ExecuTorchWhisperJni(
114+
ExecuTorchASRJni(
112115
facebook::jni::alias_ref<jstring> model_path,
113116
facebook::jni::alias_ref<jstring> tokenizer_path) {
114117
#if defined(ET_USE_THREADPOOL)
@@ -121,17 +124,18 @@ class ExecuTorchWhisperJni
121124
->_unsafe_reset_threadpool(num_performant_cores);
122125
}
123126
#endif
124-
127+
#if defined(EXECUTORCH_BUILD_QNN)
125128
// create runner
126129
runner_ = std::make_unique<example::WhisperRunner>(
127130
model_path->toStdString(), tokenizer_path->toStdString());
131+
#endif
128132
}
129133

130134
jint transcribe(
131135
jint seq_len,
132136
facebook::jni::alias_ref<
133137
facebook::jni::JArrayClass<jbyteArray>::javaobject> inputs,
134-
facebook::jni::alias_ref<ExecuTorchWhisperCallbackJni> callback) {
138+
facebook::jni::alias_ref<ExecuTorchASRCallbackJni> callback) {
135139
// Convert Java byte[][] to C++ vector<vector<char>>
136140
std::vector<std::vector<char>> cppData;
137141
auto input_size = inputs->size();
@@ -162,15 +166,15 @@ class ExecuTorchWhisperJni
162166

163167
static void registerNatives() {
164168
registerHybrid({
165-
makeNativeMethod("initHybrid", ExecuTorchWhisperJni::initHybrid),
166-
makeNativeMethod("transcribe", ExecuTorchWhisperJni::transcribe),
167-
makeNativeMethod("load", ExecuTorchWhisperJni::load),
169+
makeNativeMethod("initHybrid", ExecuTorchASRJni::initHybrid),
170+
makeNativeMethod("transcribe", ExecuTorchASRJni::transcribe),
171+
makeNativeMethod("load", ExecuTorchASRJni::load),
168172
});
169173
}
170174
};
171175

172176
} // namespace executorch_jni
173177

174-
void register_natives_for_whisper() {
175-
executorch_jni::ExecuTorchWhisperJni::registerNatives();
178+
void register_natives_for_asr() {
179+
executorch_jni::ExecuTorchASRJni::registerNatives();
176180
}

extension/llm/runner/asr_runner.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Interface for audio-to-text model runners. Currently only used for
10+
// supporting QNN Whisper Runner
11+
12+
#pragma once
13+
14+
#include <cstdint>
15+
#include <functional>
16+
#include <memory>
17+
#include <string>
18+
19+
#include <executorch/extension/llm/runner/stats.h>
20+
#include <executorch/runtime/core/error.h>
21+
22+
namespace executorch {
23+
namespace extension {
24+
namespace llm {
25+
26+
class ET_EXPERIMENTAL ASRRunner {
27+
public:
28+
virtual ~ASRRunner() = default;
29+
30+
/**
31+
* Check if the runner is loaded and ready for inference.
32+
*
33+
* @return true if the runner is loaded, false otherwise
34+
*/
35+
virtual bool is_loaded() const = 0;
36+
37+
/**
38+
* Load the model and prepare for inference.
39+
*
40+
* @return Error::Ok if successful, an error otherwise
41+
*/
42+
virtual runtime::Error load() = 0;
43+
44+
/**
45+
* Generate text from raw audio.
46+
*
47+
* @param seq_len Length of input sequence
48+
* @param inputs A vector containing one element: a vector of bytes that
49+
* encodes a float tensor in little-endian byte order
50+
* @param token_callback Callback function called for each generated token
51+
* @return Error::Ok if successful, an error otherwise
52+
*/
53+
virtual runtime::Error transcribe(
54+
int32_t seq_len,
55+
std::vector<std::vector<char>>& inputs,
56+
std::function<void(const std::string&)> token_callback = {}) = 0;
57+
};
58+
59+
} // namespace llm
60+
} // namespace extension
61+
} // namespace executorch

0 commit comments

Comments
 (0)