Skip to content

Commit a11e359

Browse files
authored
Refactor rknn code (#2079)
1 parent 8e51a97 commit a11e359

File tree

6 files changed

+218
-451
lines changed

6 files changed

+218
-451
lines changed

sherpa-onnx/csrc/online-recognizer-impl.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,26 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
9292
template <typename Manager>
9393
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
9494
Manager *mgr, const OnlineRecognizerConfig &config) {
95+
if (config.model_config.provider_config.provider == "rknn") {
96+
#if SHERPA_ONNX_ENABLE_RKNN
97+
// Currently, only zipformer v1 is suported for rknn
98+
if (config.model_config.transducer.encoder.empty() &&
99+
config.model_config.zipformer2_ctc.model.empty()) {
100+
SHERPA_ONNX_LOGE(
101+
"Only Zipformer transducers and CTC models are currently supported "
102+
"by rknn. Fallback to CPU");
103+
} else if (!config.model_config.transducer.encoder.empty()) {
104+
return std::make_unique<OnlineRecognizerTransducerRknnImpl>(mgr, config);
105+
} else if (!config.model_config.zipformer2_ctc.model.empty()) {
106+
return std::make_unique<OnlineRecognizerCtcRknnImpl>(mgr, config);
107+
}
108+
#else
109+
SHERPA_ONNX_LOGE(
110+
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
111+
"want to use rknn. Fallback to CPU");
112+
#endif
113+
}
114+
95115
if (!config.model_config.transducer.encoder.empty()) {
96116
Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
97117

sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc

Lines changed: 14 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -42,39 +42,17 @@ class OnlineZipformerCtcModelRknn::Impl {
4242
Init(buf.data(), buf.size());
4343
}
4444

45-
int32_t ret = RKNN_SUCC;
46-
switch (config_.num_threads) {
47-
case 1:
48-
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO);
49-
break;
50-
case 0:
51-
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0);
52-
break;
53-
case -1:
54-
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1);
55-
break;
56-
case -2:
57-
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2);
58-
break;
59-
case -3:
60-
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1);
61-
break;
62-
case -4:
63-
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2);
64-
break;
65-
default:
66-
SHERPA_ONNX_LOGE(
67-
"Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
68-
"1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
69-
config_.num_threads);
70-
break;
71-
}
72-
if (ret != RKNN_SUCC) {
73-
SHERPA_ONNX_LOGE(
74-
"Failed to select npu core to run the model (You can ignore it if "
75-
"you "
76-
"are not using RK3588.");
45+
SetCoreMask(ctx_, config_.num_threads);
46+
}
47+
48+
template <typename Manager>
49+
Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config) {
50+
{
51+
auto buf = ReadFile(mgr, config.zipformer2_ctc.model);
52+
Init(buf.data(), buf.size());
7753
}
54+
55+
SetCoreMask(ctx_, config_.num_threads);
7856
}
7957

8058
// TODO(fangjun): Support Android
@@ -209,86 +187,13 @@ class OnlineZipformerCtcModelRknn::Impl {
209187

210188
private:
211189
void Init(void *model_data, size_t model_data_length) {
212-
auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);
213-
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'",
214-
config_.zipformer2_ctc.model.c_str());
215-
216-
if (config_.debug) {
217-
rknn_sdk_version v;
218-
ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
219-
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
220-
221-
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
222-
v.drv_version);
223-
}
224-
225-
rknn_input_output_num io_num;
226-
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
227-
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
228-
229-
if (config_.debug) {
230-
SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
231-
static_cast<int32_t>(io_num.n_input),
232-
static_cast<int32_t>(io_num.n_output));
233-
}
234-
235-
input_attrs_.resize(io_num.n_input);
236-
output_attrs_.resize(io_num.n_output);
237-
238-
int32_t i = 0;
239-
for (auto &attr : input_attrs_) {
240-
memset(&attr, 0, sizeof(attr));
241-
attr.index = i;
242-
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
243-
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
244-
i += 1;
245-
}
246-
247-
if (config_.debug) {
248-
std::ostringstream os;
249-
std::string sep;
250-
for (auto &attr : input_attrs_) {
251-
os << sep << ToString(attr);
252-
sep = "\n";
253-
}
254-
SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
255-
os.str().c_str());
256-
}
257-
258-
i = 0;
259-
for (auto &attr : output_attrs_) {
260-
memset(&attr, 0, sizeof(attr));
261-
attr.index = i;
262-
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
263-
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
264-
i += 1;
265-
}
190+
InitContext(model_data, model_data_length, config_.debug, &ctx_);
266191

267-
if (config_.debug) {
268-
std::ostringstream os;
269-
std::string sep;
270-
for (auto &attr : output_attrs_) {
271-
os << sep << ToString(attr);
272-
sep = "\n";
273-
}
274-
SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
275-
os.str().c_str());
276-
}
192+
InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_);
277193

278-
rknn_custom_string custom_string;
279-
ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
280-
sizeof(custom_string));
281-
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
282-
if (config_.debug) {
283-
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
284-
}
285-
auto meta = Parse(custom_string);
194+
rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug);
286195

287-
if (config_.debug) {
288-
for (const auto &p : meta) {
289-
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
290-
}
291-
}
196+
auto meta = Parse(custom_string, config_.debug);
292197

293198
if (meta.count("T")) {
294199
T_ = atoi(meta.at("T").c_str());

0 commit comments

Comments
 (0)