Skip to content

Commit 3b62b21

Browse files
authored
add attention convertor (#2064)
1 parent 1c7749d commit 3b62b21

File tree

2 files changed

+212
-2
lines changed

2 files changed

+212
-2
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
3+
#include <algorithm>
4+
#include <sstream>
5+
6+
#include "mmdeploy/core/device.h"
7+
#include "mmdeploy/core/model.h"
8+
#include "mmdeploy/core/registry.h"
9+
#include "mmdeploy/core/tensor.h"
10+
#include "mmdeploy/core/utils/device_utils.h"
11+
#include "mmdeploy/core/utils/formatter.h"
12+
#include "mmdeploy/core/value.h"
13+
#include "mmdeploy/experimental/module_adapter.h"
14+
#include "mmocr.h"
15+
16+
namespace mmdeploy::mmocr {
17+
18+
using std::string;
19+
using std::vector;
20+
21+
class AttnConvertor : public MMOCR {
22+
public:
23+
explicit AttnConvertor(const Value& cfg) : MMOCR(cfg) {
24+
auto model = cfg["context"]["model"].get<Model>();
25+
if (!cfg.contains("params")) {
26+
MMDEPLOY_ERROR("'params' is required, but it's not in the config");
27+
throw_exception(eInvalidArgument);
28+
}
29+
// BaseConverter
30+
auto& _cfg = cfg["params"];
31+
if (_cfg.contains("dict_file")) {
32+
auto filename = _cfg["dict_file"].get<std::string>();
33+
auto content = model.ReadFile(filename).value();
34+
idx2char_ = SplitLines(content);
35+
} else if (_cfg.contains("dict_list")) {
36+
from_value(_cfg["dict_list"], idx2char_);
37+
} else if (_cfg.contains("dict_type")) {
38+
auto dict_type = _cfg["dict_type"].get<std::string>();
39+
if (dict_type == "DICT36") {
40+
idx2char_ = SplitChars(DICT36);
41+
} else if (dict_type == "DICT90") {
42+
idx2char_ = SplitChars(DICT90);
43+
} else {
44+
MMDEPLOY_ERROR("unknown dict_type: {}", dict_type);
45+
throw_exception(eInvalidArgument);
46+
}
47+
} else {
48+
MMDEPLOY_ERROR("either dict_file, dict_list or dict_type must be specified");
49+
throw_exception(eInvalidArgument);
50+
}
51+
// Update Dictionary
52+
53+
bool with_start = _cfg.value("with_start", false);
54+
bool with_end = _cfg.value("with_end", false);
55+
bool same_start_end = _cfg.value("same_start_end", false);
56+
bool with_padding = _cfg.value("with_padding", false);
57+
bool with_unknown = _cfg.value("with_unknown", false);
58+
if (with_start && with_end && same_start_end) {
59+
idx2char_.emplace_back("<BOS/EOS>");
60+
start_idx_ = static_cast<int>(idx2char_.size()) - 1;
61+
end_idx_ = start_idx_;
62+
} else {
63+
if (with_start) {
64+
idx2char_.emplace_back("<BOS>");
65+
start_idx_ = static_cast<int>(idx2char_.size()) - 1;
66+
}
67+
if (with_end) {
68+
idx2char_.emplace_back("<EOS>");
69+
end_idx_ = static_cast<int>(idx2char_.size()) - 1;
70+
}
71+
}
72+
73+
if (with_padding) {
74+
idx2char_.emplace_back("<PAD>");
75+
padding_idx_ = static_cast<int>(idx2char_.size()) - 1;
76+
}
77+
if (with_unknown) {
78+
idx2char_.emplace_back("<UKN>");
79+
unknown_idx_ = static_cast<int>(idx2char_.size()) - 1;
80+
}
81+
82+
vector<string> ignore_chars;
83+
if (cfg.contains("ignore_chars")) {
84+
for (int i = 0; i < cfg["ignore_chars"].size(); i++)
85+
ignore_chars.emplace_back(cfg["ignore_chars"][i].get<string>());
86+
} else {
87+
ignore_chars.emplace_back("padding");
88+
}
89+
std::map<string, int> mapping_table = {
90+
{"padding", padding_idx_}, {"end", end_idx_}, {"unknown", unknown_idx_}};
91+
for (int i = 0; i < ignore_chars.size(); i++) {
92+
if (mapping_table.find(ignore_chars[i]) != mapping_table.end()) {
93+
ignore_indexes_.emplace_back(mapping_table.at(ignore_chars[i]));
94+
}
95+
}
96+
97+
model_ = model;
98+
}
99+
100+
Result<Value> operator()(const Value& _data, const Value& _prob) {
101+
auto d_conf = _prob["output"].get<Tensor>();
102+
103+
if (!(d_conf.shape().size() == 3 && d_conf.data_type() == DataType::kFLOAT)) {
104+
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", d_conf.shape(),
105+
(int)d_conf.data_type());
106+
return Status(eNotSupported);
107+
}
108+
109+
OUTCOME_TRY(auto h_conf, MakeAvailableOnDevice(d_conf, Device{0}, stream()));
110+
OUTCOME_TRY(stream().Wait());
111+
112+
auto data = h_conf.data<float>();
113+
114+
auto shape = d_conf.shape();
115+
auto w = static_cast<int>(shape[1]);
116+
auto c = static_cast<int>(shape[2]);
117+
118+
float valid_ratio = 1;
119+
if (_data["img_metas"].contains("valid_ratio")) {
120+
valid_ratio = _data["img_metas"]["valid_ratio"].get<float>();
121+
}
122+
auto [indexes, scores] = Tensor2Idx(data, w, c, valid_ratio);
123+
124+
auto text = Idx2Str(indexes);
125+
MMDEPLOY_DEBUG("text: {}", text);
126+
127+
TextRecognition output{text, scores};
128+
129+
return make_pointer(to_value(output));
130+
}
131+
132+
std::pair<vector<int>, vector<float> > Tensor2Idx(const float* data, int w, int c,
133+
float valid_ratio) {
134+
auto decode_len = w;
135+
vector<int> indexes;
136+
indexes.reserve(decode_len);
137+
vector<float> scores;
138+
scores.reserve(decode_len);
139+
for (int t = 0; t < decode_len; ++t, data += c) {
140+
vector<float> prob(data, data + c);
141+
auto iter = max_element(begin(prob), end(prob));
142+
auto index = static_cast<int>(iter - begin(prob));
143+
if (index == end_idx_) break;
144+
if (std::find(ignore_indexes_.begin(), ignore_indexes_.end(), index) ==
145+
ignore_indexes_.end()) {
146+
indexes.push_back(index);
147+
scores.push_back(*iter);
148+
}
149+
}
150+
return {indexes, scores};
151+
}
152+
153+
string Idx2Str(const vector<int>& indexes) {
154+
size_t count = 0;
155+
for (const auto& idx : indexes) {
156+
count += idx2char_[idx].size();
157+
}
158+
std::string text;
159+
text.reserve(count);
160+
for (const auto& idx : indexes) {
161+
text += idx2char_[idx];
162+
}
163+
return text;
164+
}
165+
166+
protected:
167+
static vector<string> SplitLines(const string& s) {
168+
std::istringstream is(s);
169+
vector<string> ret;
170+
string line;
171+
while (std::getline(is, line)) {
172+
ret.push_back(std::move(line));
173+
}
174+
return ret;
175+
}
176+
177+
static vector<string> SplitChars(const string& s) {
178+
vector<string> ret;
179+
ret.reserve(s.size());
180+
for (char c : s) {
181+
ret.push_back({c});
182+
}
183+
return ret;
184+
}
185+
186+
static constexpr const auto DICT36 = R"(0123456789abcdefghijklmnopqrstuvwxyz)";
187+
static constexpr const auto DICT90 = R"(0123456789abcdefghijklmnopqrstuvwxyz)"
188+
R"(ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'())"
189+
R"(*+,-./:;<=>?@[\]_`~)";
190+
191+
static constexpr const auto kHost = Device(0);
192+
193+
Model model_;
194+
195+
static constexpr const int blank_idx_{0};
196+
int padding_idx_{-1};
197+
int end_idx_{-1};
198+
int start_idx_{-1};
199+
int unknown_idx_{-1};
200+
201+
vector<int> ignore_indexes_;
202+
vector<string> idx2char_;
203+
};
204+
205+
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMOCR, AttnConvertor);
206+
207+
} // namespace mmdeploy::mmocr

mmdeploy/codebase/mmocr/deploy/text_recognition.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from copy import deepcopy
23
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
34

45
import mmengine
@@ -260,11 +261,13 @@ def get_postprocess(self,
260261
postprocess = self.model_cfg.model.decoder.postprocessor
261262
if postprocess.type == 'CTCPostProcessor':
262263
postprocess.type = 'CTCConvertor'
264+
if postprocess.type == 'AttentionPostprocessor':
265+
postprocess.type = 'AttnConvertor'
263266
import shutil
264267
shutil.copy(self.model_cfg.dictionary.dict_file,
265268
f'{work_dir}/dict_file.txt')
266-
with_padding = self.model_cfg.dictionary.get('with_padding', False)
267-
params = dict(dict_file='dict_file.txt', with_padding=with_padding)
269+
params = deepcopy(self.model_cfg.dictionary)
270+
params.update(dict(dict_file='dict_file.txt'))
268271
postprocess['params'] = params
269272
return postprocess
270273

0 commit comments

Comments
 (0)