33#include < algorithm>
44#include < sstream>
55
6+ #include " base_convertor.h"
67#include " mmdeploy/core/device.h"
78#include " mmdeploy/core/model.h"
89#include " mmdeploy/core/registry.h"
@@ -18,84 +19,9 @@ namespace mmdeploy::mmocr {
1819using std::string;
1920using std::vector;
2021
21- class AttnConvertor : public MMOCR {
22+ class AttnConvertor : public BaseConvertor {
2223 public:
23- explicit AttnConvertor (const Value& cfg) : MMOCR(cfg) {
24- auto model = cfg[" context" ][" model" ].get <Model>();
25- if (!cfg.contains (" params" )) {
26- MMDEPLOY_ERROR (" 'params' is required, but it's not in the config" );
27- throw_exception (eInvalidArgument);
28- }
29- // BaseConverter
30- auto & _cfg = cfg[" params" ];
31- if (_cfg.contains (" dict_file" )) {
32- auto filename = _cfg[" dict_file" ].get <std::string>();
33- auto content = model.ReadFile (filename).value ();
34- idx2char_ = SplitLines (content);
35- } else if (_cfg.contains (" dict_list" )) {
36- from_value (_cfg[" dict_list" ], idx2char_);
37- } else if (_cfg.contains (" dict_type" )) {
38- auto dict_type = _cfg[" dict_type" ].get <std::string>();
39- if (dict_type == " DICT36" ) {
40- idx2char_ = SplitChars (DICT36);
41- } else if (dict_type == " DICT90" ) {
42- idx2char_ = SplitChars (DICT90);
43- } else {
44- MMDEPLOY_ERROR (" unknown dict_type: {}" , dict_type);
45- throw_exception (eInvalidArgument);
46- }
47- } else {
48- MMDEPLOY_ERROR (" either dict_file, dict_list or dict_type must be specified" );
49- throw_exception (eInvalidArgument);
50- }
51- // Update Dictionary
52-
53- bool with_start = _cfg.value (" with_start" , false );
54- bool with_end = _cfg.value (" with_end" , false );
55- bool same_start_end = _cfg.value (" same_start_end" , false );
56- bool with_padding = _cfg.value (" with_padding" , false );
57- bool with_unknown = _cfg.value (" with_unknown" , false );
58- if (with_start && with_end && same_start_end) {
59- idx2char_.emplace_back (" <BOS/EOS>" );
60- start_idx_ = static_cast <int >(idx2char_.size ()) - 1 ;
61- end_idx_ = start_idx_;
62- } else {
63- if (with_start) {
64- idx2char_.emplace_back (" <BOS>" );
65- start_idx_ = static_cast <int >(idx2char_.size ()) - 1 ;
66- }
67- if (with_end) {
68- idx2char_.emplace_back (" <EOS>" );
69- end_idx_ = static_cast <int >(idx2char_.size ()) - 1 ;
70- }
71- }
72-
73- if (with_padding) {
74- idx2char_.emplace_back (" <PAD>" );
75- padding_idx_ = static_cast <int >(idx2char_.size ()) - 1 ;
76- }
77- if (with_unknown) {
78- idx2char_.emplace_back (" <UKN>" );
79- unknown_idx_ = static_cast <int >(idx2char_.size ()) - 1 ;
80- }
81-
82- vector<string> ignore_chars;
83- if (cfg.contains (" ignore_chars" )) {
84- for (int i = 0 ; i < cfg[" ignore_chars" ].size (); i++)
85- ignore_chars.emplace_back (cfg[" ignore_chars" ][i].get <string>());
86- } else {
87- ignore_chars.emplace_back (" padding" );
88- }
89- std::map<string, int > mapping_table = {
90- {" padding" , padding_idx_}, {" end" , end_idx_}, {" unknown" , unknown_idx_}};
91- for (int i = 0 ; i < ignore_chars.size (); i++) {
92- if (mapping_table.find (ignore_chars[i]) != mapping_table.end ()) {
93- ignore_indexes_.emplace_back (mapping_table.at (ignore_chars[i]));
94- }
95- }
96-
97- model_ = model;
98- }
24+ explicit AttnConvertor (const Value& cfg) : BaseConvertor(cfg) {}
9925
10026 Result<Value> operator ()(const Value& _data, const Value& _prob) {
10127 auto d_conf = _prob[" output" ].get <Tensor>();
@@ -115,11 +41,7 @@ class AttnConvertor : public MMOCR {
11541 auto w = static_cast <int >(shape[1 ]);
11642 auto c = static_cast <int >(shape[2 ]);
11743
118- float valid_ratio = 1 ;
119- if (_data[" img_metas" ].contains (" valid_ratio" )) {
120- valid_ratio = _data[" img_metas" ][" valid_ratio" ].get <float >();
121- }
122- auto [indexes, scores] = Tensor2Idx (data, w, c, valid_ratio);
44+ auto [indexes, scores] = Tensor2Idx (data, w, c);
12345
12446 auto text = Idx2Str (indexes);
12547 MMDEPLOY_DEBUG (" text: {}" , text);
@@ -129,8 +51,7 @@ class AttnConvertor : public MMOCR {
12951 return make_pointer (to_value (output));
13052 }
13153
132- std::pair<vector<int >, vector<float > > Tensor2Idx (const float * data, int w, int c,
133- float valid_ratio) {
54+ std::pair<vector<int >, vector<float > > Tensor2Idx (const float * data, int w, int c) {
13455 auto decode_len = w;
13556 vector<int > indexes;
13657 indexes.reserve (decode_len);
@@ -149,57 +70,6 @@ class AttnConvertor : public MMOCR {
14970 }
15071 return {indexes, scores};
15172 }
152-
153- string Idx2Str (const vector<int >& indexes) {
154- size_t count = 0 ;
155- for (const auto & idx : indexes) {
156- count += idx2char_[idx].size ();
157- }
158- std::string text;
159- text.reserve (count);
160- for (const auto & idx : indexes) {
161- text += idx2char_[idx];
162- }
163- return text;
164- }
165-
166- protected:
167- static vector<string> SplitLines (const string& s) {
168- std::istringstream is (s);
169- vector<string> ret;
170- string line;
171- while (std::getline (is, line)) {
172- ret.push_back (std::move (line));
173- }
174- return ret;
175- }
176-
177- static vector<string> SplitChars (const string& s) {
178- vector<string> ret;
179- ret.reserve (s.size ());
180- for (char c : s) {
181- ret.push_back ({c});
182- }
183- return ret;
184- }
185-
186- static constexpr const auto DICT36 = R"( 0123456789abcdefghijklmnopqrstuvwxyz)" ;
187- static constexpr const auto DICT90 = R"( 0123456789abcdefghijklmnopqrstuvwxyz)"
188- R"( ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'())"
189- R"( *+,-./:;<=>?@[\]_`~)" ;
190-
191- static constexpr const auto kHost = Device(0 );
192-
193- Model model_;
194-
195- static constexpr const int blank_idx_{0 };
196- int padding_idx_{-1 };
197- int end_idx_{-1 };
198- int start_idx_{-1 };
199- int unknown_idx_{-1 };
200-
201- vector<int > ignore_indexes_;
202- vector<string> idx2char_;
20373};
20474
20575MMDEPLOY_REGISTER_CODEBASE_COMPONENT (MMOCR, AttnConvertor);
0 commit comments