Skip to content

Commit 6afb221

Browse files
committed
Addressed feedback
1 parent 899206b commit 6afb221

File tree

13 files changed

+134
-268
lines changed

13 files changed

+134
-268
lines changed

backends/qualcomm/scripts/build.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ if [ "$BUILD_AARCH64" = true ]; then
8181
-DCMAKE_BUILD_TYPE=$BUILD_TYPE \
8282
-DEXECUTORCH_BUILD_QNN=ON \
8383
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
84+
-DEXECUTORCH_BUILD_EXTENSION_AUDIO=ON \
8485
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON \
8586
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \
8687
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
@@ -150,6 +151,7 @@ if [ "$BUILD_X86_64" = true ]; then
150151
-DQNN_SDK_ROOT=${QNN_SDK_ROOT} \
151152
-DEXECUTORCH_BUILD_QNN=ON \
152153
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
154+
-DEXECUTORCH_BUILD_EXTENSION_AUDIO=ON \
153155
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON \
154156
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \
155157
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \

examples/qualcomm/oss_scripts/whisper/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ set(_qnn_whisper_runner__srcs
1414
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
1515
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
1616
${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp
17-
${EXECUTORCH_ROOT}/extension/llm/runner/asr_runner.h
17+
${EXECUTORCH_ROOT}/extension/audio/runner/asr_runner.h
1818
)
1919

2020
# build qnn whisper runner

examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
*/
1515

1616
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/runner.h>
17+
#include <executorch/extension/llm/runner/audio.h>
1718
#include <executorch/runtime/platform/log.h>
1819
#include <gflags/gflags.h>
1920
#include <fstream>
@@ -110,7 +111,14 @@ int main(int argc, char** argv) {
110111
}
111112
};
112113
// generate tokens
113-
runner.transcribe(FLAGS_seq_len, multi_turns_input_buffers[iter], callback);
114+
executorch::extension::llm::Audio audio{
115+
std::vector<uint8_t>(
116+
multi_turns_input_buffers[iter][0].begin(),
117+
multi_turns_input_buffers[iter][0].end()),
118+
1,
119+
80,
120+
3000};
121+
runner.transcribe(FLAGS_seq_len, audio, callback);
114122
auto output_file_name =
115123
FLAGS_output_folder_path + "/output_" + std::to_string(iter) + ".txt";
116124
std::ofstream fout(output_file_name);

