Skip to content

Commit e9fa5e3

Browse files
Qualcomm AI Engine Direct - GA model enablement (T5) (#12234)
### Summary - e2e script / test case for GA [T5](https://huggingface.co/google-t5/t5-small) model - perf: 16a8w avg encoding time: 4.09ms/inf, avg decoding time: 6ms/inf (SM8750) - acc: F1 Score ~= 76% in [SQuAD](https://www.kaggle.com/datasets/akashdesarda/squad-v11?select=SQuAD-v1.1.csv) - add QA dataset for Seq2SeqLM benchmarking ### Test plan ``` python -m examples.qualcomm.oss_scripts.t5.t5 -b build-android -m ${soc} -H ${host_id} -s ${device_id} -d ./SQuAD-v1.1.csv ``` cc: @haowhsu-quic,@cccclai
1 parent 0012ffa commit e9fa5e3

File tree

14 files changed

+1921
-8
lines changed

14 files changed

+1921
-8
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5022,6 +5022,40 @@ def test_swin_transformer(self):
50225022
self.assertGreaterEqual(msg["top_1"], 60)
50235023
self.assertGreaterEqual(msg["top_5"], 80)
50245024

5025+
def test_t5(self):
5026+
if not self.required_envs([self.qa_dataset]):
5027+
self.skipTest("missing required envs")
5028+
cmds = [
5029+
"python",
5030+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/t5/t5.py",
5031+
"--dataset",
5032+
self.sentence_dataset,
5033+
"--artifact",
5034+
self.artifact_dir,
5035+
"--build_folder",
5036+
self.build_folder,
5037+
"--device",
5038+
self.device,
5039+
"--model",
5040+
self.model,
5041+
"--ip",
5042+
self.ip,
5043+
"--port",
5044+
str(self.port),
5045+
]
5046+
if self.host:
5047+
cmds.extend(["--host", self.host])
5048+
5049+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
5050+
with Listener((self.ip, self.port)) as listener:
5051+
conn = listener.accept()
5052+
p.communicate()
5053+
msg = json.loads(conn.recv())
5054+
if "Error" in msg:
5055+
self.fail(msg["Error"])
5056+
else:
5057+
self.assertGreaterEqual(msg["f1"], 0.7)
5058+
50255059
def test_whisper(self):
50265060
if not self.required_envs():
50275061
self.skipTest("missing required envs")

backends/qualcomm/tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ class TestQNN(unittest.TestCase):
183183
executorch_root: str = ""
184184
artifact_dir: str = ""
185185
image_dataset: str = ""
186+
qa_dataset: str = ""
186187
sentence_dataset: str = ""
187188
pretrained_weight: str = ""
188189
enable_profile: bool = False

examples/qualcomm/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama)
9090
# build qnn_mimi_decoder_runner
9191
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/moshi)
9292

