Skip to content

Commit c4093df

Browse files
Refactor/refactor generators (#137)
* fix: change cache_dir in read operator to working_dir * refactor: use xml format prompt in Generators * feat: change temperature & max_token in vllmwrapper * Update graphgen/models/generator/vqa_generator.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 27733a4 commit c4093df

File tree

12 files changed

+203
-130
lines changed

12 files changed

+203
-130
lines changed

graphgen/models/generator/aggregated_generator.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any
1+
import re
2+
from typing import Any, Optional
23

34
from graphgen.bases import BaseGenerator
45
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
@@ -56,19 +57,21 @@ def build_prompt(
5657
return prompt
5758

5859
@staticmethod
59-
def parse_rephrased_text(response: str) -> str:
60+
def parse_rephrased_text(response: str) -> Optional[str]:
6061
"""
6162
Parse the rephrased text from the response.
6263
:param response:
6364
:return: rephrased text
6465
"""
65-
if "Rephrased Text:" in response:
66-
rephrased_text = response.split("Rephrased Text:")[1].strip()
67-
elif "重述文本:" in response:
68-
rephrased_text = response.split("重述文本:")[1].strip()
66+
rephrased_match = re.search(
67+
r"<rephrased_text>(.*?)</rephrased_text>", response, re.DOTALL
68+
)
69+
if rephrased_match:
70+
rephrased_text = rephrased_match.group(1).strip()
6971
else:
70-
rephrased_text = response.strip()
71-
return rephrased_text.strip('"')
72+
logger.warning("Failed to parse rephrased text from response: %s", response)
73+
return None
74+
return rephrased_text.strip('"').strip("'")
7275

7376
@staticmethod
7477
def _build_prompt_for_question_generation(answer: str) -> str:
@@ -85,15 +88,13 @@ def _build_prompt_for_question_generation(answer: str) -> str:
8588

8689
@staticmethod
8790
def parse_response(response: str) -> dict:
88-
if response.startswith("Question:"):
89-
question = response[len("Question:") :].strip()
90-
elif response.startswith("问题:"):
91-
question = response[len("问题:") :].strip()
91+
question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
92+
if question_match:
93+
question = question_match.group(1).strip()
9294
else:
93-
question = response.strip()
94-
return {
95-
"question": question,
96-
}
95+
logger.warning("Failed to parse question from response: %s", response)
96+
return {"question": ""}
97+
return {"question": question.strip('"').strip("'")}
9798

9899
async def generate(
99100
self,
@@ -110,9 +111,13 @@ async def generate(
110111
rephrasing_prompt = self.build_prompt(batch)
111112
response = await self.llm_client.generate_answer(rephrasing_prompt)
112113
context = self.parse_rephrased_text(response)
114+
if not context:
115+
return result
113116
question_generation_prompt = self._build_prompt_for_question_generation(context)
114117
response = await self.llm_client.generate_answer(question_generation_prompt)
115118
question = self.parse_response(response)["question"]
119+
if not question:
120+
return result
116121
logger.debug("Question: %s", question)
117122
logger.debug("Answer: %s", context)
118123
qa_pairs = {

graphgen/models/generator/atomic_generator.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Any
23

34
from graphgen.bases import BaseGenerator
@@ -29,17 +30,18 @@ def parse_response(response: str) -> dict:
2930
:param response:
3031
:return:
3132
"""
32-
if "Question:" in response and "Answer:" in response:
33-
question = response.split("Question:")[1].split("Answer:")[0].strip()
34-
answer = response.split("Answer:")[1].strip()
35-
elif "问题:" in response and "答案:" in response:
36-
question = response.split("问题:")[1].split("答案:")[0].strip()
37-
answer = response.split("答案:")[1].strip()
33+
question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
34+
answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
35+
36+
if question_match and answer_match:
37+
question = question_match.group(1).strip()
38+
answer = answer_match.group(1).strip()
3839
else:
3940
logger.warning("Failed to parse response: %s", response)
4041
return {}
41-
question = question.strip('"')
42-
answer = answer.strip('"')
42+
43+
question = question.strip('"').strip("'")
44+
answer = answer.strip('"').strip("'")
4345
logger.debug("Question: %s", question)
4446
logger.debug("Answer: %s", answer)
4547
return {

graphgen/models/generator/cot_generator.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Any
23

34
from graphgen.bases import BaseGenerator
@@ -67,22 +68,26 @@ def build_prompt_for_cot_generation(
6768

6869
@staticmethod
6970
def parse_response(response: str) -> dict:
70-
if "Question:" in response and "Reasoning-Path Design:" in response:
71-
question = (
72-
response.split("Question:")[1]
73-
.split("Reasoning-Path Design:")[0]
74-
.strip()
75-
)
76-
reasoning_path = response.split("Reasoning-Path Design:")[1].strip()
77-
elif "问题:" in response and "推理路径设计:" in response:
78-
question = response.split("问题:")[1].split("推理路径设计:")[0].strip()
79-
reasoning_path = response.split("推理路径设计:")[1].strip()
71+
"""
72+
Parse CoT template from response.
73+
:param response:
74+
:return: dict with question and reasoning_path
75+
"""
76+
question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
77+
reasoning_path_match = re.search(
78+
r"<reasoning_path>(.*?)</reasoning_path>", response, re.DOTALL
79+
)
80+
81+
if question_match and reasoning_path_match:
82+
question = question_match.group(1).strip()
83+
reasoning_path = reasoning_path_match.group(1).strip()
8084
else:
81-
logger.warning("Failed to parse CoT template: %s", response)
85+
logger.warning("Failed to parse response: %s", response)
8286
return {}
8387

84-
question = question.strip('"')
85-
reasoning_path = reasoning_path.strip('"')
88+
question = question.strip('"').strip("'")
89+
reasoning_path = reasoning_path.strip('"').strip("'")
90+
8691
logger.debug("CoT Question: %s", question)
8792
logger.debug("CoT Reasoning Path: %s", reasoning_path)
8893
return {
@@ -105,6 +110,8 @@ async def generate(
105110
prompt = self.build_prompt(batch)
106111
response = await self.llm_client.generate_answer(prompt)
107112
response = self.parse_response(response)
113+
if not response:
114+
return result
108115
question, reasoning_path = response["question"], response["reasoning_path"]
109116
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
110117
cot_answer = await self.llm_client.generate_answer(prompt)

graphgen/models/generator/multi_hop_generator.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Any
23

34
from graphgen.bases import BaseGenerator
@@ -32,17 +33,18 @@ def build_prompt(
3233

3334
@staticmethod
3435
def parse_response(response: str) -> dict:
35-
if "Question:" in response and "Answer:" in response:
36-
question = response.split("Question:")[1].split("Answer:")[0].strip()
37-
answer = response.split("Answer:")[1].strip()
38-
elif "问题:" in response and "答案:" in response:
39-
question = response.split("问题:")[1].split("答案:")[0].strip()
40-
answer = response.split("答案:")[1].strip()
36+
question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
37+
answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
38+
39+
if question_match and answer_match:
40+
question = question_match.group(1).strip()
41+
answer = answer_match.group(1).strip()
4142
else:
4243
logger.warning("Failed to parse response: %s", response)
4344
return {}
44-
question = question.strip('"')
45-
answer = answer.strip('"')
45+
46+
question = question.strip('"').strip("'")
47+
answer = answer.strip('"').strip("'")
4648
logger.debug("Question: %s", question)
4749
logger.debug("Answer: %s", answer)
4850
return {

graphgen/models/generator/vqa_generator.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Any
23

34
from graphgen.bases import BaseGenerator
@@ -38,25 +39,21 @@ def parse_response(response: str) -> Any:
3839
:return: QA pairs
3940
"""
4041
qa_pairs = {}
41-
qa_list = response.strip().split("\n\n")
42-
for qa in qa_list:
43-
if "Question:" in qa and "Answer:" in qa:
44-
question = qa.split("Question:")[1].split("Answer:")[0].strip()
45-
answer = qa.split("Answer:")[1].strip()
46-
elif "问题:" in qa and "答案:" in qa:
47-
question = qa.split("问题:")[1].split("答案:")[0].strip()
48-
answer = qa.split("答案:")[1].strip()
49-
else:
50-
logger.error("Failed to parse QA pair: %s", qa)
51-
continue
52-
question = question.strip('"')
53-
answer = answer.strip('"')
54-
logger.debug("Question: %s", question)
55-
logger.debug("Answer: %s", answer)
56-
qa_pairs[compute_content_hash(question)] = {
57-
"question": question,
58-
"answer": answer,
59-
}
42+
pattern = r"<question>(.*?)</question>\s*<answer>(.*?)</answer>"
43+
matches = re.findall(pattern, response, re.DOTALL)
44+
45+
if matches:
46+
for question, answer in matches:
47+
question = question.strip().strip('"').strip("'")
48+
answer = answer.strip().strip('"').strip("'")
49+
logger.debug("Question: %s", question)
50+
logger.debug("Answer: %s", answer)
51+
qa_pairs[compute_content_hash(question)] = {
52+
"question": question,
53+
"answer": answer,
54+
}
55+
else:
56+
logger.warning("Error parsing the response %s", response)
6057
return qa_pairs
6158

6259
async def generate(

graphgen/models/llm/local/vllm_wrapper.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(
1616
model: str,
1717
tensor_parallel_size: int = 1,
1818
gpu_memory_utilization: float = 0.9,
19-
temperature: float = 0.0,
19+
temperature: float = 0.6,
2020
top_p: float = 1.0,
2121
topk: int = 5,
2222
**kwargs: Any,
@@ -66,7 +66,7 @@ async def generate_answer(
6666
sp = self.SamplingParams(
6767
temperature=self.temperature if self.temperature > 0 else 1.0,
6868
top_p=self.top_p if self.temperature > 0 else 1.0,
69-
max_tokens=extra.get("max_new_tokens", 512),
69+
max_tokens=extra.get("max_new_tokens", 2048),
7070
)
7171

7272
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
@@ -82,7 +82,7 @@ async def generate_answer(
8282

8383
async def generate_topk_per_token(
8484
self, text: str, history: Optional[List[str]] = None, **extra: Any
85-
) -> List[Token]:
85+
) -> List[Token]:
8686
full_prompt = self._build_inputs(text, history)
8787
request_id = f"graphgen_topk_{uuid.uuid4()}"
8888

@@ -110,7 +110,9 @@ async def generate_topk_per_token(
110110

111111
candidate_tokens = []
112112
for _, logprob_obj in top_logprobs.items():
113-
tok_str = logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
113+
tok_str = (
114+
logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
115+
)
114116
prob = float(math.exp(logprob_obj.logprob))
115117
candidate_tokens.append(Token(tok_str, prob))
116118

@@ -120,7 +122,7 @@ async def generate_topk_per_token(
120122
main_token = Token(
121123
text=candidate_tokens[0].text,
122124
prob=candidate_tokens[0].prob,
123-
top_candidates=candidate_tokens
125+
top_candidates=candidate_tokens,
124126
)
125127
return [main_token]
126128
return []

graphgen/operators/generate/generate_service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def generate(self, items: list[dict]) -> list[dict]:
6161
unit="batch",
6262
)
6363

64+
# Filter out empty results
65+
results = [res for res in results if res]
66+
6467
results = self.generator.format_generation_results(
6568
results, output_data_format=self.data_format
6669
)

graphgen/templates/generation/aggregated_generation.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@
132132
- Logical consistency throughout
133133
- Clear cause-and-effect relationships
134134
135+
**Attention: Please directly provide the rephrased text without any additional content or analysis.**
136+
135137
################
136138
-ENTITIES-
137139
################
@@ -175,6 +177,8 @@
175177
- 整体逻辑一致性
176178
- 清晰的因果关系
177179
180+
**注意: 请你直接给出重述文本,不要输出任何额外的内容,也不要进行任何分析。**
181+
178182
################
179183
-实体-
180184
################
@@ -191,32 +195,52 @@
191195
################
192196
请在下方直接输出连贯的重述文本,不要输出任何额外的内容。
193197
198+
输出格式:
199+
<rephrased_text>rephrased_text_here</rephrased_text>
200+
194201
重述文本:
195202
"""
196203

197204
REQUIREMENT_EN = """
198205
################
199206
Please directly output the coherent rephrased text below, without any additional content.
200207
208+
Output format:
209+
<rephrased_text>rephrased_text_here</rephrased_text>
210+
201211
Rephrased Text:
202212
"""
203213

204214
QUESTION_GENERATION_EN: str = """The answer to a question is provided. Please generate a question that corresponds to the answer.
205215
206-
################
207-
Answer:
208-
{answer}
209-
################
216+
The answer for which a question needs to be generated is as follows:
217+
<answer>{answer}</answer>
218+
219+
Please note the following requirements:
220+
1. Only output one question text without any additional explanations or analysis.
221+
2. Do not repeat the content of the answer or any fragments of it.
222+
3. The question must be independently understandable and fully match the answer.
223+
224+
Output format:
225+
<question>question_text</question>
226+
210227
Question:
211228
"""
212229

213230
QUESTION_GENERATION_ZH: str = """下面提供了一个问题的答案,请生成一个与答案对应的问题。
214231
215-
################
216-
答案:
217-
{answer}
218-
################
219-
问题:
232+
需要生成问题的答案如下:
233+
<answer>{answer}</answer>
234+
235+
请注意下列要求:
236+
1. 仅输出一个问题文本,不得包含任何额外说明或分析
237+
2. 不得重复答案内容或其中任何片段
238+
3. 问题必须可独立理解且与答案完全匹配
239+
240+
输出格式:
241+
<question>question_text</question>
242+
243+
问题:
220244
"""
221245

222246
AGGREGATED_GENERATION_PROMPT = {

0 commit comments

Comments
 (0)