examples/qualcomm/oss_scripts/whisper/runner/runner.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,29 +112,24 @@ uint64_t WhisperRunner::logits_to_token(
112112
const executorch::aten::Tensor& logits_tensor) {
113113
return sampler_->sample(logits_tensor.data_ptr<float>());
114114
}
115-
/**
116-
* @param inputs: A vector containing one element: a vector of bytes that
117-
* encodes a float tensor in little-endian byte order.
118-
*
119-
*/
120115
Error WhisperRunner::transcribe(
121116
int32_t seq_len,
122-
std::vector<std::vector<char>>& inputs,
123-
std::function<void(const std::string&)> token_callback) {
117+
executorch::extension::llm::Audio& audio,
118+
std::function<void(const std::string&)> token_callback,
119+
std::function<void(const executorch::extension::llm::Stats&)>
120+
stats_callback) {
124121
if (!is_loaded()) {
125122
stats_.model_load_start_ms = time_in_ms();
126123
ET_CHECK_OK_OR_RETURN_ERROR(load());
127124
stats_.model_load_end_ms = time_in_ms();
128125
}
129-
ET_CHECK_MSG(inputs.size() == 1, "The input size of whisper should be one.");
130-
131126
ET_LOG(Info, "Start Encoding");
132127
stats_.encoder_inference_start_ms = time_in_ms();
133128
auto input_features_tensor_ptr = from_blob(
134-
inputs[0].data(),
129+
audio.data.data(),
135130
// (1, processor.feature_extractor.feature_size,
136131
// processor.feature_extractor.nb_max_frames)
137-
{1, 80, 3000},
132+
{audio.batch_size, audio.n_bins, audio.n_frames}, // {1, 80, 3000}
138133
ScalarType::Float);
139134
Result<Tensor> encoder_out = encoder_->encode(input_features_tensor_ptr);
140135
auto encoder_out_tensor_ptr = make_tensor_ptr(encoder_out.get());

examples/qualcomm/oss_scripts/whisper/runner/runner.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/decoder.h>
1515
#include <executorch/examples/qualcomm/oss_scripts/whisper/runner/encoder.h>
16-
#include <executorch/extension/llm/runner/asr_runner.h>
16+
#include <executorch/extension/audio/runner/asr_runner.h>
17+
#include <executorch/extension/llm/runner/audio.h>
18+
#include <executorch/extension/llm/runner/stats.h>
1719
#include <executorch/extension/llm/sampler/sampler.h>
1820
#include <executorch/runtime/core/error.h>
1921
#include <pytorch/tokenizers/tokenizer.h>
@@ -25,7 +27,7 @@
2527

2628
namespace example {
2729

28-
class WhisperRunner : public executorch::extension::llm::ASRRunner {
30+
class WhisperRunner : public executorch::extension::audio::ASRRunner {
2931
public:
3032
explicit WhisperRunner(
3133
const std::string& model_path,
@@ -52,8 +54,10 @@ class WhisperRunner : public executorch::extension::llm::ASRRunner {
5254
executorch::runtime::Error load();
5355
executorch::runtime::Error transcribe(
5456
int32_t seq_len,
55-
std::vector<std::vector<char>>& inputs,
56-
std::function<void(const std::string&)> token_callback = {});
57+
executorch::extension::llm::Audio& audio,
58+
std::function<void(const std::string&)> token_callback = {},
59+
std::function<void(const executorch::extension::llm::Stats&)>
60+
stats_callback = {});
5761

5862
private:
5963
executorch::runtime::Error print_performance();

extension/android/CMakeLists.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ endif()
169169
if(EXECUTORCH_BUILD_EXTENSION_LLM)
170170
target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp)
171171
list(APPEND link_libraries extension_llm_runner)
172-
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_LLM=1)
172+
target_compile_definitions(
173+
executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_LLM=1
174+
)
173175

174176
if(QNN_SDK_ROOT)
175177
target_sources(
@@ -222,8 +224,10 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM)
222224
endif()
223225

224226
if(EXECUTORCH_BUILD_EXTENSION_AUDIO)
225-
target_sources(executorch_jni PRIVATE jni/jni_layer_asr.cpp jni/log.cpp)
226-
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_AUDIO=1)
227+
target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp)
228+
target_compile_definitions(
229+
executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_AUDIO=1
230+
)
227231

228232
if(QNN_SDK_ROOT)
229233
target_sources(

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRCallback.java

Lines changed: 0 additions & 31 deletions
This file was deleted.

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/audio/ASRModule.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import com.facebook.jni.annotations.DoNotStrip;
1212
import java.io.File;
1313
import org.pytorch.executorch.ExecuTorchRuntime;
14+
import org.pytorch.executorch.extension.llm.LlmCallback;
1415
import org.pytorch.executorch.annotations.Experimental;
1516

1617
/**
@@ -50,8 +51,9 @@ public void resetNative() {
5051
public native int transcribe(
5152
int seqLen,
5253
byte[][] inputs,
53-
ASRCallback callback);
54-
54+
LlmCallback callback,
55+
int n_bins,
56+
int n_frames);
5557

5658
/** Force loading the module. Otherwise the model is loaded during first generate(). */
5759
@DoNotStrip
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
/** Extension for ASR related use cases for ExecuTorch Android Java/JNI package. */
1+
/** Extension for audio and ASR related use cases for ExecuTorch Android Java/JNI package. */
22
package org.pytorch.executorch.extension.audio;

extension/android/jni/BUCK

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -121,35 +121,6 @@ non_fbcode_target(_kind = fb_android_cxx_library,
121121
],
122122
)
123123

124-
non_fbcode_target(_kind = fb_android_cxx_library,
125-
name = "executorch_asr_jni",
126-
srcs = [
127-
"jni_layer.cpp",
128-
"jni_layer_asr.cpp",
129-
"jni_layer_runtime.cpp",
130-
],
131-
allow_jni_merging = False,
132-
compiler_flags = ET_JNI_COMPILER_FLAGS + [
133-
"-DEXECUTORCH_BUILD_EXTENSION_AUDIO",
134-
],
135-
soname = "libexecutorch.$(ext)",
136-
visibility = ["PUBLIC"],
137-
deps = [
138-
":jni_headers",
139-
":log_provider_static",
140-
"//fbandroid/libraries/fbjni:fbjni",
141-
"//fbandroid/native/fb:fb",
142-
"//third-party/glog:glog",
143-
"//xplat/executorch/backends/xnnpack:xnnpack_backend_static",
144-
"//xplat/executorch/examples/oss_scripts/qualcomm/whisper/runner:runner_static",
145-
"//xplat/executorch/extension/module:module_static",
146-
"//xplat/executorch/extension/runner_util:inputs_static",
147-
"//xplat/executorch/extension/tensor:tensor_static",
148-
"//xplat/executorch/extension/threadpool:cpuinfo_utils_static",
149-
"//xplat/executorch/extension/threadpool:threadpool_static",
150-
],
151-
)
152-
153124
non_fbcode_target(_kind = runtime.cxx_library,
154125
name = "log_provider",
155126
srcs = ["log.cpp"],

0 commit comments

Comments
 (0)