44#include < map>
55#include < stdexcept>
66
7- typedef enum _color_space_e
8- {
9- axdl_color_space_unknown,
10- axdl_color_space_nv12,
11- axdl_color_space_nv21,
12- axdl_color_space_bgr,
13- axdl_color_space_rgb,
14- } ax_color_space_e;
15-
16- typedef struct _image_t
17- {
18- unsigned long long int pPhy;
19- void *pVir;
20- unsigned int nSize;
21- unsigned int nWidth;
22- unsigned int nHeight;
23- ax_color_space_e eDtype;
24- union
25- {
26- int tStride_H, tStride_W, tStride_C;
27- };
28- } ax_image_t ;
29-
307typedef struct
318{
329 std::string sName ;
@@ -43,115 +20,93 @@ class ax_runner_base
4320 std::vector<ax_runner_tensor_t > moutput_tensors;
4421 std::vector<ax_runner_tensor_t > minput_tensors;
4522
23+ // Group tensors
4624 std::vector<std::vector<ax_runner_tensor_t >> mgroup_output_tensors;
4725 std::vector<std::vector<ax_runner_tensor_t >> mgroup_input_tensors;
4826
27+ // Lookup maps
4928 std::map<std::string, ax_runner_tensor_t > map_output_tensors;
5029 std::map<std::string, ax_runner_tensor_t > map_input_tensors;
5130
5231 std::map<std::string, std::vector<ax_runner_tensor_t >> map_group_output_tensors;
5332 std::map<std::string, std::vector<ax_runner_tensor_t >> map_group_input_tensors;
5433
34+ // 辅助函数:初始化完成后构建映射表,提高后续查找速度
35+ void build_tensor_maps () {
36+ map_input_tensors.clear ();
37+ for (const auto & t : minput_tensors) map_input_tensors[t.sName ] = t;
38+
39+ map_output_tensors.clear ();
40+ for (const auto & t : moutput_tensors) map_output_tensors[t.sName ] = t;
41+
42+ map_group_input_tensors.clear ();
43+ for (const auto & grp : mgroup_input_tensors) {
44+ for (const auto & t : grp) map_group_input_tensors[t.sName ].push_back (t);
45+ }
46+
47+ map_group_output_tensors.clear ();
48+ for (const auto & grp : mgroup_output_tensors) {
49+ for (const auto & t : grp) map_group_output_tensors[t.sName ].push_back (t);
50+ }
51+ }
52+
5553public:
54+ virtual ~ax_runner_base () {} // 【重要】添加虚析构函数
55+
5656 virtual int init (const char *model_file, bool use_mmap = false ) = 0;
5757 virtual int init (char *model_buffer, size_t model_size) = 0;
58-
5958 virtual void deinit () = 0;
6059
6160 int get_num_inputs () { return minput_tensors.size (); };
6261 int get_num_outputs () { return moutput_tensors.size (); };
63-
6462 int get_num_input_groups () { return mgroup_input_tensors.size (); };
6563 int get_num_output_groups () { return mgroup_output_tensors.size (); };
6664
6765 const ax_runner_tensor_t &get_input (int idx) { return minput_tensors[idx]; }
6866 const ax_runner_tensor_t *get_inputs_ptr () { return minput_tensors.data (); }
69- const ax_runner_tensor_t &get_input (std::string name)
67+
68+ const ax_runner_tensor_t &get_input (const std::string& name)
7069 {
71- if (map_input_tensors.size () == 0 )
72- {
73- for (size_t i = 0 ; i < minput_tensors.size (); i++)
74- {
75- map_input_tensors[minput_tensors[i].sName ] = minput_tensors[i];
76- }
77- }
78- if (map_input_tensors.find (name) == map_input_tensors.end ())
79- {
80- throw std::runtime_error (" input tensor not found: " + name);
81- }
82-
83- return map_input_tensors[name];
70+ auto it = map_input_tensors.find (name);
71+ if (it == map_input_tensors.end ()) throw std::runtime_error (" input tensor not found: " + name);
72+ return it->second ;
8473 }
8574
8675 const ax_runner_tensor_t &get_input (int grpid, int idx) { return mgroup_input_tensors[grpid][idx]; }
8776 const ax_runner_tensor_t *get_inputs_ptr (int grpid) { return mgroup_input_tensors[grpid].data (); }
88- const ax_runner_tensor_t &get_input (int grpid, std::string name)
77+
78+ const ax_runner_tensor_t &get_input (int grpid, const std::string& name)
8979 {
90- if (map_group_input_tensors.size () == 0 )
91- {
92- for (size_t i = 0 ; i < mgroup_input_tensors.size (); i++)
93- {
94- for (size_t j = 0 ; j < mgroup_input_tensors[i].size (); j++)
95- {
96- map_group_input_tensors[mgroup_input_tensors[i][j].sName ].push_back (mgroup_input_tensors[i][j]);
97- }
98- }
99- }
100- if (map_group_input_tensors.find (name) == map_group_input_tensors.end ())
101- {
102- throw std::runtime_error (" input tensor not found: " + name);
103- }
104- return map_group_input_tensors[name][grpid];
105- // return map_input_tensors[name];
80+ auto it = map_group_input_tensors.find (name);
81+ if (it == map_group_input_tensors.end ()) throw std::runtime_error (" input tensor not found: " + name);
82+ // 简单的越界检查
83+ if (grpid < 0 || grpid >= (int )it->second .size ()) throw std::runtime_error (" group id out of range for: " + name);
84+ return it->second [grpid];
10685 }
10786
10887 const ax_runner_tensor_t &get_output (int idx) { return moutput_tensors[idx]; }
10988 const ax_runner_tensor_t *get_outputs_ptr () { return moutput_tensors.data (); }
110- const ax_runner_tensor_t &get_output (std::string name)
89+
90+ const ax_runner_tensor_t &get_output (const std::string& name)
11191 {
112- if (map_output_tensors.size () == 0 )
113- {
114- for (size_t i = 0 ; i < moutput_tensors.size (); i++)
115- {
116- map_output_tensors[moutput_tensors[i].sName ] = moutput_tensors[i];
117- }
118- }
119- if (map_output_tensors.find (name) == map_output_tensors.end ())
120- {
121- throw std::runtime_error (" output tensor not found: " + name);
122- }
123-
124- return map_output_tensors[name];
92+ auto it = map_output_tensors.find (name);
93+ if (it == map_output_tensors.end ()) throw std::runtime_error (" output tensor not found: " + name);
94+ return it->second ;
12595 }
12696
12797 const ax_runner_tensor_t &get_output (int grpid, int idx) { return mgroup_output_tensors[grpid][idx]; }
12898 const ax_runner_tensor_t *get_outputs_ptr (int grpid) { return mgroup_output_tensors[grpid].data (); }
129- const ax_runner_tensor_t &get_output (int grpid, std::string name)
99+
100+ const ax_runner_tensor_t &get_output (int grpid, const std::string& name)
130101 {
131- if (map_group_output_tensors.size () == 0 )
132- {
133- for (size_t i = 0 ; i < mgroup_output_tensors.size (); i++)
134- {
135- for (size_t j = 0 ; j < mgroup_output_tensors[i].size (); j++)
136- {
137- map_group_output_tensors[mgroup_output_tensors[i][j].sName ].push_back (mgroup_output_tensors[i][j]);
138- }
139- }
140- }
141- if (map_group_output_tensors.find (name) == map_group_output_tensors.end ())
142- {
143- throw std::runtime_error (" input tensor not found: " + name);
144- }
145- return map_group_output_tensors[name][grpid];
102+ auto it = map_group_output_tensors.find (name);
103+ if (it == map_group_output_tensors.end ()) throw std::runtime_error (" output tensor not found: " + name);
104+ if (grpid < 0 || grpid >= (int )it->second .size ()) throw std::runtime_error (" group id out of range for: " + name);
105+ return it->second [grpid];
146106 }
147107
148108 virtual int inference () = 0;
149109 virtual int inference (int grpid) = 0;
150110
151- int operator ()()
152- {
153- return inference ();
154- }
155- };
156-
157- // int ax_cmmcpy(unsigned long long int dst, unsigned long long int src, int size);
111+ int operator ()() { return inference (); }
112+ };
0 commit comments