@@ -58,9 +58,35 @@ class ax_runner_base
5858
5959 int dev_id = 0 ;
6060
61+ // 辅助函数:初始化完成后构建映射表,提高后续查找速度
62+ void build_tensor_maps ()
63+ {
64+ map_input_tensors.clear ();
65+ for (const auto &t : minput_tensors)
66+ map_input_tensors[t.sName ] = t;
67+
68+ map_output_tensors.clear ();
69+ for (const auto &t : moutput_tensors)
70+ map_output_tensors[t.sName ] = t;
71+
72+ map_group_input_tensors.clear ();
73+ for (const auto &grp : mgroup_input_tensors)
74+ {
75+ for (const auto &t : grp)
76+ map_group_input_tensors[t.sName ].push_back (t);
77+ }
78+
79+ map_group_output_tensors.clear ();
80+ for (const auto &grp : mgroup_output_tensors)
81+ {
82+ for (const auto &t : grp)
83+ map_group_output_tensors[t.sName ].push_back (t);
84+ }
85+ }
86+
6187public:
6288 virtual int init (const char *model_file, int devid) = 0;
63- virtual int init (char *model_buffer, size_t model_size) = 0;
89+ virtual int init (char *model_buffer, size_t model_size, int devid ) = 0;
6490
6591 virtual void deinit () = 0;
6692
@@ -74,83 +100,51 @@ class ax_runner_base
74100
75101 const ax_runner_tensor_t &get_input (int idx) { return minput_tensors[idx]; }
76102 const ax_runner_tensor_t *get_inputs_ptr () { return minput_tensors.data (); }
77- const ax_runner_tensor_t &get_input (std::string name)
103+
104+ const ax_runner_tensor_t &get_input (const std::string &name)
78105 {
79- if (map_input_tensors.size () == 0 )
80- {
81- for (size_t i = 0 ; i < minput_tensors.size (); i++)
82- {
83- map_input_tensors[minput_tensors[i].sName ] = minput_tensors[i];
84- }
85- }
86- if (map_input_tensors.find (name) == map_input_tensors.end ())
87- {
106+ auto it = map_input_tensors.find (name);
107+ if (it == map_input_tensors.end ())
88108 throw std::runtime_error (" input tensor not found: " + name);
89- }
90-
91- return map_input_tensors[name];
109+ return it->second ;
92110 }
93111
94112 const ax_runner_tensor_t &get_input (int grpid, int idx) { return mgroup_input_tensors[grpid][idx]; }
95113 const ax_runner_tensor_t *get_inputs_ptr (int grpid) { return mgroup_input_tensors[grpid].data (); }
96- const ax_runner_tensor_t &get_input (int grpid, std::string name)
114+
115+ const ax_runner_tensor_t &get_input (int grpid, const std::string &name)
97116 {
98- if (map_group_input_tensors.size () == 0 )
99- {
100- for (size_t i = 0 ; i < mgroup_input_tensors.size (); i++)
101- {
102- for (size_t j = 0 ; j < mgroup_input_tensors[i].size (); j++)
103- {
104- map_group_input_tensors[mgroup_input_tensors[i][j].sName ].push_back (mgroup_input_tensors[i][j]);
105- }
106- }
107- }
108- if (map_group_input_tensors.find (name) == map_group_input_tensors.end ())
109- {
117+ auto it = map_group_input_tensors.find (name);
118+ if (it == map_group_input_tensors.end ())
110119 throw std::runtime_error (" input tensor not found: " + name);
111- }
112- return map_group_input_tensors[name][grpid];
113- // return map_input_tensors[name];
120+ // 简单的越界检查
121+ if (grpid < 0 || grpid >= (int )it->second .size ())
122+ throw std::runtime_error (" group id out of range for: " + name);
123+ return it->second [grpid];
114124 }
115125
116126 const ax_runner_tensor_t &get_output (int idx) { return moutput_tensors[idx]; }
117127 const ax_runner_tensor_t *get_outputs_ptr () { return moutput_tensors.data (); }
118- const ax_runner_tensor_t &get_output (std::string name)
128+
129+ const ax_runner_tensor_t &get_output (const std::string &name)
119130 {
120- if (map_output_tensors.size () == 0 )
121- {
122- for (size_t i = 0 ; i < moutput_tensors.size (); i++)
123- {
124- map_output_tensors[moutput_tensors[i].sName ] = moutput_tensors[i];
125- }
126- }
127- if (map_output_tensors.find (name) == map_output_tensors.end ())
128- {
131+ auto it = map_output_tensors.find (name);
132+ if (it == map_output_tensors.end ())
129133 throw std::runtime_error (" output tensor not found: " + name);
130- }
131-
132- return map_output_tensors[name];
134+ return it->second ;
133135 }
134136
135137 const ax_runner_tensor_t &get_output (int grpid, int idx) { return mgroup_output_tensors[grpid][idx]; }
136138 const ax_runner_tensor_t *get_outputs_ptr (int grpid) { return mgroup_output_tensors[grpid].data (); }
137- const ax_runner_tensor_t &get_output (int grpid, std::string name)
139+
140+ const ax_runner_tensor_t &get_output (int grpid, const std::string &name)
138141 {
139- if (map_group_output_tensors.size () == 0 )
140- {
141- for (size_t i = 0 ; i < mgroup_output_tensors.size (); i++)
142- {
143- for (size_t j = 0 ; j < mgroup_output_tensors[i].size (); j++)
144- {
145- map_group_output_tensors[mgroup_output_tensors[i][j].sName ].push_back (mgroup_output_tensors[i][j]);
146- }
147- }
148- }
149- if (map_group_output_tensors.find (name) == map_group_output_tensors.end ())
150- {
151- throw std::runtime_error (" input tensor not found: " + name);
152- }
153- return map_group_output_tensors[name][grpid];
142+ auto it = map_group_output_tensors.find (name);
143+ if (it == map_group_output_tensors.end ())
144+ throw std::runtime_error (" output tensor not found: " + name);
145+ if (grpid < 0 || grpid >= (int )it->second .size ())
146+ throw std::runtime_error (" group id out of range for: " + name);
147+ return it->second [grpid];
154148 }
155149
156150 virtual int get_algo_width () = 0;
0 commit comments