|
| 1 | +// Copyright (c) OpenMMLab. All rights reserved. |
| 2 | + |
| 3 | +#include <algorithm> |
| 4 | +#include <sstream> |
| 5 | + |
| 6 | +#include "mmdeploy/core/device.h" |
| 7 | +#include "mmdeploy/core/model.h" |
| 8 | +#include "mmdeploy/core/registry.h" |
| 9 | +#include "mmdeploy/core/tensor.h" |
| 10 | +#include "mmdeploy/core/utils/device_utils.h" |
| 11 | +#include "mmdeploy/core/utils/formatter.h" |
| 12 | +#include "mmdeploy/core/value.h" |
| 13 | +#include "mmdeploy/experimental/module_adapter.h" |
| 14 | +#include "mmocr.h" |
| 15 | + |
| 16 | +namespace mmdeploy::mmocr { |
| 17 | + |
| 18 | +using std::string; |
| 19 | +using std::vector; |
| 20 | + |
| 21 | +class AttnConvertor : public MMOCR { |
| 22 | + public: |
| 23 | + explicit AttnConvertor(const Value& cfg) : MMOCR(cfg) { |
| 24 | + auto model = cfg["context"]["model"].get<Model>(); |
| 25 | + if (!cfg.contains("params")) { |
| 26 | + MMDEPLOY_ERROR("'params' is required, but it's not in the config"); |
| 27 | + throw_exception(eInvalidArgument); |
| 28 | + } |
| 29 | + // BaseConverter |
| 30 | + auto& _cfg = cfg["params"]; |
| 31 | + if (_cfg.contains("dict_file")) { |
| 32 | + auto filename = _cfg["dict_file"].get<std::string>(); |
| 33 | + auto content = model.ReadFile(filename).value(); |
| 34 | + idx2char_ = SplitLines(content); |
| 35 | + } else if (_cfg.contains("dict_list")) { |
| 36 | + from_value(_cfg["dict_list"], idx2char_); |
| 37 | + } else if (_cfg.contains("dict_type")) { |
| 38 | + auto dict_type = _cfg["dict_type"].get<std::string>(); |
| 39 | + if (dict_type == "DICT36") { |
| 40 | + idx2char_ = SplitChars(DICT36); |
| 41 | + } else if (dict_type == "DICT90") { |
| 42 | + idx2char_ = SplitChars(DICT90); |
| 43 | + } else { |
| 44 | + MMDEPLOY_ERROR("unknown dict_type: {}", dict_type); |
| 45 | + throw_exception(eInvalidArgument); |
| 46 | + } |
| 47 | + } else { |
| 48 | + MMDEPLOY_ERROR("either dict_file, dict_list or dict_type must be specified"); |
| 49 | + throw_exception(eInvalidArgument); |
| 50 | + } |
| 51 | + // Update Dictionary |
| 52 | + |
| 53 | + bool with_start = _cfg.value("with_start", false); |
| 54 | + bool with_end = _cfg.value("with_end", false); |
| 55 | + bool same_start_end = _cfg.value("same_start_end", false); |
| 56 | + bool with_padding = _cfg.value("with_padding", false); |
| 57 | + bool with_unknown = _cfg.value("with_unknown", false); |
| 58 | + if (with_start && with_end && same_start_end) { |
| 59 | + idx2char_.emplace_back("<BOS/EOS>"); |
| 60 | + start_idx_ = static_cast<int>(idx2char_.size()) - 1; |
| 61 | + end_idx_ = start_idx_; |
| 62 | + } else { |
| 63 | + if (with_start) { |
| 64 | + idx2char_.emplace_back("<BOS>"); |
| 65 | + start_idx_ = static_cast<int>(idx2char_.size()) - 1; |
| 66 | + } |
| 67 | + if (with_end) { |
| 68 | + idx2char_.emplace_back("<EOS>"); |
| 69 | + end_idx_ = static_cast<int>(idx2char_.size()) - 1; |
| 70 | + } |
| 71 | + } |
| 72 | + |
| 73 | + if (with_padding) { |
| 74 | + idx2char_.emplace_back("<PAD>"); |
| 75 | + padding_idx_ = static_cast<int>(idx2char_.size()) - 1; |
| 76 | + } |
| 77 | + if (with_unknown) { |
| 78 | + idx2char_.emplace_back("<UKN>"); |
| 79 | + unknown_idx_ = static_cast<int>(idx2char_.size()) - 1; |
| 80 | + } |
| 81 | + |
| 82 | + vector<string> ignore_chars; |
| 83 | + if (cfg.contains("ignore_chars")) { |
| 84 | + for (int i = 0; i < cfg["ignore_chars"].size(); i++) |
| 85 | + ignore_chars.emplace_back(cfg["ignore_chars"][i].get<string>()); |
| 86 | + } else { |
| 87 | + ignore_chars.emplace_back("padding"); |
| 88 | + } |
| 89 | + std::map<string, int> mapping_table = { |
| 90 | + {"padding", padding_idx_}, {"end", end_idx_}, {"unknown", unknown_idx_}}; |
| 91 | + for (int i = 0; i < ignore_chars.size(); i++) { |
| 92 | + if (mapping_table.find(ignore_chars[i]) != mapping_table.end()) { |
| 93 | + ignore_indexes_.emplace_back(mapping_table.at(ignore_chars[i])); |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + model_ = model; |
| 98 | + } |
| 99 | + |
| 100 | + Result<Value> operator()(const Value& _data, const Value& _prob) { |
| 101 | + auto d_conf = _prob["output"].get<Tensor>(); |
| 102 | + |
| 103 | + if (!(d_conf.shape().size() == 3 && d_conf.data_type() == DataType::kFLOAT)) { |
| 104 | + MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", d_conf.shape(), |
| 105 | + (int)d_conf.data_type()); |
| 106 | + return Status(eNotSupported); |
| 107 | + } |
| 108 | + |
| 109 | + OUTCOME_TRY(auto h_conf, MakeAvailableOnDevice(d_conf, Device{0}, stream())); |
| 110 | + OUTCOME_TRY(stream().Wait()); |
| 111 | + |
| 112 | + auto data = h_conf.data<float>(); |
| 113 | + |
| 114 | + auto shape = d_conf.shape(); |
| 115 | + auto w = static_cast<int>(shape[1]); |
| 116 | + auto c = static_cast<int>(shape[2]); |
| 117 | + |
| 118 | + float valid_ratio = 1; |
| 119 | + if (_data["img_metas"].contains("valid_ratio")) { |
| 120 | + valid_ratio = _data["img_metas"]["valid_ratio"].get<float>(); |
| 121 | + } |
| 122 | + auto [indexes, scores] = Tensor2Idx(data, w, c, valid_ratio); |
| 123 | + |
| 124 | + auto text = Idx2Str(indexes); |
| 125 | + MMDEPLOY_DEBUG("text: {}", text); |
| 126 | + |
| 127 | + TextRecognition output{text, scores}; |
| 128 | + |
| 129 | + return make_pointer(to_value(output)); |
| 130 | + } |
| 131 | + |
| 132 | + std::pair<vector<int>, vector<float> > Tensor2Idx(const float* data, int w, int c, |
| 133 | + float valid_ratio) { |
| 134 | + auto decode_len = w; |
| 135 | + vector<int> indexes; |
| 136 | + indexes.reserve(decode_len); |
| 137 | + vector<float> scores; |
| 138 | + scores.reserve(decode_len); |
| 139 | + for (int t = 0; t < decode_len; ++t, data += c) { |
| 140 | + vector<float> prob(data, data + c); |
| 141 | + auto iter = max_element(begin(prob), end(prob)); |
| 142 | + auto index = static_cast<int>(iter - begin(prob)); |
| 143 | + if (index == end_idx_) break; |
| 144 | + if (std::find(ignore_indexes_.begin(), ignore_indexes_.end(), index) == |
| 145 | + ignore_indexes_.end()) { |
| 146 | + indexes.push_back(index); |
| 147 | + scores.push_back(*iter); |
| 148 | + } |
| 149 | + } |
| 150 | + return {indexes, scores}; |
| 151 | + } |
| 152 | + |
| 153 | + string Idx2Str(const vector<int>& indexes) { |
| 154 | + size_t count = 0; |
| 155 | + for (const auto& idx : indexes) { |
| 156 | + count += idx2char_[idx].size(); |
| 157 | + } |
| 158 | + std::string text; |
| 159 | + text.reserve(count); |
| 160 | + for (const auto& idx : indexes) { |
| 161 | + text += idx2char_[idx]; |
| 162 | + } |
| 163 | + return text; |
| 164 | + } |
| 165 | + |
| 166 | + protected: |
| 167 | + static vector<string> SplitLines(const string& s) { |
| 168 | + std::istringstream is(s); |
| 169 | + vector<string> ret; |
| 170 | + string line; |
| 171 | + while (std::getline(is, line)) { |
| 172 | + ret.push_back(std::move(line)); |
| 173 | + } |
| 174 | + return ret; |
| 175 | + } |
| 176 | + |
| 177 | + static vector<string> SplitChars(const string& s) { |
| 178 | + vector<string> ret; |
| 179 | + ret.reserve(s.size()); |
| 180 | + for (char c : s) { |
| 181 | + ret.push_back({c}); |
| 182 | + } |
| 183 | + return ret; |
| 184 | + } |
| 185 | + |
| 186 | + static constexpr const auto DICT36 = R"(0123456789abcdefghijklmnopqrstuvwxyz)"; |
| 187 | + static constexpr const auto DICT90 = R"(0123456789abcdefghijklmnopqrstuvwxyz)" |
| 188 | + R"(ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'())" |
| 189 | + R"(*+,-./:;<=>?@[\]_`~)"; |
| 190 | + |
| 191 | + static constexpr const auto kHost = Device(0); |
| 192 | + |
| 193 | + Model model_; |
| 194 | + |
| 195 | + static constexpr const int blank_idx_{0}; |
| 196 | + int padding_idx_{-1}; |
| 197 | + int end_idx_{-1}; |
| 198 | + int start_idx_{-1}; |
| 199 | + int unknown_idx_{-1}; |
| 200 | + |
| 201 | + vector<int> ignore_indexes_; |
| 202 | + vector<string> idx2char_; |
| 203 | +}; |
| 204 | + |
| 205 | +MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMOCR, AttnConvertor); |
| 206 | + |
| 207 | +} // namespace mmdeploy::mmocr |
0 commit comments