@@ -42,39 +42,17 @@ class OnlineZipformerCtcModelRknn::Impl {
4242 Init (buf.data (), buf.size ());
4343 }
4444
45- int32_t ret = RKNN_SUCC;
46- switch (config_.num_threads ) {
47- case 1 :
48- ret = rknn_set_core_mask (ctx_, RKNN_NPU_CORE_AUTO);
49- break ;
50- case 0 :
51- ret = rknn_set_core_mask (ctx_, RKNN_NPU_CORE_0);
52- break ;
53- case -1 :
54- ret = rknn_set_core_mask (ctx_, RKNN_NPU_CORE_1);
55- break ;
56- case -2 :
57- ret = rknn_set_core_mask (ctx_, RKNN_NPU_CORE_2);
58- break ;
59- case -3 :
60- ret = rknn_set_core_mask (ctx_, RKNN_NPU_CORE_0_1);
61- break ;
62- case -4 :
63- ret = rknn_set_core_mask (ctx_, RKNN_NPU_CORE_0_1_2);
64- break ;
65- default :
66- SHERPA_ONNX_LOGE (
67- " Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
68- " 1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d" ,
69- config_.num_threads );
70- break ;
71- }
72- if (ret != RKNN_SUCC) {
73- SHERPA_ONNX_LOGE (
74- " Failed to select npu core to run the model (You can ignore it if "
75- " you "
76- " are not using RK3588." );
45+ SetCoreMask (ctx_, config_.num_threads );
46+ }
47+
48+ template <typename Manager>
49+ Impl (Manager *mgr, const OnlineModelConfig &config) : config_(config) {
50+ {
51+ auto buf = ReadFile (mgr, config.zipformer2_ctc .model );
52+ Init (buf.data (), buf.size ());
7753 }
54+
55+ SetCoreMask (ctx_, config_.num_threads );
7856 }
7957
8058 // TODO(fangjun): Support Android
@@ -209,86 +187,13 @@ class OnlineZipformerCtcModelRknn::Impl {
209187
210188 private:
211189 void Init (void *model_data, size_t model_data_length) {
212- auto ret = rknn_init (&ctx_, model_data, model_data_length, 0 , nullptr );
213- SHERPA_ONNX_RKNN_CHECK (ret, " Failed to init model '%s'" ,
214- config_.zipformer2_ctc .model .c_str ());
215-
216- if (config_.debug ) {
217- rknn_sdk_version v;
218- ret = rknn_query (ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof (v));
219- SHERPA_ONNX_RKNN_CHECK (ret, " Failed to get rknn sdk version" );
220-
221- SHERPA_ONNX_LOGE (" sdk api version: %s, driver version: %s" , v.api_version ,
222- v.drv_version );
223- }
224-
225- rknn_input_output_num io_num;
226- ret = rknn_query (ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof (io_num));
227- SHERPA_ONNX_RKNN_CHECK (ret, " Failed to get I/O information for the model" );
228-
229- if (config_.debug ) {
230- SHERPA_ONNX_LOGE (" model: %d inputs, %d outputs" ,
231- static_cast <int32_t >(io_num.n_input ),
232- static_cast <int32_t >(io_num.n_output ));
233- }
234-
235- input_attrs_.resize (io_num.n_input );
236- output_attrs_.resize (io_num.n_output );
237-
238- int32_t i = 0 ;
239- for (auto &attr : input_attrs_) {
240- memset (&attr, 0 , sizeof (attr));
241- attr.index = i;
242- ret = rknn_query (ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof (attr));
243- SHERPA_ONNX_RKNN_CHECK (ret, " Failed to get attr for model input %d" , i);
244- i += 1 ;
245- }
246-
247- if (config_.debug ) {
248- std::ostringstream os;
249- std::string sep;
250- for (auto &attr : input_attrs_) {
251- os << sep << ToString (attr);
252- sep = " \n " ;
253- }
254- SHERPA_ONNX_LOGE (" \n ----------Model inputs info----------\n %s" ,
255- os.str ().c_str ());
256- }
257-
258- i = 0 ;
259- for (auto &attr : output_attrs_) {
260- memset (&attr, 0 , sizeof (attr));
261- attr.index = i;
262- ret = rknn_query (ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof (attr));
263- SHERPA_ONNX_RKNN_CHECK (ret, " Failed to get attr for model output %d" , i);
264- i += 1 ;
265- }
190+ InitContext (model_data, model_data_length, config_.debug , &ctx_);
266191
267- if (config_.debug ) {
268- std::ostringstream os;
269- std::string sep;
270- for (auto &attr : output_attrs_) {
271- os << sep << ToString (attr);
272- sep = " \n " ;
273- }
274- SHERPA_ONNX_LOGE (" \n ----------Model outputs info----------\n %s" ,
275- os.str ().c_str ());
276- }
192+ InitInputOutputAttrs (ctx_, config_.debug , &input_attrs_, &output_attrs_);
277193
278- rknn_custom_string custom_string;
279- ret = rknn_query (ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
280- sizeof (custom_string));
281- SHERPA_ONNX_RKNN_CHECK (ret, " Failed to read custom string from the model" );
282- if (config_.debug ) {
283- SHERPA_ONNX_LOGE (" customs string: %s" , custom_string.string );
284- }
285- auto meta = Parse (custom_string);
194+ rknn_custom_string custom_string = GetCustomString (ctx_, config_.debug );
286195
287- if (config_.debug ) {
288- for (const auto &p : meta) {
289- SHERPA_ONNX_LOGE (" %s: %s" , p.first .c_str (), p.second .c_str ());
290- }
291- }
196+ auto meta = Parse (custom_string, config_.debug );
292197
293198 if (meta.count (" T" )) {
294199 T_ = atoi (meta.at (" T" ).c_str ());
0 commit comments