@@ -33,7 +33,9 @@ namespace sherpa_onnx {
3333OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel (
3434 const OnlineModelConfig &config)
3535 : env_(ORT_LOGGING_LEVEL_WARNING),
36- sess_opts_ (GetSessionOptions(config)),
36+ encoder_sess_opts_ (GetSessionOptions(config)),
37+ decoder_sess_opts_(GetSessionOptions(config, " decoder" )),
38+ joiner_sess_opts_(GetSessionOptions(config, " joiner" )),
3739 config_(config),
3840 allocator_{} {
3941 {
@@ -57,7 +59,9 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel(
5759 AAssetManager *mgr, const OnlineModelConfig &config)
5860 : env_(ORT_LOGGING_LEVEL_WARNING),
5961 config_(config),
60- sess_opts_(GetSessionOptions(config)),
62+ encoder_sess_opts_(GetSessionOptions(config)),
63+ decoder_sess_opts_(GetSessionOptions(config)),
64+ joiner_sess_opts_(GetSessionOptions(config)),
6165 allocator_{} {
6266 {
6367 auto buf = ReadFile (mgr, config.transducer .encoder );
@@ -79,7 +83,7 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel(
7983void OnlineZipformer2TransducerModel::InitEncoder (void *model_data,
8084 size_t model_data_length) {
8185 encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
82- model_data_length, sess_opts_ );
86+ model_data_length, encoder_sess_opts_ );
8387
8488 GetInputNames (encoder_sess_.get (), &encoder_input_names_,
8589 &encoder_input_names_ptr_);
@@ -132,7 +136,7 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data,
132136void OnlineZipformer2TransducerModel::InitDecoder (void *model_data,
133137 size_t model_data_length) {
134138 decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
135- model_data_length, sess_opts_ );
139+ model_data_length, decoder_sess_opts_ );
136140
137141 GetInputNames (decoder_sess_.get (), &decoder_input_names_,
138142 &decoder_input_names_ptr_);
@@ -157,7 +161,7 @@ void OnlineZipformer2TransducerModel::InitDecoder(void *model_data,
157161void OnlineZipformer2TransducerModel::InitJoiner (void *model_data,
158162 size_t model_data_length) {
159163 joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
160- model_data_length, sess_opts_ );
164+ model_data_length, joiner_sess_opts_ );
161165
162166 GetInputNames (joiner_sess_.get (), &joiner_input_names_,
163167 &joiner_input_names_ptr_);
0 commit comments