Skip to content

Commit 3e999db

Browse files
authored
[Fix] Fix text recog task postprocess (#2209)
* update ocr_recog postprocess export * update ocr_recog sdk postprocess * fix read unknown_token
1 parent ebd6b75 commit 3e999db

File tree

5 files changed

+216
-222
lines changed

5 files changed

+216
-222
lines changed

csrc/mmdeploy/codebase/mmocr/attention_convertor.cpp

Lines changed: 5 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <algorithm>
44
#include <sstream>
55

6+
#include "base_convertor.h"
67
#include "mmdeploy/core/device.h"
78
#include "mmdeploy/core/model.h"
89
#include "mmdeploy/core/registry.h"
@@ -18,84 +19,9 @@ namespace mmdeploy::mmocr {
1819
using std::string;
1920
using std::vector;
2021

21-
class AttnConvertor : public MMOCR {
22+
class AttnConvertor : public BaseConvertor {
2223
public:
23-
explicit AttnConvertor(const Value& cfg) : MMOCR(cfg) {
24-
auto model = cfg["context"]["model"].get<Model>();
25-
if (!cfg.contains("params")) {
26-
MMDEPLOY_ERROR("'params' is required, but it's not in the config");
27-
throw_exception(eInvalidArgument);
28-
}
29-
// BaseConverter
30-
auto& _cfg = cfg["params"];
31-
if (_cfg.contains("dict_file")) {
32-
auto filename = _cfg["dict_file"].get<std::string>();
33-
auto content = model.ReadFile(filename).value();
34-
idx2char_ = SplitLines(content);
35-
} else if (_cfg.contains("dict_list")) {
36-
from_value(_cfg["dict_list"], idx2char_);
37-
} else if (_cfg.contains("dict_type")) {
38-
auto dict_type = _cfg["dict_type"].get<std::string>();
39-
if (dict_type == "DICT36") {
40-
idx2char_ = SplitChars(DICT36);
41-
} else if (dict_type == "DICT90") {
42-
idx2char_ = SplitChars(DICT90);
43-
} else {
44-
MMDEPLOY_ERROR("unknown dict_type: {}", dict_type);
45-
throw_exception(eInvalidArgument);
46-
}
47-
} else {
48-
MMDEPLOY_ERROR("either dict_file, dict_list or dict_type must be specified");
49-
throw_exception(eInvalidArgument);
50-
}
51-
// Update Dictionary
52-
53-
bool with_start = _cfg.value("with_start", false);
54-
bool with_end = _cfg.value("with_end", false);
55-
bool same_start_end = _cfg.value("same_start_end", false);
56-
bool with_padding = _cfg.value("with_padding", false);
57-
bool with_unknown = _cfg.value("with_unknown", false);
58-
if (with_start && with_end && same_start_end) {
59-
idx2char_.emplace_back("<BOS/EOS>");
60-
start_idx_ = static_cast<int>(idx2char_.size()) - 1;
61-
end_idx_ = start_idx_;
62-
} else {
63-
if (with_start) {
64-
idx2char_.emplace_back("<BOS>");
65-
start_idx_ = static_cast<int>(idx2char_.size()) - 1;
66-
}
67-
if (with_end) {
68-
idx2char_.emplace_back("<EOS>");
69-
end_idx_ = static_cast<int>(idx2char_.size()) - 1;
70-
}
71-
}
72-
73-
if (with_padding) {
74-
idx2char_.emplace_back("<PAD>");
75-
padding_idx_ = static_cast<int>(idx2char_.size()) - 1;
76-
}
77-
if (with_unknown) {
78-
idx2char_.emplace_back("<UKN>");
79-
unknown_idx_ = static_cast<int>(idx2char_.size()) - 1;
80-
}
81-
82-
vector<string> ignore_chars;
83-
if (cfg.contains("ignore_chars")) {
84-
for (int i = 0; i < cfg["ignore_chars"].size(); i++)
85-
ignore_chars.emplace_back(cfg["ignore_chars"][i].get<string>());
86-
} else {
87-
ignore_chars.emplace_back("padding");
88-
}
89-
std::map<string, int> mapping_table = {
90-
{"padding", padding_idx_}, {"end", end_idx_}, {"unknown", unknown_idx_}};
91-
for (int i = 0; i < ignore_chars.size(); i++) {
92-
if (mapping_table.find(ignore_chars[i]) != mapping_table.end()) {
93-
ignore_indexes_.emplace_back(mapping_table.at(ignore_chars[i]));
94-
}
95-
}
96-
97-
model_ = model;
98-
}
24+
explicit AttnConvertor(const Value& cfg) : BaseConvertor(cfg) {}
9925

10026
Result<Value> operator()(const Value& _data, const Value& _prob) {
10127
auto d_conf = _prob["output"].get<Tensor>();
@@ -115,11 +41,7 @@ class AttnConvertor : public MMOCR {
11541
auto w = static_cast<int>(shape[1]);
11642
auto c = static_cast<int>(shape[2]);
11743

118-
float valid_ratio = 1;
119-
if (_data["img_metas"].contains("valid_ratio")) {
120-
valid_ratio = _data["img_metas"]["valid_ratio"].get<float>();
121-
}
122-
auto [indexes, scores] = Tensor2Idx(data, w, c, valid_ratio);
44+
auto [indexes, scores] = Tensor2Idx(data, w, c);
12345

12446
auto text = Idx2Str(indexes);
12547
MMDEPLOY_DEBUG("text: {}", text);
@@ -129,8 +51,7 @@ class AttnConvertor : public MMOCR {
12951
return make_pointer(to_value(output));
13052
}
13153

132-
std::pair<vector<int>, vector<float> > Tensor2Idx(const float* data, int w, int c,
133-
float valid_ratio) {
54+
std::pair<vector<int>, vector<float> > Tensor2Idx(const float* data, int w, int c) {
13455
auto decode_len = w;
13556
vector<int> indexes;
13657
indexes.reserve(decode_len);
@@ -149,57 +70,6 @@ class AttnConvertor : public MMOCR {
14970
}
15071
return {indexes, scores};
15172
}
152-
153-
string Idx2Str(const vector<int>& indexes) {
154-
size_t count = 0;
155-
for (const auto& idx : indexes) {
156-
count += idx2char_[idx].size();
157-
}
158-
std::string text;
159-
text.reserve(count);
160-
for (const auto& idx : indexes) {
161-
text += idx2char_[idx];
162-
}
163-
return text;
164-
}
165-
166-
protected:
167-
static vector<string> SplitLines(const string& s) {
168-
std::istringstream is(s);
169-
vector<string> ret;
170-
string line;
171-
while (std::getline(is, line)) {
172-
ret.push_back(std::move(line));
173-
}
174-
return ret;
175-
}
176-
177-
static vector<string> SplitChars(const string& s) {
178-
vector<string> ret;
179-
ret.reserve(s.size());
180-
for (char c : s) {
181-
ret.push_back({c});
182-
}
183-
return ret;
184-
}
185-
186-
static constexpr const auto DICT36 = R"(0123456789abcdefghijklmnopqrstuvwxyz)";
187-
static constexpr const auto DICT90 = R"(0123456789abcdefghijklmnopqrstuvwxyz)"
188-
R"(ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'())"
189-
R"(*+,-./:;<=>?@[\]_`~)";
190-
191-
static constexpr const auto kHost = Device(0);
192-
193-
Model model_;
194-
195-
static constexpr const int blank_idx_{0};
196-
int padding_idx_{-1};
197-
int end_idx_{-1};
198-
int start_idx_{-1};
199-
int unknown_idx_{-1};
200-
201-
vector<int> ignore_indexes_;
202-
vector<string> idx2char_;
20373
};
20474

20575
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMOCR, AttnConvertor);
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
3+
#include "base_convertor.h"
4+
5+
namespace mmdeploy::mmocr {
6+
7+
using std::string;
8+
using std::unordered_map;
9+
using std::unordered_set;
10+
using std::vector;
11+
12+
BaseConvertor::BaseConvertor(const Value& cfg) : MMOCR(cfg) {
13+
auto model = cfg["context"]["model"].get<Model>();
14+
if (!cfg.contains("params")) {
15+
MMDEPLOY_ERROR("'params' is required, but it's not in the config");
16+
throw_exception(eInvalidArgument);
17+
}
18+
// BaseConverter
19+
auto& _cfg = cfg["params"];
20+
if (_cfg.contains("dict_file")) {
21+
auto filename = _cfg["dict_file"].get<std::string>();
22+
auto content = model.ReadFile(filename).value();
23+
idx2char_ = SplitLines(content);
24+
} else if (_cfg.contains("dict_list")) {
25+
from_value(_cfg["dict_list"], idx2char_);
26+
} else if (_cfg.contains("dict_type")) {
27+
auto dict_type = _cfg["dict_type"].get<std::string>();
28+
if (dict_type == "DICT36") {
29+
idx2char_ = SplitChars(DICT36);
30+
} else if (dict_type == "DICT90") {
31+
idx2char_ = SplitChars(DICT90);
32+
} else {
33+
MMDEPLOY_ERROR("unknown dict_type: {}", dict_type);
34+
throw_exception(eInvalidArgument);
35+
}
36+
} else {
37+
MMDEPLOY_ERROR("either dict_file, dict_list or dict_type must be specified");
38+
throw_exception(eInvalidArgument);
39+
}
40+
model_ = model;
41+
42+
// Update Dictionary
43+
bool with_start = _cfg.value("with_start", false);
44+
bool with_end = _cfg.value("with_end", false);
45+
bool same_start_end = _cfg.value("same_start_end", false);
46+
bool with_padding = _cfg.value("with_padding", false);
47+
bool with_unknown = _cfg.value("with_unknown", false);
48+
49+
if (with_start && with_end && same_start_end) {
50+
start_idx_ = static_cast<int>(idx2char_.size());
51+
end_idx_ = start_idx_;
52+
string start_end_token = _cfg.value("start_end_token", string("<BOS/EOS>"));
53+
idx2char_.emplace_back(std::move(start_end_token));
54+
} else {
55+
if (with_start) {
56+
start_idx_ = static_cast<int>(idx2char_.size());
57+
string start_token = _cfg.value("start_token", string("<BOS>"));
58+
idx2char_.emplace_back(std::move(start_token));
59+
}
60+
if (with_end) {
61+
end_idx_ = static_cast<int>(idx2char_.size());
62+
string end_token = _cfg.value("end_token", string("<EOS>"));
63+
idx2char_.emplace_back(std::move(end_token));
64+
}
65+
}
66+
if (with_padding) {
67+
padding_idx_ = static_cast<int>(idx2char_.size());
68+
string padding_token = _cfg.value("padding_token", string("<PAD>"));
69+
idx2char_.emplace_back(std::move(padding_token));
70+
}
71+
if (with_unknown && _cfg.contains("unknown_token") && !_cfg["unknown_token"].is_null()) {
72+
unknown_idx_ = static_cast<int>(idx2char_.size());
73+
string unknown_token = _cfg.value("unknown_token", string("<UKN>"));
74+
idx2char_.emplace_back(unknown_token);
75+
}
76+
77+
// char2idx
78+
for (int i = 0; i < (int)idx2char_.size(); i++) {
79+
char2idx_[idx2char_[i]] = i;
80+
}
81+
82+
vector<string> ignore_chars;
83+
if (cfg.contains("ignore_chars")) {
84+
for (int i = 0; i < cfg["ignore_chars"].size(); i++)
85+
ignore_chars.emplace_back(cfg["ignore_chars"][i].get<string>());
86+
} else {
87+
ignore_chars.emplace_back("padding");
88+
}
89+
std::map<string, int> mapping_table = {
90+
{"padding", padding_idx_}, {"end", end_idx_}, {"unknown", unknown_idx_}};
91+
for (int i = 0; i < ignore_chars.size(); i++) {
92+
const auto& ignore_char = ignore_chars[i];
93+
int ignore_idx = -1;
94+
95+
if (auto it_default = mapping_table.find(ignore_char); it_default != mapping_table.end()) {
96+
ignore_idx = it_default->second;
97+
} else if (auto it_candidate = char2idx_.find(ignore_char); it_candidate != char2idx_.end()) {
98+
ignore_idx = it_candidate->second;
99+
} else if (with_unknown) {
100+
ignore_idx = unknown_idx_;
101+
}
102+
103+
if (ignore_idx == -1 || (ignore_idx == unknown_idx_ && ignore_char != "unknown")) {
104+
MMDEPLOY_WARN("{} does not exist in the dictionary", ignore_char);
105+
continue;
106+
}
107+
ignore_indexes_.insert(ignore_idx);
108+
}
109+
}
110+
111+
string BaseConvertor::Idx2Str(const vector<int>& indexes) {
112+
size_t count = 0;
113+
for (const auto& idx : indexes) {
114+
if (idx >= idx2char_.size()) {
115+
MMDEPLOY_ERROR("idx exceeds array bounds {} {}", idx, idx2char_.size());
116+
}
117+
count += idx2char_[idx].size();
118+
}
119+
std::string text;
120+
text.reserve(count);
121+
for (const auto& idx : indexes) {
122+
text += idx2char_[idx];
123+
}
124+
return text;
125+
}
126+
127+
vector<string> BaseConvertor::SplitLines(const string& s) {
128+
std::istringstream is(s);
129+
vector<string> ret;
130+
string line;
131+
while (std::getline(is, line)) {
132+
ret.push_back(std::move(line));
133+
}
134+
return ret;
135+
}
136+
137+
vector<string> BaseConvertor::SplitChars(const string& s) {
138+
vector<string> ret;
139+
ret.reserve(s.size());
140+
for (char c : s) {
141+
ret.push_back({c});
142+
}
143+
return ret;
144+
}
145+
146+
} // namespace mmdeploy::mmocr
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
3+
#include <algorithm>
4+
#include <sstream>
5+
#include <unordered_map>
6+
#include <unordered_set>
7+
8+
#include "mmdeploy/core/device.h"
9+
#include "mmdeploy/core/model.h"
10+
#include "mmdeploy/core/registry.h"
11+
#include "mmdeploy/core/tensor.h"
12+
#include "mmdeploy/core/utils/device_utils.h"
13+
#include "mmdeploy/core/utils/formatter.h"
14+
#include "mmdeploy/core/value.h"
15+
#include "mmdeploy/experimental/module_adapter.h"
16+
#include "mmocr.h"
17+
18+
namespace mmdeploy::mmocr {
19+
20+
using std::string;
21+
using std::unordered_map;
22+
using std::unordered_set;
23+
using std::vector;
24+
25+
class BaseConvertor : public MMOCR {
26+
public:
27+
explicit BaseConvertor(const Value& cfg);
28+
29+
string Idx2Str(const vector<int>& indexes);
30+
31+
virtual Result<Value> operator()(const Value& _data, const Value& _prob) = 0;
32+
33+
protected:
34+
static vector<string> SplitLines(const string& s);
35+
36+
static vector<string> SplitChars(const string& s);
37+
38+
static constexpr const auto DICT36 = R"(0123456789abcdefghijklmnopqrstuvwxyz)";
39+
static constexpr const auto DICT90 = R"(0123456789abcdefghijklmnopqrstuvwxyz)"
40+
R"(ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'())"
41+
R"(*+,-./:;<=>?@[\]_`~)";
42+
43+
static constexpr const auto kHost = Device(0);
44+
45+
Model model_;
46+
47+
int padding_idx_{-1};
48+
int end_idx_{-1};
49+
int start_idx_{-1};
50+
int unknown_idx_{-1};
51+
52+
unordered_set<int> ignore_indexes_;
53+
unordered_map<string, int> char2idx_;
54+
vector<string> idx2char_;
55+
56+
}; // class BaseConvertor
57+
58+
} // namespace mmdeploy::mmocr

0 commit comments

Comments
 (0)