Skip to content

Commit 20a623a

Browse files
authored
Merge pull request #1734 from myhloli/dev
refactor(magic_pdf): improve title optimization process
2 parents e0f591e + 54940c6 commit 20a623a

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

magic_pdf/post_proc/llm_aided.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from loguru import logger
44
from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
55
from openai import OpenAI
6+
import ast
67

78

89
#@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
@@ -119,11 +120,12 @@ def llm_aided_title(pdf_info_dict, title_aided_config):
119120
- 在完成初步分级后,仔细检查分级结果的合理性
120121
- 根据上下文关系和逻辑顺序,对不合理的分级进行微调
121122
- 确保最终的分级结果符合文档的实际结构和逻辑
123+
- 字典中包含OCR错误识别的标题,你可以通过将其层级标记为 0 来排除它们
122124
123125
IMPORTANT:
124-
请直接返回优化过的由标题层级组成的json,格式如下
125-
{{"0":1,"1":2,"2":2,"3":3}}
126-
返回的json不需要格式化
126+
请直接返回优化过的由标题层级组成的字典,格式为{{标题id:标题层级}},如下
127+
{{0:1,1:2,2:2,3:3}}
128+
不需要对字典格式化,不需要返回任何其他信息
127129
128130
Input title list:
129131
{title_dict}
@@ -133,7 +135,7 @@ def llm_aided_title(pdf_info_dict, title_aided_config):
133135

134136
retry_count = 0
135137
max_retries = 3
136-
json_completion = None
138+
dict_completion = None
137139

138140
while retry_count < max_retries:
139141
try:
@@ -143,24 +145,20 @@ def llm_aided_title(pdf_info_dict, title_aided_config):
143145
{'role': 'user', 'content': title_optimize_prompt}],
144146
temperature=0.7,
145147
)
146-
json_completion = json.loads(completion.choices[0].message.content)
148+
# logger.info(f"Title completion: {completion.choices[0].message.content}")
149+
dict_completion = ast.literal_eval(completion.choices[0].message.content)
150+
# logger.info(f"len(dict_completion): {len(dict_completion)}, len(title_dict): {len(title_dict)}")
147151

148-
# logger.info(f"Title completion: {json_completion}")
149-
# logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
150-
151-
if len(json_completion) == len(title_dict):
152+
if len(dict_completion) == len(title_dict):
152153
for i, origin_title_block in enumerate(origin_title_list):
153-
origin_title_block["level"] = int(json_completion[str(i)])
154+
origin_title_block["level"] = int(dict_completion[i])
154155
break
155156
else:
156157
logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.")
157158
retry_count += 1
158159
except Exception as e:
159-
if isinstance(e, json.decoder.JSONDecodeError):
160-
logger.warning(f"JSON decode error on attempt {retry_count + 1}: {e}")
161-
else:
162-
logger.exception(e)
160+
logger.exception(e)
163161
retry_count += 1
164162

165-
if json_completion is None:
166-
logger.error("Failed to decode JSON after maximum retries.")
163+
if dict_completion is None:
164+
logger.error("Failed to decode dict after maximum retries.")

0 commit comments

Comments
 (0)