Skip to content

Commit 172389f

Browse files
1649759610yingyibiaotianxin
authored
add Doc-VQA applications (#2214)
* initial commit * modify _calc_img_embeddings to support running without img embedding. * remove commented code * delete README * refine readme.md * change question * modify layoutxlm to support traing without image embedding * modify _calc_img_embeddings in layoutxlm to support training without img_embeddings * modify _calc_img_embeddings in layoutxlm to support training without img_embeddings * refine .gitignore * refine Rerank with pre-commit * refine Extraction with pre-commit * refine code and readme details * refine coding * refine code style about imports * refine README * set CUDA_VISIBLE_DEVICES 0 * refine code style * refine readme * refine readme * delete ocr parsing file * refine readme Co-authored-by: yingyibiao <[email protected]> Co-authored-by: tianxin <[email protected]>
1 parent 1d433b1 commit 172389f

29 files changed

+22591
-1
lines changed

applications/doc_vqa/.gitignore

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
OCR_process/*.json
2+
*.png
3+
answers/*
4+
checkpoints/*
5+
__pycache__/*
6+
OCR_process/demo_pics/*
7+
Rerank/log/*
8+
Rerank/checkpoints/*
9+
Rerank/data/*
10+
Rerank/output/*
11+
Rerank/__pycache__/*
12+
Extraction/log/*
13+
Extraction/checkpoints/*
14+
Extraction/data/*
15+
Extraction/output/*
16+
Extraction/__pycache__/*
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import sys
2+
import json
3+
import numpy as np
4+
5+
6+
def get_top1_from_ranker(path):
7+
with open(path, "r", encoding="utf-8") as f:
8+
scores = [float(line.strip()) for line in f.readlines()]
9+
top_id = np.argmax(scores)
10+
11+
return top_id
12+
13+
14+
def get_ocr_result_by_id(path, top_id):
15+
with open(path, "r", encoding="utf-8") as f:
16+
reses = f.readlines()
17+
res = reses[top_id]
18+
return json.loads(res)
19+
20+
21+
def write_to_file(doc, path):
22+
with open(path, "w", encoding="utf-8") as f:
23+
json.dump(doc, f, ensure_ascii=False)
24+
f.write("\n")
25+
26+
27+
if __name__ == "__main__":
28+
question = sys.argv[1]
29+
ranker_result_path = "../Rerank/data/demo.score"
30+
ocr_result_path = "../OCR_process/demo_ocr_res.json"
31+
save_path = "data/demo_test.json"
32+
top_id = get_top1_from_ranker(ranker_result_path)
33+
doc = get_ocr_result_by_id(ocr_result_path, top_id)
34+
doc["question"] = question
35+
doc["img_id"] = str(top_id + 1)
36+
37+
write_to_file(doc, save_path)

0 commit comments

Comments
 (0)