93+
# build qnn_t5_runner for t5
94+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/t5)
95+
9396
# build qnn_whisper_runner for whisper
9497
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/whisper)
9598

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
# preprocess qnn runner src files for t5
9+
set(_qnn_t5_runner__srcs
10+
${CMAKE_CURRENT_LIST_DIR}/qnn_t5_runner.cpp
11+
${CMAKE_CURRENT_LIST_DIR}/runner/decoder.cpp
12+
${CMAKE_CURRENT_LIST_DIR}/runner/decoder.h
13+
${CMAKE_CURRENT_LIST_DIR}/runner/encoder.cpp
14+
${CMAKE_CURRENT_LIST_DIR}/runner/encoder.h
15+
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
16+
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
17+
${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp
18+
)
19+
20+
# build qnn t5 runner
21+
add_executable(qnn_t5_runner ${_qnn_t5_runner__srcs})
22+
target_include_directories(
23+
qnn_t5_runner PUBLIC ${_common_include_directories}
24+
${EXECUTORCH_ROOT}/extension/llm/tokenizers/include
25+
)
26+
27+
28+
target_link_libraries(
29+
qnn_t5_runner
30+
qnn_executorch_backend
31+
executorch_core
32+
extension_data_loader
33+
extension_flat_tensor
34+
extension_module
35+
extension_tensor
36+
gflags
37+
tokenizers
38+
)
39+
40+
target_compile_options(
41+
qnn_t5_runner PUBLIC ${_common_compile_options}
42+
)
43+
set_target_properties(
44+
qnn_t5_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
45+
)
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright (c) Qualcomm Innovation Center, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/**
10+
* @file
11+
*
12+
* This tool can run t5 with Qualcomm AI Engine Direct.
13+
*
14+
*/
15+
16+
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
17+
#include <executorch/examples/qualcomm/oss_scripts/t5/runner/runner.h>
18+
#include <executorch/runtime/platform/log.h>
19+
#include <gflags/gflags.h>
20+
#include <fstream>
21+
#include <vector>
22+
23+
DEFINE_string(
24+
model_path,
25+
"t5_qnn.pte",
26+
"t5 model serialized in flatbuffer format.");
27+
28+
DEFINE_string(
29+
tokenizer_model_path,
30+
"tokenizer.model",
31+
"The tokenizer is saved from T5Tokenize.save_pretrained for tokenizer.");
32+
DEFINE_string(
33+
input_list_path,
34+
"input_list.txt",
35+
"Input list storing file name of encoded results.");
36+
DEFINE_int32(
37+
seq_len,
38+
128,
39+
"Maximum sequence length for the generated output. Defaults to use the model's `max_cache_size` attribute. Will be truncated to maximal cache size if larger than `max_cache_size`.");
40+
41+
DEFINE_string(
42+
output_folder_path,
43+
"outputs",
44+
"Executorch inference data output path.");
45+
46+
std::vector<std::vector<std::vector<int64_t>>> parse_input_list_file(
47+
const std::string& input_list_path) {
48+
std::vector<std::vector<std::vector<int64_t>>> bufs;
49+
std::ifstream input_list(input_list_path);
50+
51+
auto split = [](std::string s, std::string delimiter) {
52+
size_t pos_start = 0, pos_end, delim_len = delimiter.length();
53+
std::string token;
54+
std::vector<std::string> res;
55+
56+
while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) {
57+
token = s.substr(pos_start, pos_end - pos_start);
58+
pos_start = pos_end + delim_len;
59+
res.push_back(token);
60+
}
61+
res.push_back(s.substr(pos_start));
62+
return res;
63+
};
64+
65+
if (!input_list.is_open()) {
66+
ET_LOG(Error, "Unable to open file");
67+
return bufs;
68+
}
69+
70+
std::string file_path;
71+
while (std::getline(input_list, file_path)) {
72+
auto input_files = split(file_path, " ");
73+
int num_inputs = input_files.size();
74+
if (num_inputs == 0) {
75+
break;
76+
}
77+
78+
bufs.emplace_back();
79+
bufs.back().resize(num_inputs);
80+
for (int input_index = 0; input_index < num_inputs; ++input_index) {
81+
std::ifstream fin(input_files[input_index], std::ios::binary);
82+
if (!fin.is_open()) {
83+
ET_LOG(
84+
Error, "Could not open file %s", input_files[input_index].c_str());
85+
continue;
86+
}
87+
88+
fin.seekg(0, std::ios::end);
89+
size_t file_size = fin.tellg();
90+
fin.seekg(0, std::ios::beg);
91+
92+
size_t num_tokens = file_size / sizeof(int64_t);
93+
bufs.back()[input_index].resize(num_tokens);
94+
95+
if (!fin.read(
96+
reinterpret_cast<char*>(bufs.back()[input_index].data()),
97+
file_size)) {
98+
ET_LOG(
99+
Error, "Could not read file %s", input_files[input_index].c_str());
100+
continue;
101+
}
102+
103+
fin.close();
104+
}
105+
}
106+
107+
input_list.close();
108+
return bufs;
109+
}
110+
111+
int main(int argc, char** argv) {
112+
gflags::ParseCommandLineFlags(&argc, &argv, true);
113+
114+
std::vector<std::vector<std::vector<int64_t>>> multi_turns_input_buffers =
115+
parse_input_list_file(FLAGS_input_list_path);
116+
117+
for (int iter = 0; iter < multi_turns_input_buffers.size(); ++iter) {
118+
std::vector<char> bufs;
119+
bufs.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
120+
auto callback = [&](const std::string& piece) {
121+
for (const char c : piece) {
122+
bufs.push_back(c);
123+
}
124+
};
125+
126+
example::Runner runner(FLAGS_model_path, FLAGS_tokenizer_model_path);
127+
// generate tokens
128+
runner.generate(FLAGS_seq_len, multi_turns_input_buffers[iter], callback);
129+
auto output_file_name =
130+
FLAGS_output_folder_path + "/output_" + std::to_string(iter) + ".txt";
131+
std::ofstream fout(output_file_name);
132+
fout.write(bufs.data(), bufs.size());
133+
fout.close();
134+
}
135+
136+
return 0;
137+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright (c) Qualcomm Innovation Center, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/examples/qualcomm/oss_scripts/t5/runner/decoder.h>
10+
11+
using executorch::aten::Tensor;
12+
using executorch::extension::Module;
13+
using executorch::extension::TensorPtr;
14+
using executorch::runtime::Error;
15+
using executorch::runtime::Result;
16+
17+
namespace example {
18+
T5Decoder::T5Decoder(const std::string& model_path) {
19+
module_ = std::make_unique<Module>(
20+
model_path, Module::LoadMode::MmapUseMlockIgnoreErrors);
21+
ET_LOG(Info, "creating decoder module: model_path=%s", model_path.c_str());
22+
}
23+
24+
bool T5Decoder::is_method_loaded() const {
25+
return module_->is_method_loaded(kDecoderForwardName);
26+
}
27+
28+
Error T5Decoder::load() {
29+
if (is_method_loaded()) {
30+
return Error::Ok;
31+
}
32+
return module_->load_method(kDecoderForwardName);
33+
}
34+
Result<Tensor> T5Decoder::step(
35+
TensorPtr& input_ids,
36+
TensorPtr& attention_mask,
37+
TensorPtr& encoder_hidden_states,
38+
TensorPtr& encoder_attention_mask,
39+
TensorPtr& cache_position) {
40+
auto outputs_res = module_->execute(
41+
kDecoderForwardName,
42+
{input_ids,
43+
attention_mask,
44+
encoder_hidden_states,
45+
encoder_attention_mask,
46+
cache_position});
47+
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
48+
ET_CHECK_MSG(
49+
outputs_res.get().size() == 1,
50+
"More then one output returned from executing decoder.");
51+
ET_CHECK_MSG(
52+
outputs_res.get()[0].isTensor(),
53+
"Non Tensor Output returned from executing decoder");
54+
55+
// Return the logits tensor
56+
return outputs_res.get()[0].toTensor();
57+
}
58+
} // namespace example
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) Qualcomm Innovation Center, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <executorch/extension/module/module.h>
11+
#include <executorch/extension/tensor/tensor.h>
12+
#include <executorch/extension/tensor/tensor_ptr.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/evalue.h>
15+
#include <memory>
16+
#include <string>
17+
#include <unordered_set>
18+
#include <vector>
19+
20+
namespace example {
21+
22+
class T5Decoder {
23+
public:
24+
explicit T5Decoder(const std::string& model_path);
25+
26+
bool is_method_loaded() const;
27+
executorch::runtime::Error load();
28+
executorch::runtime::Result<executorch::aten::Tensor> step(
29+
executorch::extension::TensorPtr& input_ids,
30+
executorch::extension::TensorPtr& attention_mask,
31+
executorch::extension::TensorPtr& encoder_hidden_states,
32+
executorch::extension::TensorPtr& encoder_attention_mask,
33+
executorch::extension::TensorPtr& cache_position);
34+
executorch::runtime::Result<std::unordered_set<std::string>> method_names() {
35+
return module_->method_names();
36+
}
37+
executorch::runtime::Result<executorch::runtime::EValue> get(
38+
const std::string& method_name) {
39+
return module_->get(method_name);
40+
}
41+
42+
executorch::runtime::Result<std::vector<executorch::runtime::EValue>> execute(
43+
const std::string& method_name) {
44+
return module_->execute(method_name);
45+
}
46+
47+
private:
48+
std::unique_ptr<executorch::extension::Module> module_;
49+
static constexpr const char* kDecoderForwardName = "decoder";
50+
};
51+
52+
} // namespace example

0 commit comments

Comments
 (0)