From 7063fb9e38ec9bcd5e64ab1b592555dfb3fae5bb Mon Sep 17 00:00:00 2001
From: topduke <784990967@qq.com>
Date: Tue, 18 Feb 2025 09:42:28 +0000
Subject: [PATCH 1/2] add openocr igtr method and fix parseq and ppstruct bug
---
configs/rec/rec_svtrnet_igtr.yml | 149 +++
docs/algorithm/overview.en.md | 2 +
docs/algorithm/overview.md | 1 +
.../text_recognition/algorithm_rec_cppd.en.md | 2 +-
.../text_recognition/algorithm_rec_cppd.md | 2 +-
.../text_recognition/algorithm_rec_igtr.en.md | 139 +++
.../text_recognition/algorithm_rec_igtr.md | 144 +++
ppocr/data/__init__.py | 27 +-
ppocr/data/imaug/label_ops.py | 403 +++++++++
ppocr/data/multi_scale_sampler.py | 183 ++++
ppocr/data/ratio_dataset.py | 228 +++++
ppocr/losses/__init__.py | 2 +
ppocr/losses/rec_igtr_loss.py | 12 +
ppocr/modeling/backbones/__init__.py | 2 +
ppocr/modeling/backbones/rec_svtrnet.py | 5 +-
ppocr/modeling/backbones/rec_svtrnet2dpos.py | 663 ++++++++++++++
ppocr/modeling/heads/__init__.py | 2 +
ppocr/modeling/heads/rec_igtr_head.py | 847 ++++++++++++++++++
ppocr/modeling/heads/rec_parseq_head.py | 2 +-
ppocr/postprocess/__init__.py | 2 +
ppocr/postprocess/rec_postprocess.py | 89 ++
ppocr/utils/save_load.py | 4 +
ppstructure/recovery/table_process.py | 14 +-
tools/program.py | 2 +
24 files changed, 2902 insertions(+), 24 deletions(-)
create mode 100644 configs/rec/rec_svtrnet_igtr.yml
create mode 100644 docs/algorithm/text_recognition/algorithm_rec_igtr.en.md
create mode 100644 docs/algorithm/text_recognition/algorithm_rec_igtr.md
create mode 100644 ppocr/data/ratio_dataset.py
create mode 100644 ppocr/losses/rec_igtr_loss.py
create mode 100644 ppocr/modeling/backbones/rec_svtrnet2dpos.py
create mode 100644 ppocr/modeling/heads/rec_igtr_head.py
diff --git a/configs/rec/rec_svtrnet_igtr.yml b/configs/rec/rec_svtrnet_igtr.yml
new file mode 100644
index 00000000000..2d978b8a09d
--- /dev/null
+++ b/configs/rec/rec_svtrnet_igtr.yml
@@ -0,0 +1,149 @@
+Global:
+ use_gpu: True
+ epoch_num: 20
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec/svtr_igtr_base/
+ save_epoch_step: 1
+ # evaluation is run every 2000 iterations after the 0th iteration
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_10.png
+ # for data or label process
+ character_type: en
+ character_dict_path: &character_dict_path
+ max_text_length: &max_text_length 25
+ infer_mode: False
+ use_space_char: &use_space_char False
+ save_res_path: ./output/rec/predicts_svtr_igtr_base.txt
+
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.99
+ epsilon: 1.e-8
+ weight_decay: 0.05
+ no_weight_decay_name: norm pos_embed char_node_embed pos_node_embed char_pos_embed vis_pos_embed
+ one_dim_param_no_weight_decay: True
+ lr:
+ name: Cosine
+ learning_rate: 0.0005 # 4gpus 256bs
+ warmup_epoch: 2
+
+Architecture:
+ model_type: rec
+ algorithm: IGTR
+ Transform:
+ Backbone:
+ name: SVTRNet2DPos
+ img_size: [32, -1]
+ out_char_num: 25
+ out_channels: 256
+ patch_merging: 'Conv'
+ embed_dim: [128, 256, 384]
+ depth: [6, 6, 6]
+ num_heads: [4, 8, 12]
+ mixer: ['ConvB','ConvB','ConvB','ConvB','ConvB','ConvB', 'ConvB','ConvB', 'Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
+ local_mixer: [[5, 5], [5, 5], [5, 5]]
+ last_stage: False
+ prenorm: True
+ use_first_sub: False
+ Head:
+ name: IGTRHead
+ dim: 384
+ num_layer: 1
+ ar: False
+ refine_iter: 0
+ next_pred: False
+ pos2d: True
+ ds: True
+
+Loss:
+ name: IGTRLoss
+
+PostProcess:
+ name: IGTRLabelDecode
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: False
+ max_ratio: 4
+ data_dir_list: ['./train_data/data_lmdb_release/training/data_name1',
+ './train_data/data_lmdb_release/training/data_name2']
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - IGTRLabelEncode: # Class handling label
+ k: 8
+ prompt_error: False
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'prompt_pos_idx_list',
+ 'prompt_char_idx_list', 'ques_pos_idx_list', 'ques1_answer_list',
+ 'ques2_char_idx_list', 'ques2_answer_list', 'ques3_answer', 'ques4_char_num_list',
+ 'ques_len_list', 'ques2_len_list', 'prompt_len_list', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: True
+ batch_size_per_card: 256
+ drop_last: True
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: RatioDataSet
+ ds_width: True
+ padding: False
+ max_ratio: 4
+ data_dir_list: ['./train_data/data_lmdb_release/evaluation/CUTE80',
+ './train_data/data_lmdb_release/evaluation/IC13_857',
+ './train_data/data_lmdb_release/evaluation/IC15_1811',
+ './train_data/data_lmdb_release/evaluation/IIIT5k',
+ './train_data/data_lmdb_release/evaluation/SVT',
+ './train_data/data_lmdb_release/evaluation/SVTP']
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - ARLabelEncode: # Class handling label
+ character_dict_path: *character_dict_path
+ use_space_char: *use_space_char
+ max_text_length: *max_text_length
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+ sampler:
+ name: RatioSampler
+ scales: [[128, 32]] # w, h
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 256
+ fix_bs: false
+ divided_factor: [4, 16] # w, h
+ is_training: False
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 4
diff --git a/docs/algorithm/overview.en.md b/docs/algorithm/overview.en.md
index e7551214ce5..06c01461db5 100755
--- a/docs/algorithm/overview.en.md
+++ b/docs/algorithm/overview.en.md
@@ -78,6 +78,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [CPPD](./text_recognition/algorithm_rec_cppd.en.md)
- [x] [SATRN](./text_recognition/algorithm_rec_satrn.en.md)
+
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|Model|Backbone|Avg Accuracy|Module combination|Download link|
@@ -104,6 +105,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|ParseQ|VIT| 91.24% | rec_vit_parseq_synth | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz) |
|CPPD|SVTR-Base| 93.8% | rec_svtrnet_cppd_base_en | [trained model](https://paddleocr.bj.bcebos.com/CCPD/rec_svtr_cppd_base_en_train.tar) |
|SATRN|ShallowCNN| 88.05% | rec_satrn | [trained model](https://pan.baidu.com/s/10J-Bsd881bimKaclKszlaQ?pwd=lk8a) |
+|IGTR|SVTR-Base| 94.78% | rec_svtr_igtr | [trained model](https://paddleocr.bj.bcebos.com/igtr/rec_svtr_igtr_train.tar) |
### 1.3 Text Super-Resolution Algorithms
diff --git a/docs/algorithm/overview.md b/docs/algorithm/overview.md
index 3c6cff9b137..47352d89781 100755
--- a/docs/algorithm/overview.md
+++ b/docs/algorithm/overview.md
@@ -105,6 +105,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|ParseQ|VIT| 91.24% | rec_vit_parseq_synth | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz) |
|CPPD|SVTR-Base| 93.8% | rec_svtrnet_cppd_base_en | [训练模型](https://paddleocr.bj.bcebos.com/CCPD/rec_svtr_cppd_base_en_train.tar) |
|SATRN|ShallowCNN| 88.05% | rec_satrn | [训练模型](https://pan.baidu.com/s/10J-Bsd881bimKaclKszlaQ?pwd=lk8a) |
+|IGTR|SVTR-Base| 94.78% | rec_svtr_igtr | [训练模型](https://paddleocr.bj.bcebos.com/igtr/rec_svtr_igtr_train.tar) |
### 1.3 文本超分辨率算法
diff --git a/docs/algorithm/text_recognition/algorithm_rec_cppd.en.md b/docs/algorithm/text_recognition/algorithm_rec_cppd.en.md
index 04b19c9b106..6e6712fc41d 100644
--- a/docs/algorithm/text_recognition/algorithm_rec_cppd.en.md
+++ b/docs/algorithm/text_recognition/algorithm_rec_cppd.en.md
@@ -70,7 +70,7 @@ Specifically, after the data preparation is completed, the training can be start
python3 tools/train.py -c configs/rec/rec_svtrnet_cppd_base_en.yml
# Multi GPU training, specify the gpu number through the --gpus parameter
-python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_svtrnet_cppd_base_en.yml
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_svtrnet_cppd_base_en.yml
```
### Evaluation
diff --git a/docs/algorithm/text_recognition/algorithm_rec_cppd.md b/docs/algorithm/text_recognition/algorithm_rec_cppd.md
index aa1c077c6b8..71a4794e0c3 100644
--- a/docs/algorithm/text_recognition/algorithm_rec_cppd.md
+++ b/docs/algorithm/text_recognition/algorithm_rec_cppd.md
@@ -86,7 +86,7 @@ CPPD在场景文本识别公开数据集上的精度(%)和模型文件如下:
python3 tools/train.py -c configs/rec/rec_svtrnet_cppd_base_en.yml
# 多卡训练,通过--gpus参数指定卡号
-python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_svtrnet_cppd_base_en.yml
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_svtrnet_cppd_base_en.yml
```
### 3.2 评估
diff --git a/docs/algorithm/text_recognition/algorithm_rec_igtr.en.md b/docs/algorithm/text_recognition/algorithm_rec_igtr.en.md
new file mode 100644
index 00000000000..6cbd6cbb4a6
--- /dev/null
+++ b/docs/algorithm/text_recognition/algorithm_rec_igtr.en.md
@@ -0,0 +1,139 @@
+---
+comments: true
+---
+
+# IGTR
+
+## 1. Introduction
+
+Paper:
+> [Instruction-Guided Scene Text Recognition](https://arxiv.org/abs/2401.17851),
+> Yongkun Du, Zhineng Chen, Yuchen Su, Caiyan Jia, Yu-Gang Jiang,
+> TPAMI 2025,
+> Source Repository: [OpenOCR](https://github.com/Topdu/OpenOCR)
+
+Multi-modal models have shown appealing performance in visual recognition tasks, as free-form text-guided training evokes the ability to understand fine-grained visual content. However, current models cannot be trivially applied to scene text recognition (STR) due to the compositional difference between natural and text images. We propose a novel instruction-guided scene text recognition (IGTR) paradigm that formulates STR as an instruction learning problem and understands text images by predicting character attributes, e.g., character frequency, position, etc. IGTR first devises $\left \langle condition,question,answer \right \rangle$ instruction triplets, providing rich and diverse descriptions of character attributes. To effectively learn these attributes through question-answering, IGTR develops a lightweight instruction encoder, a cross-modal feature fusion module and a multi-task answer head, which guides nuanced text image understanding. Furthermore, IGTR realizes different recognition pipelines simply by using different instructions, enabling a character-understanding-based text reasoning paradigm that differs from current methods considerably. Experiments on English and Chinese benchmarks show that IGTR outperforms existing models by significant margins, while maintaining a small model size and fast inference speed. Moreover, by adjusting the sampling of instructions, IGTR offers an elegant way to tackle the recognition of rarely appearing and morphologically similar characters, which were previous challenges.
+
+The accuracy (%) and model files of IGTR on the public dataset of scene text recognition are as follows::
+
+- Trained on Synth dataset(MJ+ST), test on Common Benchmarks, training and test datasets both from [PARSeq](https://github.com/baudm/parseq).
+
+| Model | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log |
+| :-----: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------------------------------------------------------------------------------: |
+| IGTR-PD | 97.6 | 95.2 | 97.6 | 88.4 | 91.6 | 95.5 | 94.30 | TODO |
+| IGTR-AR | 98.6 | 95.7 | 98.2 | 88.4 | 92.4 | 95.5 | 94.78 | as above |
+
+- Test on Union14M-Benchmark, from [Union14M](https://github.com/Mountchicken/Union14M/).
+
+| Model | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log |
+| :-----: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: |
+| IGTR-PD | 76.9 | 30.6 | 59.1 | 63.3 | 77.8 | 62.5 | 66.7 | 62.40 | Same as the above table |
+| IGTR-AR | 78.4 | 31.9 | 61.3 | 66.5 | 80.2 | 69.3 | 67.9 | 65.07 | as above |
+
+- Trained on Union14M-L-LMDB-Filtered training dataset.
+
+| Model | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log |
+| :----------: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------------------------------------------------------------------------------: |
+| IGTR-PD | 97.7 | 97.7 | 98.3 | 89.8 | 93.7 | 97.9 | 95.86 | [PaddleOCR Model](https://paddleocr.bj.bcebos.com/igtr/rec_svtr_igtr_train.tar) |
+| IGTR-AR | 98.1 | 98.4 | 98.7 | 90.5 | 94.9 | 98.3 | 96.48 | as above |
+| IGTR-PD-60ep | 97.9 | 98.3 | 99.2 | 90.8 | 93.7 | 97.6 | 96.24 | TODO|
+| IGTR-AR-60ep | 98.4 | 98.1 | 99.3 | 91.5 | 94.3 | 97.6 | 96.54 | as above |
+| IGTR-PD-PT | 98.6 | 98.0 | 99.1 | 91.7 | 96.8 | 99.0 | 97.20 | TODO |
+| IGTR-AR-PT | 98.8 | 98.3 | 99.2 | 92.0 | 96.8 | 99.0 | 97.34 | as above |
+
+| Model | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log |
+| :----------: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: |
+| IGTR-PD | 88.1 | 89.9 | 74.2 | 80.3 | 82.8 | 79.2 | 83.0 | 82.51 | Same as the above table |
+| IGTR-AR | 90.4 | 91.2 | 77.0 | 82.4 | 84.7 | 84.0 | 84.4 | 84.86 | as above |
+| IGTR-PD-60ep | 90.0 | 92.1 | 77.5 | 82.8 | 86.0 | 83.0 | 84.8 | 85.18 | Same as the above table |
+| IGTR-AR-60ep | 91.0 | 93.0 | 78.7 | 84.6 | 87.3 | 84.8 | 85.6 | 86.43 | as above |
+| IGTR-PD-PT | 92.4 | 92.1 | 80.7 | 83.6 | 87.7 | 86.9 | 85.0 | 86.92 | Same as the above table |
+| IGTR-AR-PT | 93.0 | 92.9 | 81.3 | 83.4 | 88.6 | 88.7 | 85.6 | 87.65 | as above |
+
+- Trained and test on Chinese dataset, from [Chinese Benckmark](https://github.com/FudanVI/benchmarking-chinese-text-recognition).
+
+| Model | Scene | Web | Document | Handwriting | Avg | Config&Model&Log |
+| :---------: | :---: | :--: | :------: | :---------: | :---: | :---------------------------------------------------------------------------------------------: |
+| IGTR-PD | 73.1 | 74.8 | 98.6 | 52.5 | 74.75 | |
+| IGTR-AR | 75.1 | 76.4 | 98.7 | 55.3 | 76.37 | |
+| IGTR-PD-TS | 73.5 | 75.9 | 98.7 | 54.5 | 75.65 | TODO |
+| IGTR-AR-TS | 75.6 | 77.0 | 98.8 | 57.3 | 77.17 | as above |
+| IGTR-PD-Aug | 79.5 | 80.0 | 99.4 | 58.9 | 79.45 | TODO |
+| IGTR-AR-Aug | 82.0 | 81.7 | 99.5 | 63.8 | 81.74 | as above |
+
+Download all Configs, Models, and Logs from [OpenOCR](https://github.com/Topdu/OpenOCR/blob/main/configs/rec/igtr/readme.md), and then convert to paddleocr model file.
+
+## 2. Environment
+
+Please refer to ["Environment Preparation"](../../ppocr/environment.en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](../../ppocr/blog/clone.en.md)to clone the project code.
+
+### Dataset Preparation
+
+- [English dataset download](https://github.com/baudm/parseq)
+
+- [Union14M-L-LMDB-Filtered download](https://github.com/Topdu/OpenOCR/blob/main/docs/svtrv2.md#downloading-datasets)
+
+- [Chinese dataset download](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](../../ppocr/model_train/recognition.en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+
+### Training
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```bash linenums="1"
+# Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_svtrnet_igtr.yml
+
+# Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_svtrnet_igtr.yml
+```
+
+### Evaluation
+
+You can download the model files and configuration files provided by `IGTR`: [download link](https://paddleocr.bj.bcebos.com/igtr/rec_svtr_igtr_train.tar), using the following command to evaluate:
+
+```bash linenums="1"
+# Download the tar archive containing the model files and configuration files of IGTR-B and extract it
+wget https://paddleocr.bj.bcebos.com/igtr/rec_svtr_igtr_train.tar && tar xf rec_svtr_igtr_train.tar
+# GPU evaluation
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_svtrnet_igtr.yml -o Global.pretrained_model=./rec_svtr_igtr_train/best_model
+```
+
+### Prediction
+
+```bash linenums="1"
+python3 tools/infer_rec.py -c configs/rec/rec_svtrnet_igtr.yml -o Global.infer_img='./doc/imgs_words/word_10.png' Global.pretrained_model=./rec_svtr_igtr_train/best_model
+```
+
+## 4. Inference and Deployment
+
+### 4.1 Python Inference
+
+Coming soon.
+
+### 4.2 C++ Inference
+
+Not supported
+
+### 4.3 Serving
+
+Not supported
+
+### 4.4 More
+
+Not supported
+
+## Citation
+
+```bibtex
+@article{Du2025IGTR,
+ title = {Instruction-Guided Scene Text Recognition},
+ author = {Du, Yongkun and Chen, Zhineng and Su, Yuchen and Jia, Caiyan and Jiang, Yu-Gang},
+ journal = {IEEE Trans. Pattern Anal. Mach. Intell.},
+ year = {2025},
+ url = {https://arxiv.org/abs/2401.17851}
+}
+```
diff --git a/docs/algorithm/text_recognition/algorithm_rec_igtr.md b/docs/algorithm/text_recognition/algorithm_rec_igtr.md
new file mode 100644
index 00000000000..efa50d38968
--- /dev/null
+++ b/docs/algorithm/text_recognition/algorithm_rec_igtr.md
@@ -0,0 +1,144 @@
+---
+comments: true
+---
+
+# 场景文本识别算法-IGTR
+
+## 1. 算法简介
+
+论文信息:
+> [Instruction-Guided Scene Text Recognition](https://arxiv.org/abs/2401.17851),
+> Yongkun Du, Zhineng Chen, Yuchen Su, Caiyan Jia, Yu-Gang Jiang,
+> TPAMI 2025,
+> 源仓库: [OpenOCR](https://github.com/Topdu/OpenOCR)
+
+### IGTR算法简介
+
+IGTR是由复旦大学[FVL实验室](https://fvl.fudan.edu.cn/) [OCR团队](https://github.com/Topdu/OpenOCR)提出的基于指令学习的场景文本识别(STR)方法。IGTR将STR视为一个跨模态指令学习问题,通过预测字符属性(如频率、位置等)来理解文本图像。具体而言,IGTR提出了以⟨条件,问题,答案⟩三元组格式的指令,提供丰富的字符属性描述;开发了轻量级指令编码器、跨模态特征融合模块和多任务答案头,增强文本图像理解能力。IGTR在英文和中文基准测试中均显著优于现有模型,同时保持较小的模型尺寸和快速推理速度。此外,通过调整指令采样规则,IGTR能够优雅地解决罕见字符和形态相似字符的识别问题。IGTR开创了基于指令学习的STR新范式,为多模态模型在特定任务中的应用提供了重要参考。
+
+IGTR在场景文本识别公开数据集上的精度(%)和模型文件如下:
+- 合成数据集(MJ+ST)训练,在Common Benchmarks测试, 训练集和测试集来自于 [PARSeq](https://github.com/baudm/parseq).
+
+| Model | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log |
+| :-----: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------------------------------------------------------------------------------: |
+| IGTR-PD | 97.6 | 95.2 | 97.6 | 88.4 | 91.6 | 95.5 | 94.30 | TODO |
+| IGTR-AR | 98.6 | 95.7 | 98.2 | 88.4 | 92.4 | 95.5 | 94.78 | as above |
+
+- 合成数据集(MJ+ST)训练,在Union14M-Benchmark测试, 测试集来自于 [Union14M](https://github.com/Mountchicken/Union14M/).
+
+| Model | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log |
+| :-----: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: |
+| IGTR-PD | 76.9 | 30.6 | 59.1 | 63.3 | 77.8 | 62.5 | 66.7 | 62.40 | Same as the above table |
+| IGTR-AR | 78.4 | 31.9 | 61.3 | 66.5 | 80.2 | 69.3 | 67.9 | 65.07 | as above |
+
+- 在大规模真实数据集Union14M-L-LMDB-Filtered上训练的结果.
+
+| Model | IC13
857 | SVT | IIIT5k
3000 | IC15
1811 | SVTP | CUTE80 | Avg | Config&Model&Log |
+| :----------: | :----------: | :--: | :-------------: | :-----------: | :--: | :----: | :---: | :---------------------------------------------------------------------------------------------: |
+| IGTR-PD | 97.7 | 97.7 | 98.3 | 89.8 | 93.7 | 97.9 | 95.86 | [PaddleOCR Model](https://paddleocr.bj.bcebos.com/igtr/rec_svtr_igtr_train.tar) |
+| IGTR-AR | 98.1 | 98.4 | 98.7 | 90.5 | 94.9 | 98.3 | 96.48 | as above |
+| IGTR-PD-60ep | 97.9 | 98.3 | 99.2 | 90.8 | 93.7 | 97.6 | 96.24 | TODO|
+| IGTR-AR-60ep | 98.4 | 98.1 | 99.3 | 91.5 | 94.3 | 97.6 | 96.54 | as above |
+| IGTR-PD-PT | 98.6 | 98.0 | 99.1 | 91.7 | 96.8 | 99.0 | 97.20 | TODO |
+| IGTR-AR-PT | 98.8 | 98.3 | 99.2 | 92.0 | 96.8 | 99.0 | 97.34 | as above |
+
+| Model | Curve | Multi-
Oriented | Artistic | Contextless | Salient | Multi-
word | General | Avg | Config&Model&Log |
+| :----------: | :---: | :-----------------: | :------: | :---------: | :-----: | :-------------: | :-----: | :---: | :---------------------: |
+| IGTR-PD | 88.1 | 89.9 | 74.2 | 80.3 | 82.8 | 79.2 | 83.0 | 82.51 | Same as the above table |
+| IGTR-AR | 90.4 | 91.2 | 77.0 | 82.4 | 84.7 | 84.0 | 84.4 | 84.86 | as above |
+| IGTR-PD-60ep | 90.0 | 92.1 | 77.5 | 82.8 | 86.0 | 83.0 | 84.8 | 85.18 | Same as the above table |
+| IGTR-AR-60ep | 91.0 | 93.0 | 78.7 | 84.6 | 87.3 | 84.8 | 85.6 | 86.43 | as above |
+| IGTR-PD-PT | 92.4 | 92.1 | 80.7 | 83.6 | 87.7 | 86.9 | 85.0 | 86.92 | Same as the above table |
+| IGTR-AR-PT | 93.0 | 92.9 | 81.3 | 83.4 | 88.6 | 88.7 | 85.6 | 87.65 | as above |
+
+- 中文文本识别的结果, 训练集和测试集来自于 [Chinese Benckmark](https://github.com/FudanVI/benchmarking-chinese-text-recognition).
+
+| Model | Scene | Web | Document | Handwriting | Avg | Config&Model&Log |
+| :---------: | :---: | :--: | :------: | :---------: | :---: | :---------------------------------------------------------------------------------------------: |
+| IGTR-PD | 73.1 | 74.8 | 98.6 | 52.5 | 74.75 | |
+| IGTR-AR | 75.1 | 76.4 | 98.7 | 55.3 | 76.37 | |
+| IGTR-PD-TS | 73.5 | 75.9 | 98.7 | 54.5 | 75.65 | TODO |
+| IGTR-AR-TS | 75.6 | 77.0 | 98.8 | 57.3 | 77.17 | as above |
+| IGTR-PD-Aug | 79.5 | 80.0 | 99.4 | 58.9 | 79.45 | TODO |
+| IGTR-AR-Aug | 82.0 | 81.7 | 99.5 | 63.8 | 81.74 | as above |
+
+从[OpenOCR](https://github.com/Topdu/OpenOCR/blob/main/configs/rec/igtr/readme.md)可以下载所有的模型文本和训练日志, 将模型文件转换为符合paddleocr 模型参数的要求后,即可在PaddleOCR中使用.
+
+## 2. 环境配置
+
+请先参考[《运行环境准备》](../../ppocr/environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](../../ppocr/blog/clone.md)克隆项目代码。
+
+## 3. 模型训练、评估、预测
+
+### 3.1 模型训练
+
+#### 数据集准备
+
+[英文数据集下载](https://github.com/baudm/parseq)
+
+[Union14M-L-LMDB-Filtered](https://github.com/Mountchicken/Union14M)
+
+[中文数据集下载](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
+
+#### 启动训练
+
+请参考[文本识别训练教程](../../ppocr/model_train/recognition.md)。在完成数据准备后,便可以启动训练,训练命令如下:
+
+```bash linenums="1"
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_svtrnet_igtr.yml
+
+# 多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_svtrnet_igtr.yml
+```
+
+### 3.2 评估
+
+可下载`IGTR`提供的模型文件和配置文件:[下载地址](https://paddleocr.bj.bcebos.com/igtr/rec_svtr_igtr_train.tar) ,使用如下命令进行评估:
+
+```bash linenums="1"
+# 下载包含IGTR的模型文件和配置文件的tar压缩包并解压
+wget https://paddleocr.bj.bcebos.com/igtr/rec_svtr_igtr_train.tar && tar xf rec_svtr_igtr_train.tar
+# 注意将pretrained_model的路径设置为本地路径。
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_svtrnet_igtr.yml -o Global.pretrained_model=./rec_svtr_igtr_train/best_model
+```
+
+### 3.3 预测
+
+使用如下命令进行单张图片预测:
+
+```bash linenums="1"
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/infer_rec.py -c configs/rec/rec_svtrnet_igtr.yml -o Global.infer_img='./doc/imgs_words/word_10.png' Global.pretrained_model=./rec_svtr_igtr_train/best_model
+# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
+```
+
+## 4. 推理部署
+
+### 4.1 Python推理
+
+即将实现
+
+### 4.2 C++推理部署
+
+暂不支持
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+### 4.4 更多推理部署
+
+暂不支持
+
+## 引用
+
+```bibtex
+@article{Du2025IGTR,
+ title = {Instruction-Guided Scene Text Recognition},
+ author = {Du, Yongkun and Chen, Zhineng and Su, Yuchen and Jia, Caiyan and Jiang, Yu-Gang},
+ journal = {IEEE Trans. Pattern Anal. Mach. Intell.},
+ year = {2025},
+ url = {https://arxiv.org/abs/2401.17851}
+}
+```
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 5678aebec14..73b600a4938 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -37,8 +37,9 @@
from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR, LMDBDataSetTableMaster
from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
-from ppocr.data.multi_scale_sampler import MultiScaleSampler
+from ppocr.data.multi_scale_sampler import MultiScaleSampler, RatioSampler
from ppocr.data.latexocr_dataset import LaTeXOCRDataSet
+from ppocr.data.ratio_dataset import RatioDataSet
# for PaddleX dataset_type
TextDetDataset = SimpleDataSet
@@ -97,6 +98,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
"PubTabTableRecDataset",
"KieDataset",
"LaTeXOCRDataSet",
+ "RatioDataSet",
]
module_name = config[mode]["dataset"]["name"]
assert module_name in support_dict, Exception(
@@ -115,19 +117,18 @@ def build_dataloader(config, mode, device, logger, seed=None):
else:
use_shared_memory = True
- if mode == "Train":
+ if "sampler" in config[mode]:
+ config_sampler = config[mode]["sampler"]
+ sampler_name = config_sampler.pop("name")
+ batch_sampler = eval(sampler_name)(dataset, **config_sampler)
+ elif mode == "Train":
# Distribute data to multiple cards
- if "sampler" in config[mode]:
- config_sampler = config[mode]["sampler"]
- sampler_name = config_sampler.pop("name")
- batch_sampler = eval(sampler_name)(dataset, **config_sampler)
- else:
- batch_sampler = DistributedBatchSampler(
- dataset=dataset,
- batch_size=batch_size,
- shuffle=shuffle,
- drop_last=drop_last,
- )
+ batch_sampler = DistributedBatchSampler(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ drop_last=drop_last,
+ )
else:
# Distribute data to single card
batch_sampler = BatchSampler(
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index df282df4d6a..ffbd3a6d73c 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -2191,3 +2191,406 @@ def __call__(self, data):
data["label"] = np.array(topk["input_ids"]).astype(np.int64)[0]
data["attention_mask"] = np.array(topk["attention_mask"]).astype(np.int64)[0]
return data
+
+
+class ARLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ BOS = ""
+ EOS = ""
+ PAD = ""
+
+ def __init__(
+ self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
+ ):
+ super(ARLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char
+ )
+
+ def __call__(self, data):
+ text = data["label"]
+ text = self.encode(text)
+ if text is None:
+ return None
+ data["length"] = np.array(len(text))
+ text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
+ text = text + [self.dict[self.PAD]] * (self.max_text_len + 2 - len(text))
+ data["label"] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
+ return dict_character
+
+
+class IGTRLabelEncode(BaseRecLabelEncode):
+ """Convert between text-label and text-index."""
+
+ def __init__(
+ self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ k=1,
+ ch=False,
+ prompt_error=False,
+ **kwargs,
+ ):
+ super(IGTRLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char
+ )
+ self.ignore_index = self.dict[""]
+ self.k = k
+ self.prompt_error = prompt_error
+ self.ch = ch
+ rare_file = kwargs.get("rare_file", None)
+ siml_file = kwargs.get("siml_file", None)
+ siml_char_dict = {}
+ siml_char_list = [0 for _ in range(self.num_character)]
+ if siml_file is not None:
+ with open(siml_file, "r") as f:
+ for lin in f.readlines():
+ lin_s = lin.strip().split("\t")
+ char_siml = lin_s[0]
+ if char_siml in self.dict:
+ siml_list = []
+ siml_prob = []
+ for i in range(1, len(lin_s), 2):
+ c = lin_s[i]
+ prob = int(lin_s[i + 1])
+ if c in self.dict and prob >= 1:
+ siml_list.append(self.dict[c])
+ siml_prob.append(prob)
+ siml_prob = np.array(siml_prob, dtype=np.float32) / sum(
+ siml_prob
+ )
+ siml_char_dict[self.dict[char_siml]] = [
+ siml_list,
+ siml_prob.tolist(),
+ ]
+ siml_char_list[self.dict[char_siml]] = 1
+ self.siml_char_dict = siml_char_dict
+ self.siml_char_list = siml_char_list
+
+ rare_char_list = [0 for _ in range(self.num_character)]
+ if rare_file is not None:
+ with open(rare_file, "r") as f:
+ for lin in f.readlines():
+ lin_s = lin.strip().split("\t")
+ # print(lin_s)
+ char_rare = lin_s[0]
+ num_appear = int(lin_s[1])
+ if char_rare in self.dict and num_appear < 1000:
+ rare_char_list[self.dict[char_rare]] = 1
+
+ self.rare_char_list = (
+ rare_char_list # [self.dict[char] for char in rare_char_list]
+ )
+
+ def __call__(self, data):
+ text = data["label"] # coffee
+
+ encoder_result = self.encode(text)
+ if encoder_result is None:
+ return None
+
+ text, text_char_num, ques_list_s, prompt_list_s = encoder_result
+
+ if len(text) > self.max_text_len:
+ return None
+ data["length"] = np.array(len(text))
+
+ text = [self.dict[""]] + text + [self.dict[""]]
+ text = text + [self.dict[""]] * (self.max_text_len + 2 - len(text))
+ data["label"] = np.array(text) # 6
+
+ ques_len_list = []
+ ques2_len_list = []
+ prompt_len_list = []
+
+ prompt_pos_idx_list = []
+ prompt_char_idx_list = []
+ ques_pos_idx_list = []
+ ques1_answer_list = []
+ ques2_char_idx_list = []
+ ques2_answer_list = []
+ ques4_char_num_list = []
+ train_step = 0
+ for prompt_list, ques_list in zip(prompt_list_s, ques_list_s):
+
+ prompt_len = len(prompt_list) + 1
+ prompt_len_list.append(prompt_len)
+ prompt_list = np.array(
+ [[0, self.dict[""], 0]]
+ + prompt_list
+ + [[self.max_text_len + 2, self.dict[""], 0]]
+ * (self.max_text_len - len(prompt_list))
+ )
+ prompt_pos_idx_list.append(prompt_list[:, 0])
+ prompt_char_idx_list.append(prompt_list[:, 1])
+
+ ques_len = len(ques_list)
+ ques_len_list.append(ques_len)
+
+ ques_list = np.array(
+ ques_list
+ + [[self.max_text_len + 2, self.dict[""], 0]]
+ * (self.max_text_len + 1 - ques_len)
+ )
+ ques_pos_idx_list.append(ques_list[:, 0])
+ # what is the first and third char?
+ # Is the first character 't'? and Is the third character 'f'?
+ # How many 'c', 's' and 'f' are there in the text image?
+ ques1_answer_list.append(ques_list[:, 1])
+ ques2_char_idx = copy.deepcopy(ques_list[:ques_len, :2])
+ new_ques2_char_idx = []
+ ques2_answer = []
+ for q_2, ques2_idx in enumerate(ques2_char_idx.tolist()):
+
+ if (train_step == 2 or train_step == 3) and q_2 == ques_len - 1:
+ new_ques2_char_idx.append(ques2_idx)
+ ques2_answer.append(1)
+ continue
+ if ques2_idx[1] != self.dict[""] and random.random() > 0.5:
+ select_idx = random.randint(0, self.num_character - 3)
+ new_ques2_char_idx.append([ques2_idx[0], select_idx])
+ if select_idx == ques2_idx[1]:
+ ques2_answer.append(1)
+ else:
+ ques2_answer.append(0)
+
+ if self.siml_char_list[ques2_idx[1]] == 1 and random.random() > 0.5:
+ select_idx_sim_list = random.sample(
+ self.siml_char_dict[ques2_idx[1]][0],
+ min(3, len(self.siml_char_dict[ques2_idx[1]][0])),
+ )
+ for select_idx in select_idx_sim_list:
+ new_ques2_char_idx.append([ques2_idx[0], select_idx])
+ if select_idx == ques2_idx[1]:
+ ques2_answer.append(1)
+ else:
+ ques2_answer.append(0)
+ else:
+ new_ques2_char_idx.append(ques2_idx)
+ ques2_answer.append(1)
+ ques2_len_list.append(len(new_ques2_char_idx))
+ ques2_char_idx_new = np.array(
+ new_ques2_char_idx
+ + [[self.max_text_len + 2, self.dict[""]]]
+ * (self.max_text_len * 4 + 1 - len(new_ques2_char_idx))
+ )
+ ques2_answer = np.array(
+ ques2_answer + [0] * (self.max_text_len * 4 + 1 - len(ques2_answer))
+ )
+ ques2_char_idx_list.append(ques2_char_idx_new)
+ ques2_answer_list.append(ques2_answer)
+
+ ques4_char_num_list.append(ques_list[:, 2])
+ train_step += 1
+
+ data["ques_len_list"] = np.array(ques_len_list, dtype=np.int64)
+ data["ques2_len_list"] = np.array(ques2_len_list, dtype=np.int64)
+ data["prompt_len_list"] = np.array(prompt_len_list, dtype=np.int64)
+
+ data["prompt_pos_idx_list"] = np.array(prompt_pos_idx_list, dtype=np.int64)
+ data["prompt_char_idx_list"] = np.array(prompt_char_idx_list, dtype=np.int64)
+ data["ques_pos_idx_list"] = np.array(ques_pos_idx_list, dtype=np.int64)
+ data["ques1_answer_list"] = np.array(ques1_answer_list, dtype=np.int64)
+ data["ques2_char_idx_list"] = np.array(ques2_char_idx_list, dtype=np.int64)
+ data["ques2_answer_list"] = np.array(ques2_answer_list, dtype=np.float32)
+
+ data["ques3_answer"] = np.array(
+ text_char_num, dtype=np.int64
+ ) # np.array([1, 0, 2]) # answer 1, 0, 2
+ data["ques4_char_num_list"] = np.array(ques4_char_num_list)
+
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = [""] + dict_character + [""] + [""]
+ self.num_character = len(dict_character)
+
+ return dict_character
+
+ def encode(self, text):
+ """convert text-label into text-index.
+ input:
+ text: text labels of each image. [batch_size]
+
+ output:
+ text: concatenated text index for CTCLoss.
+ [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
+ length: length of each text. [batch_size]
+ """
+ if len(text) == 0:
+ return None
+ if self.lower:
+ text = text.lower()
+ char_num = [0 for _ in range(self.num_character - 2)]
+ char_num[0] = 1
+ text_list = []
+ qa_text = []
+ pos_i = 0
+ rare_char_qa = []
+ unrare_char_qa = []
+ for char in text:
+ if char not in self.dict:
+ continue
+
+ char_id = self.dict[char]
+ text_list.append(char_id)
+ qa_text.append([pos_i + 1, char_id, char_num[char_id]])
+ if self.rare_char_list[char_id] == 1:
+ rare_char_qa.append([pos_i + 1, char_id, char_num[char_id]])
+ else:
+ unrare_char_qa.append([pos_i + 1, char_id, char_num[char_id]])
+ char_num[char_id] += 1
+ pos_i += 1
+
+ if self.ch:
+ char_num_ch = []
+ char_num_ch_none = []
+ rare_char_num_ch_none = []
+ for i, num in enumerate(char_num):
+ if self.rare_char_list[i] == 1:
+ rare_char_num_ch_none.append([i, num])
+ if num > 0:
+ char_num_ch.append([i, num])
+ else:
+ char_num_ch_none.append([i, 0])
+ none_char_index = random.sample(
+ char_num_ch_none, min(37 - len(char_num_ch), len(char_num_ch_none))
+ )
+ if len(rare_char_num_ch_none) > 0:
+ none_rare_char_index = random.sample(
+ rare_char_num_ch_none,
+ min(
+ 40 - len(char_num_ch) - len(none_char_index),
+ len(rare_char_num_ch_none),
+ ),
+ )
+ char_num_ch = char_num_ch + none_char_index + none_rare_char_index
+ else:
+ char_num_ch = char_num_ch + none_char_index
+ char_num_ch.sort(key=lambda x: x[0])
+ char_num = char_num_ch
+
+ len_ = len(text_list)
+ if len_ == 0:
+ return None
+ ques_list = [
+ qa_text + [[pos_i + 1, self.dict[""], 0]],
+ [[pos_i + 1, self.dict[""], 0]],
+ ]
+ prompt_list = [qa_text[len_:], qa_text]
+ if len_ == 1:
+ ques_list.append([[self.max_text_len + 1, self.dict[""], 0]])
+ prompt_list.append(
+ [[self.max_text_len + 2, self.dict[""], 0]] * 4 + qa_text
+ )
+ for _ in range(1, self.k):
+ ques_list.append([[self.max_text_len + 2, self.dict[""], 0]])
+ prompt_list.append(qa_text[1:])
+ else:
+
+ next_id = random.sample(range(1, len_ + 1), 2)
+ for slice_id in next_id:
+ b_i = slice_id - 5 if slice_id - 5 > 0 else 0
+ if slice_id == len_:
+ ques_list.append([[self.max_text_len + 1, self.dict[""], 0]])
+ else:
+ ques_list.append(
+ qa_text[slice_id:]
+ + [[self.max_text_len + 1, qa_text[slice_id][1], 0]]
+ )
+ prompt_list.append(
+ [[self.max_text_len + 2, self.dict[""], 0]]
+ * (5 - slice_id + b_i)
+ + qa_text[b_i:slice_id]
+ )
+
+ shuffle_id1 = random.sample(range(1, len_), 2) if len_ > 2 else [1, 0]
+ for slice_id in shuffle_id1:
+ if slice_id == 0:
+ ques_list.append([[self.max_text_len + 2, self.dict[""], 0]])
+ prompt_list.append(qa_text[:0])
+ else:
+ ques_list.append(
+ qa_text[slice_id:] + [[pos_i + 1, self.dict[""], 0]]
+ )
+ prompt_list.append(qa_text[:slice_id])
+
+ if len_ > 2:
+ shuffle_id2 = random.sample(
+ range(1, len_), self.k - 4 if len_ - 1 > self.k - 4 else len_ - 1
+ )
+ if self.k - 4 != len(shuffle_id2):
+ shuffle_id2 += random.sample(
+ range(1, len_), self.k - 4 - len(shuffle_id2)
+ )
+ rare_slice_id = len(rare_char_qa)
+ unrare_slice_id = len(unrare_char_qa)
+ for slice_id in shuffle_id2:
+ random.shuffle(qa_text)
+ if len(rare_char_qa) > 0 and random.random() < 0.5:
+ ques_list.append(
+ rare_char_qa[:rare_slice_id]
+ + unrare_char_qa[unrare_slice_id:]
+ + [[pos_i + 1, self.dict[""], 0]]
+ )
+ if len(unrare_char_qa[:unrare_slice_id]) > 0:
+ prompt_list1 = random.sample(
+ unrare_char_qa[:unrare_slice_id],
+ (
+ random.randint(
+ 1, len(unrare_char_qa[:unrare_slice_id])
+ )
+ if len(unrare_char_qa[:unrare_slice_id]) > 1
+ else 1
+ ),
+ )
+ else:
+ prompt_list1 = []
+ if len(rare_char_qa[rare_slice_id:]) > 0:
+ prompt_list2 = random.sample(
+ rare_char_qa[rare_slice_id:],
+ random.randint(
+ 1,
+ (
+ len(rare_char_qa[rare_slice_id:])
+ if len(rare_char_qa[rare_slice_id:]) > 1
+ else 1
+ ),
+ ),
+ )
+ else:
+ prompt_list2 = []
+ prompt_list.append(prompt_list1 + prompt_list2)
+ random.shuffle(rare_char_qa)
+ random.shuffle(unrare_char_qa)
+ rare_slice_id = (
+ random.randint(1, len(rare_char_qa))
+ if len(rare_char_qa) > 1
+ else 1
+ )
+ unrare_slice_id = (
+ random.randint(1, len(unrare_char_qa))
+ if len(unrare_char_qa) > 1
+ else 1
+ )
+ else:
+ ques_list.append(
+ qa_text[slice_id:] + [[pos_i + 1, self.dict[""], 0]]
+ )
+ prompt_list.append(qa_text[:slice_id])
+ else:
+ ques_list.append(qa_text[1:] + [[pos_i + 1, self.dict[""], 0]])
+ prompt_list.append(qa_text[:1])
+ ques_list.append(qa_text[:1] + [[pos_i + 1, self.dict[""], 0]])
+ prompt_list.append(qa_text[1:])
+ ques_list += [[[self.max_text_len + 2, self.dict[""], 0]]] * (
+ self.k - 6
+ )
+ prompt_list += [qa_text[:0]] * (self.k - 6)
+
+ return text_list, char_num, ques_list, prompt_list
diff --git a/ppocr/data/multi_scale_sampler.py b/ppocr/data/multi_scale_sampler.py
index 4ab38fc4e65..04a03cc9629 100644
--- a/ppocr/data/multi_scale_sampler.py
+++ b/ppocr/data/multi_scale_sampler.py
@@ -169,3 +169,186 @@ def set_epoch(self, epoch: int):
def __len__(self):
return self.length
+
+
+class RatioSampler(Sampler):
+
+ def __init__(
+ self,
+ data_source,
+ scales,
+ first_bs=512,
+ fix_bs=True,
+ divided_factor=[8, 16],
+ is_training=True,
+ max_ratio=10,
+ max_bs=1024,
+ seed=None,
+ ):
+ """
+ multi scale samper
+ Args:
+ data_source(dataset)
+ scales(list): several scales for image resolution
+ first_bs(int): batch size for the first scale in scales
+ divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
+ is_training(boolean): mode
+ """
+ # min. and max. spatial dimensions
+ self.data_source = data_source
+ # self.data_idx_order_list = np.array(data_source.data_idx_order_list)
+ self.ds_width = data_source.ds_width
+ if self.ds_width:
+ self.wh_ratio = data_source.wh_ratio
+ self.wh_ratio_sort = data_source.wh_ratio_sort
+ self.n_data_samples = len(self.data_source)
+ self.max_ratio = max_ratio
+ self.max_bs = max_bs
+
+ if isinstance(scales[0], list):
+ width_dims = [i[0] for i in scales]
+ height_dims = [i[1] for i in scales]
+ elif isinstance(scales[0], int):
+ width_dims = scales
+ height_dims = scales
+ base_im_w = width_dims[0]
+ base_im_h = height_dims[0]
+ base_batch_size = first_bs
+ base_elements = base_im_w * base_im_h * base_batch_size
+ self.base_elements = base_elements
+ self.base_batch_size = base_batch_size
+ self.base_im_h = base_im_h
+ self.base_im_w = base_im_w
+
+ # Get the GPU and node related information
+ num_replicas = dist.get_world_size()
+ rank = dist.get_rank()
+ # self.rank = rank
+ # adjust the total samples to avoid batch dropping
+ num_samples_per_replica = int(
+ math.ceil(self.n_data_samples * 1.0 / num_replicas)
+ )
+
+ img_indices = [idx for idx in range(self.n_data_samples)]
+ self.shuffle = False
+ if is_training:
+ # compute the spatial dimensions and corresponding batch size
+ # ImageNet models down-sample images by a factor of 32.
+ # Ensure that width and height dimensions are multiples are multiple of 32.
+ width_dims = [
+ int((w // divided_factor[0]) * divided_factor[0]) for w in width_dims
+ ]
+ height_dims = [
+ int((h // divided_factor[1]) * divided_factor[1]) for h in height_dims
+ ]
+
+ img_batch_pairs = list()
+ for h, w in zip(height_dims, width_dims):
+ if fix_bs:
+ batch_size = base_batch_size
+ else:
+ batch_size = int(max(1, (base_elements / (h * w))))
+ img_batch_pairs.append((w, h, batch_size))
+ self.img_batch_pairs = img_batch_pairs
+ self.shuffle = True
+ else:
+ self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]
+
+ self.img_indices = img_indices
+ self.n_samples_per_replica = num_samples_per_replica
+ self.epoch = 0
+ self.rank = rank
+ self.num_replicas = num_replicas
+
+ # self.batch_list = []
+ self.current = 0
+ self.is_training = is_training
+ if is_training:
+ indices_rank_i = self.img_indices[
+ self.rank : len(self.img_indices) : self.num_replicas
+ ]
+ else:
+ indices_rank_i = self.img_indices
+ self.indices_rank_i_ori = np.array(self.wh_ratio_sort[indices_rank_i])
+ self.indices_rank_i_ratio = self.wh_ratio[self.indices_rank_i_ori]
+ indices_rank_i_ratio_unique = np.unique(self.indices_rank_i_ratio)
+ self.indices_rank_i_ratio_unique = indices_rank_i_ratio_unique.tolist()
+ self.batch_list = self.create_batch()
+ self.length = len(self.batch_list)
+ self.batchs_in_one_epoch_id = [i for i in range(self.length)]
+
+ def create_batch(self):
+ batch_list = []
+ for ratio in self.indices_rank_i_ratio_unique:
+ ratio_ids = np.where(self.indices_rank_i_ratio == ratio)[0]
+ ratio_ids = self.indices_rank_i_ori[ratio_ids]
+ if self.shuffle:
+ random.shuffle(ratio_ids)
+ num_ratio = ratio_ids.shape[0]
+ if ratio < 5:
+ batch_size_ratio = self.base_batch_size
+ else:
+ batch_size_ratio = min(
+ self.max_bs,
+ int(
+ max(
+ 1,
+ (
+ self.base_elements
+ / (self.base_im_h * ratio * self.base_im_h)
+ ),
+ )
+ ),
+ )
+ if num_ratio > batch_size_ratio:
+ batch_num_ratio = num_ratio // batch_size_ratio
+ ratio_ids_full = ratio_ids[
+ : batch_num_ratio * batch_size_ratio
+ ].reshape(batch_num_ratio, batch_size_ratio, 1)
+ w = np.full_like(ratio_ids_full, ratio * self.base_im_h)
+ h = np.full_like(ratio_ids_full, self.base_im_h)
+ ra_wh = np.full_like(ratio_ids_full, ratio)
+ ratio_ids_full = np.concatenate([w, h, ratio_ids_full, ra_wh], axis=-1)
+ batch_ratio = ratio_ids_full.tolist()
+
+ if batch_num_ratio * batch_size_ratio < num_ratio:
+ drop = ratio_ids[batch_num_ratio * batch_size_ratio :]
+ if self.is_training:
+ drop_full = ratio_ids[
+ : batch_size_ratio
+ - (num_ratio - batch_num_ratio * batch_size_ratio)
+ ]
+ drop = np.append(drop_full, drop)
+ drop = drop.reshape(-1, 1)
+ w = np.full_like(drop, ratio * self.base_im_h)
+ h = np.full_like(drop, self.base_im_h)
+ ra_wh = np.full_like(drop, ratio)
+
+ drop = np.concatenate([w, h, drop, ra_wh], axis=-1)
+
+ batch_ratio.append(drop.tolist())
+ batch_list += batch_ratio
+ else:
+ ratio_ids = ratio_ids.reshape(-1, 1)
+ w = np.full_like(ratio_ids, ratio * self.base_im_h)
+ h = np.full_like(ratio_ids, self.base_im_h)
+ ra_wh = np.full_like(ratio_ids, ratio)
+
+ ratio_ids = np.concatenate([w, h, ratio_ids, ra_wh], axis=-1)
+ batch_list.append(ratio_ids.tolist())
+ return batch_list
+
+ def __iter__(self):
+ if self.shuffle or self.is_training:
+ random.seed(self.epoch)
+ self.epoch += 1
+ self.batch_list = self.create_batch()
+ random.shuffle(self.batchs_in_one_epoch_id)
+ for batch_tuple_id in self.batchs_in_one_epoch_id:
+ yield self.batch_list[batch_tuple_id]
+
+ def set_epoch(self, epoch: int):
+ self.epoch = epoch
+
+ def __len__(self):
+ return self.length
diff --git a/ppocr/data/ratio_dataset.py b/ppocr/data/ratio_dataset.py
new file mode 100644
index 00000000000..53906dff5a8
--- /dev/null
+++ b/ppocr/data/ratio_dataset.py
@@ -0,0 +1,228 @@
+# copyright (c) 2025 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import io
+import math
+import random
+import cv2
+import lmdb
+import numpy as np
+from PIL import Image
+from paddle.io import Dataset
+
+from .imaug import create_operators, transform
+
+
+class RatioDataSet(Dataset):
+
+ def __init__(self, config, mode, logger, seed=None):
+ super(RatioDataSet, self).__init__()
+ self.ds_width = config[mode]["dataset"].get("ds_width", True)
+ global_config = config["Global"]
+ dataset_config = config[mode]["dataset"]
+ loader_config = config[mode]["loader"]
+
+ data_dir_list = dataset_config["data_dir_list"]
+ self.padding = dataset_config.get("padding", True)
+ self.padding_rand = dataset_config.get("padding_rand", False)
+ self.padding_doub = dataset_config.get("padding_doub", False)
+ max_ratio = dataset_config.get("max_ratio", 12)
+ min_ratio = dataset_config.get("min_ratio", 1)
+ self.do_shuffle = loader_config["shuffle"]
+ data_source_num = len(data_dir_list)
+ ratio_list = dataset_config.get("ratio_list", 1.0)
+ if isinstance(ratio_list, (float, int)):
+ ratio_list = [float(ratio_list)] * int(data_source_num)
+ assert (
+ len(ratio_list) == data_source_num
+ ), "The length of ratio_list should be the same as the file_list."
+ self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir_list, ratio_list)
+ for data_dir in data_dir_list:
+ logger.info("Initialize indexs of datasets:%s" % data_dir)
+ self.logger = logger
+ self.data_idx_order_list = self.dataset_traversal()
+ wh_ratio = np.around(np.array(self.get_wh_ratio()))
+ self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
+ self.wh_ratio_sort = np.argsort(self.wh_ratio)
+ self.ops = create_operators(dataset_config["transforms"], global_config)
+
+ self.need_reset = True in [x < 1 for x in ratio_list]
+ self.error = 0
+ self.base_shape = dataset_config.get(
+ "base_shape", [[64, 64], [96, 48], [112, 40], [128, 32]]
+ )
+ self.base_h = 32
+
+ def get_wh_ratio(self):
+ wh_ratio = []
+ for idx in range(self.data_idx_order_list.shape[0]):
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ wh_key = "wh-%09d".encode() % file_idx
+ wh = self.lmdb_sets[lmdb_idx]["txn"].get(wh_key)
+ if wh is None:
+ img_key = f"image-{file_idx:09d}".encode()
+ img = self.lmdb_sets[lmdb_idx]["txn"].get(img_key)
+ buf = io.BytesIO(img)
+ w, h = Image.open(buf).size
+ else:
+ wh = wh.decode("utf-8")
+ w, h = wh.split("_")
+ wh_ratio.append(float(w) / float(h))
+ return wh_ratio
+
+ def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
+ lmdb_sets = {}
+ dataset_idx = 0
+ for dirpath, ratio in zip(data_dir_list, ratio_list):
+ env = lmdb.open(
+ dirpath,
+ max_readers=32,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
+ txn = env.begin(write=False)
+ num_samples = int(txn.get("num-samples".encode()))
+ lmdb_sets[dataset_idx] = {
+ "dirpath": dirpath,
+ "env": env,
+ "txn": txn,
+ "num_samples": num_samples,
+ "ratio_num_samples": int(ratio * num_samples),
+ }
+ dataset_idx += 1
+ return lmdb_sets
+
+ def dataset_traversal(self):
+ lmdb_num = len(self.lmdb_sets)
+ total_sample_num = 0
+ for lno in range(lmdb_num):
+ total_sample_num += self.lmdb_sets[lno]["ratio_num_samples"]
+ data_idx_order_list = np.zeros((total_sample_num, 2))
+ beg_idx = 0
+ for lno in range(lmdb_num):
+ tmp_sample_num = self.lmdb_sets[lno]["ratio_num_samples"]
+ end_idx = beg_idx + tmp_sample_num
+ data_idx_order_list[beg_idx:end_idx, 0] = lno
+ data_idx_order_list[beg_idx:end_idx, 1] = list(
+ random.sample(
+ range(1, self.lmdb_sets[lno]["num_samples"] + 1),
+ self.lmdb_sets[lno]["ratio_num_samples"],
+ )
+ )
+ beg_idx = beg_idx + tmp_sample_num
+ return data_idx_order_list
+
+ def get_img_data(self, value):
+ """get_img_data."""
+ if not value:
+ return None
+ imgdata = np.frombuffer(value, dtype="uint8")
+ if imgdata is None:
+ return None
+ imgori = cv2.imdecode(imgdata, 1)
+ if imgori is None:
+ return None
+ return imgori
+
+ def resize_norm_img(self, data, gen_ratio, padding=True):
+ img = data["image"]
+ h = img.shape[0]
+ w = img.shape[1]
+ if self.padding_rand and random.random() < 0.5:
+ padding = not padding
+ imgW, imgH = (
+ self.base_shape[gen_ratio - 1]
+ if gen_ratio <= 4
+ else [self.base_h * gen_ratio, self.base_h]
+ )
+ use_ratio = imgW // imgH
+ if use_ratio >= (w // h) + 2:
+ self.error += 1
+ return None
+ if not padding:
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR
+ )
+ resized_w = imgW
+ else:
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio * (random.random() + 0.5)))
+ resized_w = min(imgW, resized_w)
+
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ resized_image = resized_image.astype("float32")
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
+ if self.padding_doub and random.random() < 0.5:
+ padding_im[:, :, -resized_w:] = resized_image
+ else:
+ padding_im[:, :, :resized_w] = resized_image
+ valid_ratio = min(1.0, float(resized_w / imgW))
+ data["image"] = padding_im
+ data["valid_ratio"] = valid_ratio
+ data["real_ratio"] = round(w / h)
+ return data
+
+ def get_lmdb_sample_info(self, txn, index):
+ label_key = "label-%09d".encode() % index
+ label = txn.get(label_key)
+ if label is None:
+ return None
+ label = label.decode("utf-8")
+ img_key = "image-%09d".encode() % index
+ imgbuf = txn.get(img_key)
+ return imgbuf, label
+
+ def __getitem__(self, properties):
+ img_width = properties[0]
+ img_height = properties[1]
+ idx = properties[2]
+ ratio = properties[3]
+ lmdb_idx, file_idx = self.data_idx_order_list[idx]
+ lmdb_idx = int(lmdb_idx)
+ file_idx = int(file_idx)
+ sample_info = self.get_lmdb_sample_info(
+ self.lmdb_sets[lmdb_idx]["txn"], file_idx
+ )
+ if sample_info is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ img, label = sample_info
+ data = {"image": img, "label": label}
+ outs = transform(data, self.ops[:-1])
+ if outs is not None:
+ outs = self.resize_norm_img(outs, ratio, padding=self.padding)
+ if outs is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ outs = transform(outs, self.ops[-1:])
+ if outs is None:
+ ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
+ ids = random.sample(ratio_ids, 1)
+ return self.__getitem__([img_width, img_height, ids[0], ratio])
+ return outs
+
+ def __len__(self):
+ return self.data_idx_order_list.shape[0]
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 59b7ecfaadd..63dc78cfbe7 100644
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -48,6 +48,7 @@
from .rec_latexocr_loss import LaTeXOCRLoss
from .rec_unimernet_loss import UniMERNetLoss
from .rec_ppformulanet_loss import PPFormulaNet_S_Loss, PPFormulaNet_L_Loss
+from .rec_igtr_loss import IGTRLoss
# cls loss
from .cls_loss import ClsLoss
@@ -114,6 +115,7 @@ def build_loss(config):
"UniMERNetLoss",
"PPFormulaNet_S_Loss",
"PPFormulaNet_L_Loss",
+ "IGTRLoss",
]
config = copy.deepcopy(config)
module_name = config.pop("name")
diff --git a/ppocr/losses/rec_igtr_loss.py b/ppocr/losses/rec_igtr_loss.py
new file mode 100644
index 00000000000..45d02d05ee4
--- /dev/null
+++ b/ppocr/losses/rec_igtr_loss.py
@@ -0,0 +1,12 @@
+from paddle import nn
+
+
+class IGTRLoss(nn.Layer):
+
+ def __init__(self, **kwargs):
+ super(IGTRLoss, self).__init__()
+
+ def forward(self, predicts, batch):
+ if isinstance(predicts, list):
+ predicts = predicts[0]
+ return predicts
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index efb20169534..cdf9059a2c6 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -72,6 +72,7 @@ def build_backbone(config, model_type):
from .rec_svtrv2 import SVTRv2
from .rec_vary_vit import Vary_VIT_B, Vary_VIT_B_Formula
from .rec_pphgnetv2 import PPHGNetV2_B4
+ from .rec_svtrnet2dpos import SVTRNet2DPos
support_dict = [
"MobileNetV1Enhance",
@@ -102,6 +103,7 @@ def build_backbone(config, model_type):
"Vary_VIT_B",
"PPHGNetV2_B4",
"Vary_VIT_B_Formula",
+ "SVTRNet2DPos",
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_svtrnet.py b/ppocr/modeling/backbones/rec_svtrnet.py
index 427c87b324a..f9c172e240b 100644
--- a/ppocr/modeling/backbones/rec_svtrnet.py
+++ b/ppocr/modeling/backbones/rec_svtrnet.py
@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from paddle import ParamAttr
-from paddle.nn.initializer import KaimingNormal
import numpy as np
import paddle
import paddle.nn as nn
-from paddle.nn.initializer import TruncatedNormal, Constant, Normal
+from paddle import ParamAttr
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
trunc_normal_ = TruncatedNormal(std=0.02)
normal_ = Normal
diff --git a/ppocr/modeling/backbones/rec_svtrnet2dpos.py b/ppocr/modeling/backbones/rec_svtrnet2dpos.py
new file mode 100644
index 00000000000..5745ca5aa92
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_svtrnet2dpos.py
@@ -0,0 +1,663 @@
+# copyright (c) 2025 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+from ppocr.modeling.backbones.rec_svtrnet import DropPath, Identity, Mlp
+from paddle import ParamAttr
+from paddle.nn.initializer import (
+ TruncatedNormal,
+ Constant,
+ Normal,
+ KaimingNormal,
+ KaimingUniform,
+)
+
+trunc_normal_ = TruncatedNormal(std=0.02)
+normal_ = Normal
+zeros_ = Constant(value=0.0)
+ones_ = Constant(value=1.0)
+
+
+def dim2perm(ndim, dim0, dim1):
+ perm = list(range(ndim))
+ perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
+ return perm
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias=False,
+ groups=1,
+ act=nn.GELU,
+ ):
+ super().__init__()
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ weight_attr=ParamAttr(initializer=KaimingUniform()),
+ bias_attr=bias,
+ )
+ self.norm = nn.BatchNorm2D(num_features=out_channels)
+ self.act = act()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class ConvMixer(nn.Layer):
+ def __init__(self, dim, num_heads=8, HW=[8, 25], local_k=[3, 3]):
+ super().__init__()
+ self.HW = HW
+ self.dim = dim
+ self.local_mixer = nn.Conv2D(
+ in_channels=dim,
+ out_channels=dim,
+ kernel_size=local_k,
+ stride=1,
+ padding=[local_k[0] // 2, local_k[1] // 2],
+ groups=num_heads,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ )
+
+ def forward(self, x, w):
+ x = x.transpose(perm=dim2perm(x.ndim, 1, 2)).reshape(
+ [tuple(x.shape)[0], self.dim, -1, w]
+ )
+ x = self.local_mixer(x)
+ x = x.flatten(start_axis=2).transpose(
+ perm=dim2perm(x.flatten(start_axis=2).ndim, 1, 2)
+ )
+ return x
+
+
+class ConvMlp(nn.Layer):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ groups=1,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2D(
+ in_channels=in_features,
+ out_channels=hidden_features,
+ kernel_size=1,
+ groups=groups,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ )
+ self.act = act_layer()
+ self.fc2 = nn.Conv2D(
+ in_channels=hidden_features, out_channels=out_features, kernel_size=1
+ )
+ self.drop = nn.Dropout(p=drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class ConvBlock(nn.Layer):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mixer="Global",
+ local_mixer=[7, 11],
+ HW=None,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer="nn.LayerNorm",
+ eps=1e-06,
+ prenorm=True,
+ ):
+ super().__init__()
+ self.norm1 = nn.BatchNorm2D(num_features=dim)
+ self.local_mixer = nn.Conv2D(
+ in_channels=dim,
+ out_channels=dim,
+ kernel_size=[5, 5],
+ stride=1,
+ padding=[2, 2],
+ groups=num_heads,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.norm2 = nn.BatchNorm2D(num_features=dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ConvMlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+ self.prenorm = prenorm
+
+ def forward(self, x):
+ x = self.norm1(x + self.drop_path(self.local_mixer(x)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ return x
+
+
+class Attention(nn.Layer):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ mixer="Global",
+ HW=None,
+ local_k=[7, 11],
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.dim = dim
+ self.head_dim = dim // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.HW = HW
+ if HW is not None:
+ H = HW[0]
+ W = HW[1]
+ self.N = H * W
+ self.C = dim
+ if mixer == "Local" and HW is not None:
+ hk = local_k[0]
+ wk = local_k[1]
+ mask = paddle.ones([H * W, H + hk - 1, W + wk - 1], dtype="float32")
+ for h in range(0, H):
+ for w in range(0, W):
+ mask[h * W + w, h : h + hk, w : w + wk] = 0.0
+ mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(
+ 1
+ )
+ mask_inf = paddle.full([H * W, H * W], "-inf", dtype="float32")
+ mask = paddle.where(mask_paddle < 1, mask_paddle, mask_inf)
+ self.register_buffer("mask", mask.unsqueeze([0, 1]))
+ self.mixer = mixer
+
+ def forward(self, x):
+ qkv = (
+ self.qkv(x)
+ .reshape((0, -1, 3, self.num_heads, self.head_dim))
+ .transpose((2, 0, 3, 1, 4))
+ )
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+
+ attn = q.matmul(k.transpose((0, 1, 3, 2)))
+ if self.mixer == "Local":
+ attn += self.mask
+ attn = nn.functional.softmax(attn, axis=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, -1, self.dim))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Layer):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mixer="Global",
+ local_mixer=[7, 11],
+ HW=None,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer="nn.LayerNorm",
+ eps=1e-6,
+ prenorm=True,
+ ):
+ super().__init__()
+ if isinstance(norm_layer, str):
+ self.norm1 = eval(norm_layer)(dim, epsilon=eps)
+ else:
+ self.norm1 = norm_layer(dim)
+ if mixer == "Global" or mixer == "Local":
+ self.mixer = Attention(
+ dim,
+ num_heads=num_heads,
+ mixer=mixer,
+ HW=HW,
+ local_k=local_mixer,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ elif mixer == "Conv":
+ self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
+ else:
+ raise TypeError("The mixer must be one of [Global, Local, Conv]")
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ if isinstance(norm_layer, str):
+ self.norm2 = eval(norm_layer)(dim, epsilon=eps)
+ else:
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+ self.prenorm = prenorm
+
+ def forward(self, x, w):
+ x = self.norm1(x + self.drop_path(self.mixer(x)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ return x, w
+
+
+class PatchEmbed(nn.Layer):
+ """Image to Patch Embedding."""
+
+ def __init__(
+ self,
+ img_size=[32, 100],
+ in_channels=3,
+ embed_dim=768,
+ sub_num=2,
+ patch_size=[4, 4],
+ mode="pope",
+ ):
+ super().__init__()
+ num_patches = img_size[1] // 2**sub_num * (img_size[0] // 2**sub_num)
+ self.img_size = img_size
+ self.num_patches = num_patches
+ self.embed_dim = embed_dim
+ self.norm = None
+ if mode == "pope":
+ if sub_num == 2:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=False,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=False,
+ ),
+ )
+ if sub_num == 3:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=False,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 4,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=False,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias=False,
+ ),
+ )
+ elif mode == "linear":
+ self.proj = nn.Conv2D(
+ in_channels=1,
+ out_channels=embed_dim,
+ kernel_size=patch_size,
+ stride=patch_size,
+ )
+ self.num_patches = (
+ img_size[0] // patch_size[0] * img_size[1] // patch_size[1]
+ )
+
+ def forward(self, x):
+ x = self.proj(x)
+ return x
+
+
+class SubSample(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ types="Pool",
+ stride=[2, 1],
+ sub_norm="nn.LayerNorm",
+ act=None,
+ ):
+ super().__init__()
+ self.types = types
+ if types == "Pool":
+ self.avgpool = nn.AvgPool2D(
+ kernel_size=[3, 5], stride=stride, padding=[1, 2], exclusive=False
+ )
+ self.maxpool = nn.MaxPool2D(
+ kernel_size=[3, 5], stride=stride, padding=[1, 2]
+ )
+ self.proj = nn.Linear(in_features=in_channels, out_features=out_channels)
+ else:
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ )
+ self.dim = in_channels
+ self.norm = eval(sub_norm)(out_channels)
+ if act is not None:
+ self.act = act()
+ else:
+ self.act = None
+
+ def forward(self, x, w):
+ if self.types == "Pool":
+ x1 = self.avgpool(x)
+ x2 = self.maxpool(x)
+ x = (x1 + x2) * 0.5
+ out = self.proj(
+ x.flatten(start_axis=2).transpose(
+ perm=dim2perm(x.flatten(start_axis=2).ndim, 1, 2)
+ )
+ )
+ else:
+ x = x.transpose(perm=dim2perm(x.ndim, 1, 2)).reshape(
+ [tuple(x.shape)[0], self.dim, -1, w]
+ )
+ x = self.conv(x)
+ out = x.flatten(start_axis=2).transpose(
+ perm=dim2perm(x.flatten(start_axis=2).ndim, 1, 2)
+ )
+ out = self.norm(out)
+ if self.act is not None:
+ out = self.act(out)
+ return out, w
+
+
+class FlattenTranspose(nn.Layer):
+ def forward(self, x):
+ return x.flatten(start_axis=2).transpose(
+ perm=dim2perm(x.flatten(start_axis=2).ndim, 1, 2)
+ )
+
+
+class DownSConv(nn.Layer):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=[2, 1],
+ padding=1,
+ )
+ self.norm = nn.LayerNorm(normalized_shape=out_channels)
+
+ def forward(self, x, w):
+ B, N, C = tuple(x.shape)
+ x = x.transpose(perm=dim2perm(x.ndim, 1, 2)).reshape([B, C, -1, w])
+ x = self.conv(x)
+ w = tuple(x.shape)[-1]
+ x = self.norm(
+ x.flatten(start_axis=2).transpose(
+ perm=dim2perm(x.flatten(start_axis=2).ndim, 1, 2)
+ )
+ )
+ return x, w
+
+
+class SVTRNet2DPos(nn.Layer):
+ def __init__(
+ self,
+ img_size=[32, -1],
+ in_channels=3,
+ embed_dim=[64, 128, 256],
+ depth=[3, 6, 3],
+ num_heads=[2, 4, 8],
+ mixer=["Local"] * 6 + ["Global"] * 6,
+ local_mixer=[[7, 11], [7, 11], [7, 11]],
+ patch_merging="Conv",
+ pool_size=[2, 1],
+ max_size=[16, 32],
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ last_drop=0.1,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer="nn.LayerNorm",
+ eps=1e-06,
+ act="nn.GELU",
+ last_stage=True,
+ sub_num=2,
+ use_first_sub=True,
+ flatten=False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.img_size = img_size
+ self.embed_dim = embed_dim
+ self.flatten = flatten
+ patch_merging = (
+ None
+ if patch_merging != "Conv" and patch_merging != "Pool"
+ else patch_merging
+ )
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim[0],
+ sub_num=sub_num,
+ )
+ if img_size[1] == -1:
+ self.HW = [img_size[0] // 2**sub_num, -1]
+ else:
+ self.HW = [img_size[0] // 2**sub_num, img_size[1] // 2**sub_num]
+ pos_embed = paddle.zeros(
+ shape=[1, max_size[0] * max_size[1], embed_dim[0]], dtype="float32"
+ )
+ init_TruncatedNormal = nn.initializer.TruncatedNormal(mean=0, std=0.02)
+ init_TruncatedNormal(pos_embed)
+ self.pos_embed = paddle.base.framework.EagerParamBase.from_tensor(
+ tensor=pos_embed.transpose(perm=dim2perm(pos_embed.ndim, 1, 2)).reshape(
+ [1, embed_dim[0], max_size[0], max_size[1]]
+ ),
+ trainable=True,
+ )
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ conv_block_num = sum(
+ [(1 if mixer_type == "ConvB" else 0) for mixer_type in mixer]
+ )
+ Block_unit = [ConvBlock for _ in range(conv_block_num)] + [
+ Block for _ in range(len(mixer) - conv_block_num)
+ ]
+ HW = self.HW
+ dpr = np.linspace(0, drop_path_rate, sum(depth))
+ self.conv_blocks1 = nn.LayerList(
+ sublayers=[
+ Block_unit[0 : depth[0]][i](
+ dim=embed_dim[0],
+ num_heads=num_heads[0],
+ mixer=mixer[0 : depth[0]][i],
+ HW=self.HW,
+ local_mixer=local_mixer[0],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[0 : depth[0]][i],
+ norm_layer=norm_layer,
+ eps=eps,
+ )
+ for i in range(depth[0])
+ ]
+ )
+ if patch_merging is not None:
+ if use_first_sub:
+ stride = [2, 1]
+ HW = [self.HW[0] // 2, self.HW[1]]
+ else:
+ stride = [1, 1]
+ HW = self.HW
+ sub_sample1 = nn.Sequential(
+ nn.Conv2D(
+ in_channels=embed_dim[0],
+ out_channels=embed_dim[1],
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ ),
+ nn.BatchNorm2D(num_features=embed_dim[1]),
+ )
+ self.conv_blocks1.append(sub_sample1)
+ self.patch_merging = patch_merging
+ self.trans_blocks = nn.LayerList()
+ for i in range(depth[1]):
+ block = Block_unit[depth[0] : depth[0] + depth[1]][i](
+ dim=embed_dim[1],
+ num_heads=num_heads[1],
+ mixer=mixer[depth[0] : depth[0] + depth[1]][i],
+ HW=HW,
+ local_mixer=local_mixer[1],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
+ norm_layer=norm_layer,
+ eps=eps,
+ )
+ if i + depth[0] < conv_block_num:
+ self.conv_blocks1.append(block)
+ else:
+ self.trans_blocks.append(block)
+ if patch_merging is not None:
+ self.trans_blocks.append(DownSConv(embed_dim[1], embed_dim[2]))
+ HW = [HW[0] // 2, -1]
+ for i in range(depth[2]):
+ self.trans_blocks.append(
+ Block_unit[depth[0] + depth[1] :][i](
+ dim=embed_dim[2],
+ num_heads=num_heads[2],
+ mixer=mixer[depth[0] + depth[1] :][i],
+ HW=HW,
+ local_mixer=local_mixer[2],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0] + depth[1] :][i],
+ norm_layer=norm_layer,
+ eps=eps,
+ )
+ )
+ self.last_stage = last_stage
+ self.out_channels = embed_dim[-1]
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ w = tuple(x.shape)[-1]
+ x = x + self.pos_embed[:, :, : tuple(x.shape)[-2], :w]
+ for blk in self.conv_blocks1:
+ x = blk(x)
+ x = x.flatten(start_axis=2).transpose(
+ perm=dim2perm(x.flatten(start_axis=2).ndim, 1, 2)
+ )
+ for blk in self.trans_blocks:
+ x, w = blk(x, w)
+ B, N, C = tuple(x.shape)
+ if not self.flatten:
+ x = x.transpose(perm=dim2perm(x.ndim, 1, 2)).reshape([B, C, -1, w])
+ return x
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index c410ccc9cbb..2375b63bd90 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -46,6 +46,7 @@ def build_head(config):
from .rec_cppd_head import CPPDHead
from .rec_unimernet_head import UniMERNetHead
from .rec_ppformulanet_head import PPFormulaNet_Head
+ from .rec_igtr_head import IGTRHead
# cls head
from .cls_head import ClsHead
@@ -91,6 +92,7 @@ def build_head(config):
"CPPDHead",
"UniMERNetHead",
"PPFormulaNet_Head",
+ "IGTRHead",
]
if config["name"] == "DRRGHead":
diff --git a/ppocr/modeling/heads/rec_igtr_head.py b/ppocr/modeling/heads/rec_igtr_head.py
new file mode 100644
index 00000000000..3ee21e0341f
--- /dev/null
+++ b/ppocr/modeling/heads/rec_igtr_head.py
@@ -0,0 +1,847 @@
+# copyright (c) 2025 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+from ppocr.modeling.backbones.rec_svtrnet import DropPath, Identity, Mlp
+from ppocr.modeling.heads.rec_nrtr_head import Embeddings
+
+
+def dim2perm(ndim, dim0, dim1):
+ perm = list(range(ndim))
+ perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
+ return perm
+
+
+init_TruncatedNormal = nn.initializer.TruncatedNormal(std=0.02)
+
+
+class CrossAttention(nn.Layer):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+ self.q = nn.Linear(in_features=dim, out_features=dim, bias_attr=qkv_bias)
+ self.kv = nn.Linear(in_features=dim, out_features=dim * 2, bias_attr=qkv_bias)
+ self.attn_drop = nn.Dropout(p=attn_drop)
+ self.proj = nn.Linear(in_features=dim, out_features=dim)
+ self.proj_drop = nn.Dropout(p=proj_drop)
+
+ def forward(self, q, kv, key_mask=None):
+ N, C = tuple(kv.shape)[1:]
+ QN = tuple(q.shape)[1]
+ q = (
+ self.q(q)
+ .reshape([-1, QN, self.num_heads, C // self.num_heads])
+ .transpose([0, 2, 1, 3])
+ )
+ q = q * self.scale
+ k, v = (
+ self.kv(kv)
+ .reshape([-1, N, 2, self.num_heads, C // self.num_heads])
+ .transpose(perm=[2, 0, 3, 1, 4])
+ )
+ attn = q.matmul(y=k.transpose(perm=dim2perm(k.ndim, 2, 3)))
+ if key_mask is not None:
+ attn = attn + key_mask.unsqueeze(axis=1)
+ attn = nn.functional.softmax(x=attn, axis=-1)
+ if not self.training:
+ self.attn_map = attn
+ attn = self.attn_drop(attn)
+ x = attn.matmul(y=v)
+ x = x.transpose(perm=dim2perm(x.ndim, 1, 2)).reshape((-1, QN, C))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class DecoderLayer(nn.Layer):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer="nn.LayerNorm",
+ epsilon=1e-06,
+ ):
+ super().__init__()
+ self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
+ self.normkv = eval(norm_layer)(dim, epsilon=epsilon)
+ self.mixer = CrossAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, q, kv, key_mask=None):
+ x1 = q + self.drop_path(self.mixer(self.norm1(q), self.normkv(kv), key_mask))
+ x = x1 + self.drop_path(self.mlp(self.norm2(x1)))
+ return x
+
+
+class CMFFLayer(nn.Layer):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ epsilon=1e-06,
+ ):
+ super().__init__()
+ self.normq1 = nn.LayerNorm(normalized_shape=dim, epsilon=epsilon)
+ self.normkv1 = nn.LayerNorm(normalized_shape=dim, epsilon=epsilon)
+ self.images_to_question_cross_attn = CrossAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.normq2 = nn.LayerNorm(normalized_shape=dim, epsilon=epsilon)
+ self.normkv2 = nn.LayerNorm(normalized_shape=dim, epsilon=epsilon)
+ self.question_to_images_cross_attn = CrossAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ self.normmlp = nn.LayerNorm(normalized_shape=dim, epsilon=epsilon)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, question_f, prompt_f, visual_f, mask=None):
+ query_add = paddle.concat(x=[question_f, prompt_f, visual_f], axis=1)
+ query_add = query_add + self.drop_path(
+ self.images_to_question_cross_attn(
+ self.normq1(query_add), self.normkv1(prompt_f), mask
+ )
+ )
+ query_add = query_add + self.drop_path(
+ self.question_to_images_cross_attn(
+ self.normq2(query_add),
+ self.normkv2(query_add[:, -tuple(visual_f.shape)[1] :, :]),
+ )
+ )
+ query_updated = query_add + self.drop_path(self.mlp(self.normmlp(query_add)))
+ question_f_updated = query_updated[:, : tuple(question_f.shape)[1], :]
+ prompt_f_updated = query_updated[
+ :, tuple(question_f.shape)[1] : -tuple(visual_f.shape)[1], :
+ ]
+ visual_f_updated = query_updated[:, -tuple(visual_f.shape)[1] :, :]
+ return question_f_updated, prompt_f_updated, visual_f_updated
+
+
+class IGTRHead(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ dim,
+ out_channels,
+ num_layer=2,
+ drop_path_rate=0.1,
+ max_len=25,
+ vis_seq=50,
+ ch=False,
+ ar=False,
+ refine_iter=0,
+ quesall=True,
+ next_pred=False,
+ ds=False,
+ pos2d=False,
+ check_search=False,
+ max_size=[8, 32],
+ **kwargs,
+ ):
+ super(IGTRHead, self).__init__()
+ self.out_channels = out_channels
+ self.dim = dim
+ self.max_len = max_len + 3
+ self.ch = ch
+ self.char_embed = Embeddings(
+ d_model=dim, vocab=self.out_channels, scale_embedding=True
+ )
+ self.ignore_index = out_channels - 1
+ self.ar = ar
+ self.refine_iter = refine_iter
+ self.bos = self.out_channels - 2
+ self.eos = 0
+ self.next_pred = next_pred
+ self.quesall = quesall
+ self.check_search = check_search
+ dpr = np.linspace(0, drop_path_rate, num_layer + 2)
+ self.cmff_decoder = nn.LayerList(
+ sublayers=[
+ CMFFLayer(
+ dim=dim,
+ num_heads=dim // 32,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[i],
+ )
+ for i in range(num_layer)
+ ]
+ )
+ self.answer_to_question_layer = DecoderLayer(
+ dim=dim,
+ num_heads=dim // 32,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[-2],
+ )
+ self.answer_to_image_layer = DecoderLayer(
+ dim=dim,
+ num_heads=dim // 32,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=dpr[-1],
+ )
+ self.char_pos_embed = paddle.base.framework.EagerParamBase.from_tensor(
+ tensor=paddle.zeros(shape=[self.max_len, dim], dtype="float32"),
+ trainable=True,
+ )
+ self.appear_num_embed = paddle.base.framework.EagerParamBase.from_tensor(
+ tensor=paddle.zeros(shape=[self.max_len, dim], dtype="float32"),
+ trainable=True,
+ )
+ self.ds = ds
+ self.pos2d = pos2d
+ if not ds:
+ self.vis_pos_embed = paddle.base.framework.EagerParamBase.from_tensor(
+ tensor=paddle.zeros(shape=[1, vis_seq, dim], dtype="float32"),
+ trainable=True,
+ )
+ init_TruncatedNormal(self.vis_pos_embed)
+ elif pos2d:
+ pos_embed = paddle.zeros(
+ shape=[1, max_size[0] * max_size[1], dim], dtype="float32"
+ )
+ init_TruncatedNormal(pos_embed)
+ self.vis_pos_embed = paddle.base.framework.EagerParamBase.from_tensor(
+ tensor=pos_embed.transpose(perm=dim2perm(pos_embed.ndim, 1, 2)).reshape(
+ [1, dim, max_size[0], max_size[1]]
+ ),
+ trainable=True,
+ )
+ self.prompt_pos_embed = paddle.base.framework.EagerParamBase.from_tensor(
+ tensor=paddle.zeros(shape=[1, 6, dim], dtype="float32"), trainable=True
+ )
+ self.answer_query = paddle.base.framework.EagerParamBase.from_tensor(
+ tensor=paddle.zeros(shape=[1, 1, dim], dtype="float32"), trainable=True
+ )
+ self.norm_pred = nn.LayerNorm(normalized_shape=dim, epsilon=1e-06)
+ self.ques1_head = nn.Linear(in_features=dim, out_features=self.out_channels - 2)
+ self.ques2_head = nn.Linear(
+ in_features=dim, out_features=self.max_len, bias_attr=False
+ )
+ self.ques3_head = nn.Linear(in_features=dim, out_features=self.max_len - 1)
+ self.ques4_head = nn.Linear(in_features=dim, out_features=self.max_len - 1)
+
+ init_TruncatedNormal(self.char_pos_embed)
+ init_TruncatedNormal(self.appear_num_embed)
+ init_TruncatedNormal(self.answer_query)
+ init_TruncatedNormal(self.prompt_pos_embed)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ init_TruncatedNormal(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ init_Constant = nn.initializer.Constant(value=0.0)
+ init_Constant(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ init_Constant = nn.initializer.Constant(value=0.0)
+ init_Constant(m.bias)
+ init_Constant = nn.initializer.Constant(value=1.0)
+ init_Constant(m.weight)
+
+ def question_encoder(self, targets, train_i):
+ (
+ prompt_pos_idx,
+ prompt_char_idx,
+ ques_pos_idx,
+ ques1_answer,
+ ques2_char_idx,
+ ques2_answer,
+ ques4_char_num,
+ ques_len,
+ ques2_len,
+ prompt_len,
+ ) = targets
+ max_ques_len = paddle.max(x=ques_len)
+ max_ques2_len = paddle.max(x=ques2_len)
+ max_prompt_len = paddle.max(x=prompt_len)
+ if self.next_pred and (train_i == 2 or train_i == 3):
+ prompt_pos = self.prompt_pos_embed
+ prompt_char_idx = prompt_char_idx[:, :max_prompt_len]
+ else:
+ prompt_pos = nn.functional.embedding(
+ x=prompt_pos_idx[:, :max_prompt_len], weight=self.char_pos_embed
+ )
+ prompt_char_idx = prompt_char_idx[:, :max_prompt_len]
+ prompt_char = self.char_embed(prompt_char_idx)
+ prompt = prompt_pos + prompt_char
+ mask_1234 = paddle.where(
+ condition=prompt_char_idx == self.ignore_index,
+ x=float("-inf"),
+ y=0.0,
+ )
+ mask_1234 = paddle.cast(mask_1234.unsqueeze(axis=1), paddle.float32)
+ # mask_1234 = mask_1234
+ ques1 = nn.functional.embedding(
+ x=ques_pos_idx[:, :max_ques_len], weight=self.char_pos_embed
+ )
+
+ ques1_answer = ques1_answer[:, :max_ques_len]
+ if self.quesall or train_i == 0:
+ ques2_char = self.char_embed(ques2_char_idx[:, :max_ques2_len, (1)])
+ ques2 = ques2_char + nn.functional.embedding(
+ x=ques2_char_idx[:, :max_ques2_len, (0)], weight=self.char_pos_embed
+ )
+ ques2_answer = ques2_answer[:, :max_ques2_len]
+ # print(ques2_char_idx[:, :max_ques2_len, (0)].shape, self.ques2_head.weight.shape)
+ ques2_head = nn.functional.embedding(
+ x=ques2_char_idx[:, :max_ques2_len, (0)],
+ weight=self.ques2_head.weight.transpose([1, 0]),
+ )
+ # print(ques2_head)
+ ques4_char = self.char_embed(ques1_answer)
+ ques4_ap_num = nn.functional.embedding(
+ x=ques4_char_num[:, :max_ques_len], weight=self.appear_num_embed
+ )
+ ques4 = ques4_char + ques4_ap_num
+ ques4_answer = ques_pos_idx[:, :max_ques_len]
+ return (
+ prompt,
+ ques1,
+ ques2,
+ ques2_head,
+ ques4,
+ ques1_answer,
+ ques2_answer,
+ ques4_answer,
+ mask_1234,
+ )
+ else:
+ return prompt, ques1, ques1_answer, mask_1234
+
+ def forward(self, x, targets=None):
+ if self.training:
+ return self.forward_train(x, targets)
+ else:
+ return self.forward_test(x)
+
+ def forward_test(self, x):
+ if not self.ds:
+ visual_f = x + self.vis_pos_embed
+ elif self.pos2d:
+ x = x + self.vis_pos_embed[:, :, : tuple(x.shape)[2], : tuple(x.shape)[3]]
+ visual_f = x.flatten(start_axis=2).transpose(
+ perm=dim2perm(x.flatten(start_axis=2).ndim, 1, 2)
+ )
+ else:
+ visual_f = x
+ bs = tuple(x.shape)[0]
+ prompt_bos = self.char_embed(
+ paddle.full(shape=[bs, 1], fill_value=self.bos, dtype="int64")
+ ) + self.char_pos_embed[:1, :].unsqueeze(axis=0)
+ ques_all = paddle.tile(
+ x=self.char_pos_embed.unsqueeze(axis=0), repeat_times=(bs, 1, 1)
+ )
+ if not self.ar:
+ if self.check_search:
+ tgt_in = paddle.full(
+ shape=(bs, self.max_len),
+ fill_value=self.ignore_index,
+ dtype="int64",
+ )
+ tgt_in[:, (0)] = self.bos
+ logits = []
+ for j in range(1, self.max_len):
+ visual_f_check = visual_f
+ ques_check_i = ques_all[:, j : j + 1, :] + self.char_embed(
+ paddle.arange(end=self.out_channels - 2)
+ ).unsqueeze(axis=0)
+ prompt_check = ques_all[:, :j] + self.char_embed(tgt_in[:, :j])
+ mask = paddle.where(
+ condition=(tgt_in[:, :j] == self.eos)
+ .astype(dtype="int32")
+ .cumsum(axis=-1)
+ > 0,
+ x=float("-inf"),
+ y=0,
+ )
+ for layer in self.cmff_decoder:
+ ques_check_i, prompt_check, visual_f_check = layer(
+ ques_check_i,
+ prompt_check,
+ visual_f_check,
+ mask.unsqueeze(axis=1),
+ )
+ answer_query_i = self.answer_to_question_layer(
+ ques_check_i, prompt_check, mask.unsqueeze(axis=1)
+ )
+ answer_pred_i = self.norm_pred(
+ self.answer_to_image_layer(answer_query_i, visual_f_check)
+ )
+ fc_2 = self.ques2_head.weight[j : j + 1].unsqueeze(axis=0)
+ fc_2 = fc_2.tile(repeat_times=[bs, 1, 1])
+ p_i = fc_2 @ answer_pred_i.transpose(
+ perm=dim2perm(answer_pred_i.ndim, 1, 2)
+ )
+ logits.append(p_i)
+ if j < self.max_len - 1:
+ tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1)
+ if (
+ (tgt_in == self.eos)
+ .astype("bool")
+ .any(axis=-1)
+ .astype("bool")
+ .all()
+ ):
+ break
+ logits = paddle.concat(x=logits, axis=1)
+ else:
+ ques_pd = ques_all[:, 1:, :]
+ prompt_pd = prompt_bos
+ visual_f_pd = visual_f
+ for layer in self.cmff_decoder:
+ ques_pd, prompt_pd, visual_f_pd = layer(
+ ques_pd, prompt_pd, visual_f_pd
+ )
+ answer_query_pd = self.answer_to_question_layer(ques_pd, prompt_pd)
+ answer_feats_pd = self.norm_pred(
+ self.answer_to_image_layer(answer_query_pd, visual_f_pd)
+ )
+ logits = self.ques1_head(answer_feats_pd)
+ elif self.next_pred:
+ ques_pd_1 = ques_all[:, 1:2, :]
+ prompt_pd = prompt_bos
+ visual_f_pd = visual_f
+ for layer in self.cmff_decoder:
+ ques_pd_1, prompt_pd, visual_f_pd = layer(
+ ques_pd_1, prompt_pd, visual_f_pd
+ )
+ answer_query_pd = self.answer_to_question_layer(ques_pd_1, prompt_pd)
+ answer_feats_pd = self.norm_pred(
+ self.answer_to_image_layer(answer_query_pd, visual_f_pd)
+ )
+ logits_pd_1 = self.ques1_head(answer_feats_pd)
+ ques_next = (
+ self.char_pos_embed[-2:-1, :]
+ .unsqueeze(axis=0)
+ .tile(repeat_times=[bs, 1, 1])
+ )
+ prompt_next_bos = (
+ self.char_embed(
+ paddle.full(shape=[bs, 1], fill_value=self.bos, dtype="int64")
+ )
+ + self.prompt_pos_embed[:, :1, :]
+ )
+ pred_prob, pred_id = nn.functional.softmax(x=logits_pd_1, axis=-1).max(-1)
+ pred_prob_list = [pred_prob]
+ pred_id_list = [pred_id]
+ for j in range(1, 70):
+ prompt_next_1 = (
+ self.char_embed(pred_id)
+ + self.prompt_pos_embed[:, -1 * tuple(pred_id.shape)[1] :, :]
+ )
+ prompt_next = paddle.concat(x=[prompt_next_bos, prompt_next_1], axis=1)
+ ques_next_i = ques_next
+ visual_f_i = visual_f
+ for layer in self.cmff_decoder:
+ ques_next_i, prompt_next, visual_f_pd = layer(
+ ques_next_i, prompt_next, visual_f_i
+ )
+ answer_query_next_i = self.answer_to_question_layer(
+ ques_next_i, prompt_next
+ )
+ answer_feats_next_i = self.norm_pred(
+ self.answer_to_image_layer(answer_query_next_i, visual_f_i)
+ )
+ logits_next_i = self.ques1_head(answer_feats_next_i)
+ pred_prob_i, pred_id_i = nn.functional.softmax(
+ x=logits_next_i, axis=-1
+ ).max(-1)
+ pred_prob_list.append(pred_prob_i)
+ pred_id_list.append(pred_id_i)
+ if (
+ (paddle.concat(x=pred_id_list, axis=1) == self.eos)
+ .astype("bool")
+ .any(axis=-1)
+ .astype("bool")
+ .all()
+ ):
+ break
+ if tuple(pred_id.shape)[1] >= 5:
+ pred_id = paddle.concat(x=[pred_id[:, 1:], pred_id_i], axis=1)
+ else:
+ pred_id = paddle.concat(x=[pred_id, pred_id_i], axis=1)
+ return [
+ paddle.concat(x=pred_id_list, axis=1),
+ paddle.concat(x=pred_prob_list, axis=1),
+ ]
+ else:
+ tgt_in = paddle.full(
+ shape=(bs, self.max_len), fill_value=self.ignore_index, dtype="int64"
+ )
+ tgt_in[:, (0)] = self.bos
+ logits = []
+ for j in range(1, self.max_len):
+ visual_f_ar = visual_f
+ ques_i = ques_all[:, j : j + 1, :]
+ prompt_ar = ques_all[:, :j] + self.char_embed(tgt_in[:, :j])
+ mask = paddle.where(
+ condition=(tgt_in[:, :j] == self.eos)
+ .astype(dtype="int32")
+ .cumsum(axis=-1)
+ > 0,
+ x=float("-inf"),
+ y=0,
+ )
+ for layer in self.cmff_decoder:
+ ques_i, prompt_ar, visual_f_ar = layer(
+ ques_i, prompt_ar, visual_f_ar, mask.unsqueeze(axis=1)
+ )
+ answer_query_i = self.answer_to_question_layer(
+ ques_i, prompt_ar, mask.unsqueeze(axis=1)
+ )
+ answer_pred_i = self.norm_pred(
+ self.answer_to_image_layer(answer_query_i, visual_f_ar)
+ )
+ p_i = self.ques1_head(answer_pred_i)
+ logits.append(p_i)
+ if j < self.max_len - 1:
+ tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1)
+ if (
+ (tgt_in == self.eos)
+ .astype("bool")
+ .any(axis=-1)
+ .astype("bool")
+ .all()
+ ):
+ break
+ logits = paddle.concat(x=logits, axis=1)
+ if self.refine_iter > 0:
+ pred_probs, pred_idxs = nn.functional.softmax(x=logits, axis=-1).max(-1)
+ for i in range(self.refine_iter):
+ mask_check = (pred_idxs == self.eos).astype(dtype="int32").cumsum(
+ axis=-1
+ ) <= 1
+ ques_check_all = (
+ self.char_embed(pred_idxs)
+ + ques_all[:, 1 : tuple(pred_idxs.shape)[1] + 1, :]
+ )
+ prompt_check = prompt_bos
+ visual_f_check = visual_f
+ ques_check = ques_check_all
+ for layer in self.cmff_decoder:
+ ques_check, prompt_check, visual_f_check = layer(
+ ques_check, prompt_check, visual_f_check
+ )
+ answer_query_check = self.answer_to_question_layer(
+ ques_check, prompt_check
+ )
+ answer_pred_check = self.norm_pred(
+ self.answer_to_image_layer(answer_query_check, visual_f_check)
+ )
+ ques2_head = self.ques2_head.weight[
+ 1 : tuple(pred_idxs.shape)[1] + 1, :
+ ]
+ ques2_head = paddle.tile(
+ x=ques2_head.unsqueeze(axis=0), repeat_times=[bs, 1, 1]
+ )
+ answer2_pred = answer_pred_check.matmul(
+ y=ques2_head.transpose(perm=dim2perm(ques2_head.ndim, 1, 2))
+ )
+ diag_mask = (
+ paddle.eye(num_rows=tuple(answer2_pred.shape)[1])
+ .unsqueeze(axis=0)
+ .tile(repeat_times=[bs, 1, 1])
+ )
+ answer2_pred = (
+ nn.functional.sigmoid(x=(answer2_pred * diag_mask).sum(axis=-1))
+ * mask_check
+ )
+ check_result = answer2_pred < 0.9
+ prompt_refine = paddle.concat(x=[prompt_bos, ques_check_all], axis=1)
+ mask_refine = paddle.where(
+ condition=check_result, x=float("-inf"), y=0
+ ) + paddle.where(
+ condition=(pred_idxs == self.eos)
+ .astype(dtype="int32")
+ .cumsum(axis=-1)
+ < 1,
+ x=0,
+ y=float("-inf"),
+ )
+ mask_refine = paddle.concat(
+ x=[paddle.zeros(shape=[bs, 1]), mask_refine], axis=1
+ ).unsqueeze(axis=1)
+ ques_refine = ques_all[:, 1 : tuple(pred_idxs.shape)[1] + 1, :]
+ visual_f_refine = visual_f
+ for layer in self.cmff_decoder:
+ ques_refine, prompt_refine, visual_f_refine = layer(
+ ques_refine, prompt_refine, visual_f_refine, mask_refine
+ )
+ answer_query_refine = self.answer_to_question_layer(
+ ques_refine, prompt_refine, mask_refine
+ )
+ answer_pred_refine = self.norm_pred(
+ self.answer_to_image_layer(answer_query_refine, visual_f_refine)
+ )
+ answer_refine = self.ques1_head(answer_pred_refine)
+ refine_probs, refine_idxs = nn.functional.softmax(
+ x=answer_refine, axis=-1
+ ).max(-1)
+ pred_idxs_refine = paddle.where(
+ condition=check_result, x=refine_idxs, y=pred_idxs
+ )
+ pred_idxs = paddle.where(
+ condition=mask_check, x=pred_idxs_refine, y=pred_idxs
+ )
+ pred_probs_refine = paddle.where(
+ condition=check_result, x=refine_probs, y=pred_probs
+ )
+ pred_probs = paddle.where(
+ condition=mask_check, x=pred_probs_refine, y=pred_probs
+ )
+ return [pred_idxs, pred_probs]
+ return nn.functional.softmax(x=logits, axis=-1)
+
+ def forward_train(self, x, targets=None):
+ bs = tuple(x.shape)[0]
+ answer_token = paddle.tile(x=self.answer_query, repeat_times=(bs, 1, 1))
+ if self.ch:
+ ques3 = self.char_embed(targets[7][:, :, (0)]) + answer_token
+ ques3_answer = targets[7][:, :, (1)]
+ else:
+ ques3 = (
+ self.char_embed(paddle.arange(end=self.out_channels - 2)).unsqueeze(
+ axis=0
+ )
+ + answer_token
+ )
+ ques3_answer = targets[7]
+ loss1_list = []
+ loss2_list = []
+ loss3_list = []
+ loss4_list = []
+ sampler1_num = 0
+ sampler2_num = 0
+ sampler3_num = 0
+ sampler4_num = 0
+ if not self.ds:
+ visual_f = x + self.vis_pos_embed
+ elif self.pos2d:
+ x = x + self.vis_pos_embed[:, :, : tuple(x.shape)[2], : tuple(x.shape)[3]]
+ visual_f = x.flatten(start_axis=2).transpose(
+ perm=dim2perm(x.flatten(start_axis=2).ndim, 1, 2)
+ )
+ else:
+ visual_f = x
+ train_i = 0
+ for target_ in zip(
+ targets[1].transpose(perm=dim2perm(targets[1].ndim, 0, 1)),
+ targets[2].transpose(perm=dim2perm(targets[2].ndim, 0, 1)),
+ targets[3].transpose(perm=dim2perm(targets[3].ndim, 0, 1)),
+ targets[4].transpose(perm=dim2perm(targets[4].ndim, 0, 1)),
+ targets[5].transpose(perm=dim2perm(targets[5].ndim, 0, 1)),
+ targets[6].transpose(perm=dim2perm(targets[6].ndim, 0, 1)),
+ targets[8].transpose(perm=dim2perm(targets[8].ndim, 0, 1)),
+ targets[9].transpose(perm=dim2perm(targets[9].ndim, 0, 1)),
+ targets[10].transpose(perm=dim2perm(targets[10].ndim, 0, 1)),
+ targets[11].transpose(perm=dim2perm(targets[11].ndim, 0, 1)),
+ ):
+ visual_f_1234 = visual_f
+ if self.quesall or train_i == 0:
+ (
+ prompt,
+ ques1,
+ ques2,
+ ques2_head,
+ ques4,
+ ques1_answer,
+ ques2_answer,
+ ques4_answer,
+ mask_1234,
+ ) = self.question_encoder(target_, train_i)
+ prompt_1234 = prompt
+ ques_1234 = paddle.concat(x=[ques1, ques2, ques3, ques4], axis=1)
+ for layer in self.cmff_decoder:
+ ques_1234, prompt_1234, visual_f_1234 = layer(
+ ques_1234, prompt_1234, visual_f_1234, mask_1234
+ )
+ answer_query_1234 = self.answer_to_question_layer(
+ ques_1234, prompt_1234, mask_1234
+ )
+ answer_feats_1234 = self.norm_pred(
+ self.answer_to_image_layer(answer_query_1234, visual_f_1234)
+ )
+ answer_feats_1 = answer_feats_1234[:, : tuple(ques1.shape)[1], :]
+ answer_feats_2 = answer_feats_1234[
+ :,
+ tuple(ques1.shape)[1] : tuple(ques1.shape)[1]
+ + tuple(ques2.shape)[1],
+ :,
+ ]
+ answer_feats_3 = answer_feats_1234[
+ :,
+ tuple(ques1.shape)[1]
+ + tuple(ques2.shape)[1] : -tuple(ques4.shape)[1],
+ :,
+ ]
+ answer_feats_4 = answer_feats_1234[:, -tuple(ques4.shape)[1] :, :]
+ answer1_pred = self.ques1_head(answer_feats_1)
+ if train_i == 0:
+ logits = answer1_pred
+ n = (ques1_answer != self.ignore_index).sum().item()
+ loss1 = n * nn.functional.cross_entropy(
+ input=answer1_pred.flatten(start_axis=0, stop_axis=1),
+ label=ques1_answer.flatten(start_axis=0, stop_axis=1),
+ ignore_index=self.ignore_index,
+ reduction="mean",
+ )
+ sampler1_num += n
+ loss1_list.append(loss1)
+ answer2_pred = answer_feats_2.matmul(
+ y=ques2_head.transpose(perm=dim2perm(ques2_head.ndim, 1, 2))
+ )
+ diag_mask = (
+ paddle.eye(num_rows=tuple(answer2_pred.shape)[1])
+ .unsqueeze(axis=0)
+ .tile(repeat_times=[bs, 1, 1])
+ )
+ answer2_pred = (answer2_pred * diag_mask).sum(axis=-1)
+ ques2_answer = ques2_answer.flatten(start_axis=0, stop_axis=1)
+ non_pad_mask = paddle.not_equal(
+ x=ques2_answer,
+ y=paddle.to_tensor(self.ignore_index, dtype=paddle.float32),
+ )
+ n = non_pad_mask.sum().item()
+ ques2_answer = paddle.where(
+ condition=ques2_answer == self.ignore_index,
+ x=paddle.to_tensor(0.0, dtype=paddle.float32),
+ y=ques2_answer,
+ )
+ loss2_none = nn.functional.binary_cross_entropy_with_logits(
+ logit=answer2_pred.flatten(start_axis=0, stop_axis=1),
+ label=ques2_answer,
+ reduction="none",
+ )
+ loss2 = n * loss2_none.masked_select(mask=non_pad_mask).mean()
+ sampler2_num += n
+ loss2_list.append(loss2)
+ answer3_pred = self.ques3_head(answer_feats_3)
+ n = (ques3_answer != self.ignore_index).sum().item()
+ loss3 = n * nn.functional.cross_entropy(
+ input=answer3_pred.flatten(start_axis=0, stop_axis=1),
+ label=ques3_answer.flatten(start_axis=0, stop_axis=1),
+ reduction="mean",
+ )
+ sampler3_num += n
+ loss3_list.append(loss3)
+ answer4_pred = self.ques4_head(answer_feats_4)
+ n = (ques4_answer != self.max_len - 1).sum().item()
+ loss4 = n * nn.functional.cross_entropy(
+ input=answer4_pred.flatten(start_axis=0, stop_axis=1),
+ label=ques4_answer.flatten(start_axis=0, stop_axis=1),
+ ignore_index=self.max_len - 1,
+ reduction="mean",
+ )
+ sampler4_num += n
+ loss4_list.append(loss4)
+ else:
+ prompt, ques1, ques1_answer, mask_1234 = self.question_encoder(
+ target_, train_i
+ )
+ prompt_1234 = prompt
+ for layer in self.cmff_decoder:
+ ques1, prompt_1234, visual_f_1234 = layer(
+ ques1, prompt_1234, visual_f_1234, mask_1234
+ )
+ answer_query_1 = self.answer_to_question_layer(
+ ques1, prompt_1234, mask_1234
+ )
+ answer_feats_1 = self.norm_pred(
+ self.answer_to_image_layer(answer_query_1, visual_f_1234)
+ )
+ answer1_pred = self.ques1_head(answer_feats_1)
+ n = (ques1_answer != self.ignore_index).sum().item()
+ loss1 = n * nn.functional.cross_entropy(
+ input=answer1_pred.flatten(start_axis=0, stop_axis=1),
+ label=ques1_answer.flatten(start_axis=0, stop_axis=1),
+ ignore_index=self.ignore_index,
+ reduction="mean",
+ )
+ sampler1_num += n
+ loss1_list.append(loss1)
+ train_i += 1
+ loss_list = [
+ sum(loss1_list) / sampler1_num,
+ sum(loss2_list) / sampler2_num,
+ sum(loss3_list) / sampler3_num,
+ sum(loss4_list) / sampler4_num,
+ ]
+ loss = {
+ "loss": sum(loss_list),
+ "loss1": loss_list[0],
+ "loss2": loss_list[1],
+ "loss3": loss_list[2],
+ "loss4": loss_list[3],
+ }
+ return [loss, logits]
diff --git a/ppocr/modeling/heads/rec_parseq_head.py b/ppocr/modeling/heads/rec_parseq_head.py
index 6fc2dc3863f..8f7869a307d 100644
--- a/ppocr/modeling/heads/rec_parseq_head.py
+++ b/ppocr/modeling/heads/rec_parseq_head.py
@@ -495,7 +495,7 @@ def forward(self, feat, targets=None):
if self.training:
label = targets[0] # label
label_len = targets[1]
- max_step = paddle.max(label_len).cpu().numpy()[0] + 2
+ max_step = paddle.max(label_len).cpu().item() + 2
crop_label = label[:, :max_step]
final_out = self.forward_train(feat, crop_label)
else:
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index e427ee9cd60..fc7897c6202 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -45,6 +45,7 @@
CPPDLabelDecode,
LaTeXOCRDecode,
UniMERNetDecode,
+ IGTRLabelDecode,
)
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
@@ -101,6 +102,7 @@ def build_post_process(config, global_config=None):
"CPPDLabelDecode",
"LaTeXOCRDecode",
"UniMERNetDecode",
+ "IGTRLabelDecode",
]
if config["name"] == "PSEPostProcess":
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 9d3a958272f..0a3518286fe 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -1482,3 +1482,92 @@ def __call__(self, preds, label=None, mode="eval", *args, **kwargs):
return text
label = self.token2str(np.array(label))
return text, label
+
+
+class IGTRLabelDecode(BaseRecLabelDecode):
+ """Convert between text-label and text-index."""
+
+ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
+ super(IGTRLabelDecode, self).__init__(character_dict_path, use_space_char)
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+
+ if isinstance(preds, list):
+ if isinstance(preds[0], dict):
+ preds = preds[-1].numpy()
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ elif isinstance(preds, dict):
+ preds = preds["align"][-1].numpy()
+ else:
+ preds = preds
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ else:
+ preds_idx = preds[0].numpy()
+ preds_prob = preds[1].numpy()
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ else:
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ elif isinstance(preds, dict):
+ preds = preds["align"][-1].numpy()
+ else:
+ preds = preds
+ preds_idx = preds.argmax(axis=2)
+ preds_idx_top5 = preds.argsort(axis=2)[:, :, -5:]
+ preds_prob = preds.max(axis=2)
+ text = self.decode(
+ preds_idx,
+ preds_prob,
+ is_remove_duplicate=False,
+ idx_top5=preds_idx_top5,
+ )
+ if batch is None:
+ return text
+ label = batch
+ label = self.decode(label)
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = [""] + dict_character + ["", ""]
+ return dict_character
+
+ def decode(
+ self, text_index, text_prob=None, is_remove_duplicate=False, idx_top5=None
+ ):
+ """convert text-index into text-label."""
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ char_list_top5 = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ char_idx_top5 = []
+ try:
+ char_idx = self.character[int(text_index[batch_idx][idx])]
+ if idx_top5 is not None:
+ for top5_i in idx_top5[batch_idx][idx]:
+ char_idx_top5.append(self.character[top5_i])
+ except:
+ continue
+ if char_idx == "": # end
+ break
+ if char_idx == "" or char_idx == "":
+ continue
+ char_list.append(char_idx)
+ char_list_top5.append(char_idx_top5)
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = "".join(char_list)
+ if idx_top5 is not None:
+ result_list.append(
+ (text, [np.mean(conf_list).tolist(), char_list_top5])
+ )
+ else:
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index 4d4b7ba03bc..7ed7108818b 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -192,6 +192,10 @@ def load_pretrained_params(model, path):
)
)
+ for k1 in state_dict.keys():
+ if k1 not in params.keys():
+ logger.warning("The model params {} not in pretrained file".format(k1))
+
model.set_state_dict(new_state_dict)
if is_float16:
logger.info(
diff --git a/ppstructure/recovery/table_process.py b/ppstructure/recovery/table_process.py
index f5e01e64813..391cd39129f 100644
--- a/ppstructure/recovery/table_process.py
+++ b/ppstructure/recovery/table_process.py
@@ -273,11 +273,15 @@ def handle_table(self, html, doc):
cell_col += 1
docx_cell = table.cell(cell_row, cell_col)
- cell_to_merge = table.cell(
- cell_row + rowspan - 1, cell_col + colspan - 1
- )
- if docx_cell != cell_to_merge:
- docx_cell.merge(cell_to_merge)
+ target_row = cell_row + rowspan - 1
+ target_col = cell_col + colspan - 1
+ if target_row < num_rows and target_col < num_cols:
+ cell_to_merge = table.cell(target_row, target_col)
+ if docx_cell != cell_to_merge:
+ docx_cell.merge(cell_to_merge)
+ else:
+ # Handle out-of-bounds merge, maybe log or skip
+ continue
child_parser = HtmlToDocx()
child_parser.copy_settings_from(self)
diff --git a/tools/program.py b/tools/program.py
index 3eb43b7f0c1..f2d3e564403 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -264,6 +264,7 @@ def train(
"SVTR_HGNet",
"ParseQ",
"CPPD",
+ "IGTR",
]
extra_input = False
if config["Architecture"]["algorithm"] == "Distillation":
@@ -871,6 +872,7 @@ def preprocess(is_train=False):
"SLANeXt",
"PP-FormulaNet-S",
"PP-FormulaNet-L",
+ "IGTR",
]
if use_xpu:
From 4b765749465491d78431e3c786b58fff8863be43 Mon Sep 17 00:00:00 2001
From: topduke <784990967@qq.com>
Date: Tue, 18 Feb 2025 11:49:09 +0000
Subject: [PATCH 2/2] add igtr in overview
---
docs/algorithm/overview.en.md | 1 +
docs/algorithm/overview.md | 1 +
2 files changed, 2 insertions(+)
diff --git a/docs/algorithm/overview.en.md b/docs/algorithm/overview.en.md
index 06c01461db5..569e76fa3fd 100755
--- a/docs/algorithm/overview.en.md
+++ b/docs/algorithm/overview.en.md
@@ -77,6 +77,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [ParseQ](./text_recognition/algorithm_rec_parseq.md)
- [x] [CPPD](./text_recognition/algorithm_rec_cppd.en.md)
- [x] [SATRN](./text_recognition/algorithm_rec_satrn.en.md)
+- [x] [IGTR](./text_recognition/algorithm_rec_igtr.en.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
diff --git a/docs/algorithm/overview.md b/docs/algorithm/overview.md
index 47352d89781..b03720fd409 100755
--- a/docs/algorithm/overview.md
+++ b/docs/algorithm/overview.md
@@ -78,6 +78,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
- [x] [ParseQ](./text_recognition/algorithm_rec_parseq.md)
- [x] [CPPD](./text_recognition/algorithm_rec_cppd.md)
- [x] [SATRN](./text_recognition/algorithm_rec_satrn.md)
+- [x] [IGTR](./text_recognition/algorithm_rec_igtr.md)
参考[DTRB](https://arxiv.org/abs/1904.01906) (3)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: