Skip to content

Commit c4ea24f

Browse files
feat(graphgen): support multiple rephrasing strategies
1 parent 7d42c56 commit c4ea24f

File tree

8 files changed

+278
-67
lines changed

8 files changed

+278
-67
lines changed

generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
sys_path = os.path.abspath(os.path.dirname(__file__))
1313
unique_id = int(time.time())
1414
set_logger(os.path.join(sys_path, "cache", "logs", f"graphgen_{unique_id}.log"), if_stream=False)
15-
config_path = os.path.join(sys_path, "cache", "configs", f"graphgen_{unique_id}.yaml")
15+
config_path = os.path.join(sys_path, "cache", "data", "graphgen", str(unique_id), f"config-{unique_id}.yaml")
1616

1717
load_dotenv()
1818

@@ -71,7 +71,7 @@ def save_config(global_config):
7171

7272
graph_gen.insert(data, config['data_type'])
7373

74-
graph_gen.quiz(max_samples=2)
74+
graph_gen.quiz(max_samples=config['quiz_samples'])
7575

7676
graph_gen.judge(re_judge=False)
7777

graphgen/graphgen.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@ class GraphGen:
3636
working_dir, namespace="rephrase"
3737
)
3838
qa_storage: JsonKVStorage = JsonKVStorage(
39-
os.path.join(working_dir, "data", "graphgen"), namespace=f"qa-{unique_id}"
39+
os.path.join(working_dir, "data", "graphgen", str(unique_id)), namespace=f"qa-{unique_id}"
4040
)
4141

4242
# text chunking
4343
chunk_size: int = 1024
4444
chunk_overlap_size: int = 100
4545

4646
# llm
47-
teacher_llm_client: OpenAIModel = None
48-
student_llm_client: OpenAIModel = None
47+
synthesizer_llm_client: OpenAIModel = None
48+
training_llm_client: OpenAIModel = None
4949
tokenizer_instance: Tokenizer = None
5050

5151
# web search
@@ -73,7 +73,7 @@ async def async_split_chunks(self, data: Union[List[list], List[dict]], data_typ
7373
if len(new_docs) == 0:
7474
logger.warning("All docs are already in the storage")
7575
return {}
76-
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
76+
logger.info("[New Docs] inserting %d docs", len(new_docs))
7777
for doc_key, doc in tqdm_async(
7878
new_docs.items(), desc="Chunking documents", unit="doc"
7979
):
@@ -127,14 +127,14 @@ async def async_insert(self, data: Union[List[list], List[dict]], data_type: str
127127

128128
inserting_chunks = await self.async_split_chunks(data, data_type)
129129

130-
if not len(inserting_chunks):
130+
if len(inserting_chunks) == 0:
131131
logger.warning("All chunks are already in the storage")
132132
return
133-
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
133+
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
134134

135135
logger.info("[Entity and Relation Extraction]...")
136136
_add_entities_and_relations = await extract_kg(
137-
llm_client=self.teacher_llm_client,
137+
llm_client=self.synthesizer_llm_client,
138138
kg_instance=self.graph_storage,
139139
tokenizer_instance=self.tokenizer_instance,
140140
chunks=[Chunk(id=k, content=v['content']) for k, v in inserting_chunks.items()]
@@ -147,7 +147,7 @@ async def async_insert(self, data: Union[List[list], List[dict]], data_type: str
147147
if self.if_web_search:
148148
logger.info("[Wiki Search]...")
149149
_add_wiki_data = await search_wikipedia(
150-
llm_client= self.teacher_llm_client,
150+
llm_client= self.synthesizer_llm_client,
151151
wiki_search_client=self.wiki_client,
152152
knowledge_graph_instance=_add_entities_and_relations
153153
)
@@ -169,15 +169,15 @@ def quiz(self, max_samples=1):
169169
loop.run_until_complete(self.async_quiz(max_samples))
170170

171171
async def async_quiz(self, max_samples=1):
172-
await quiz(self.teacher_llm_client, self.graph_storage, self.rephrase_storage, max_samples)
172+
await quiz(self.synthesizer_llm_client, self.graph_storage, self.rephrase_storage, max_samples)
173173
await self.rephrase_storage.index_done_callback()
174174

175175
def judge(self, re_judge=False):
176176
loop = create_event_loop()
177177
loop.run_until_complete(self.async_judge(re_judge))
178178

179179
async def async_judge(self, re_judge=False):
180-
_update_relations = await judge_statement(self.student_llm_client, self.graph_storage,
180+
_update_relations = await judge_statement(self.training_llm_client, self.graph_storage,
181181
self.rephrase_storage, re_judge)
182182
await _update_relations.index_done_callback()
183183

@@ -186,7 +186,7 @@ def traverse(self):
186186
loop.run_until_complete(self.async_traverse())
187187

188188
async def async_traverse(self):
189-
results = await traverse_graph_by_edge(self.teacher_llm_client, self.tokenizer_instance,
190-
self.graph_storage, self.traverse_strategy)
189+
results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance,
190+
self.graph_storage, self.traverse_strategy, self.text_chunks_storage)
191191
await self.qa_storage.upsert(results)
192192
await self.qa_storage.index_done_callback()

graphgen/operators/traverse_graph.py

Lines changed: 119 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
from tqdm.asyncio import tqdm as tqdm_async
33

4-
from models import OpenAIModel, NetworkXStorage, TraverseStrategy, Tokenizer
4+
from models import OpenAIModel, NetworkXStorage, TraverseStrategy, Tokenizer, JsonKVStorage
55
from templates import ANSWER_REPHRASING_PROMPT, QUESTION_GENERATION_PROMPT
66
from utils import detect_main_language, compute_content_hash, logger
77
from graphgen.operators.split_graph import get_batches_with_strategy
@@ -49,6 +49,46 @@ async def handle_node(node: dict) -> dict:
4949
await graph_storage.index_done_callback()
5050
return new_edges, new_nodes
5151

52+
async def _construct_rephrasing_prompt(_process_nodes: list,
53+
_process_edges: list,
54+
_difficulty: str,
55+
text_chunks_storage: JsonKVStorage,
56+
add_context: bool = False
57+
) -> str:
58+
entities = [
59+
f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes
60+
]
61+
relations = [
62+
f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
63+
for _process_edge in _process_edges
64+
]
65+
66+
entities_str = "\n".join([f"{index + 1}. {entity}" for index, entity in enumerate(entities)])
67+
relations_str = "\n".join([f"{index + 1}. {relation}" for index, relation in enumerate(relations)])
68+
language = "Chinese" if detect_main_language(entities_str + relations_str) == "zh" else "English"
69+
70+
if add_context:
71+
original_ids = ([node['source_id'].split('<SEP>')[0] for node in _process_nodes] +
72+
[edge[2]['source_id'].split('<SEP>')[0] for edge in _process_edges])
73+
74+
original_ids = list(set(original_ids))
75+
original_text = await text_chunks_storage.get_by_ids(original_ids)
76+
original_text = "\n".join([f"{index + 1}. {text['content']}" for index, text in enumerate(original_text)])
77+
78+
prompt = ANSWER_REPHRASING_PROMPT[_difficulty][language]['CONTEXT_TEMPLATE'].format(
79+
language=language,
80+
original_text=original_text,
81+
entities=entities_str,
82+
relationships=relations_str
83+
)
84+
return prompt
85+
86+
prompt = ANSWER_REPHRASING_PROMPT[_difficulty][language]['TEMPLATE'].format(
87+
language=language,
88+
entities=entities_str,
89+
relationships=relations_str
90+
)
91+
return prompt
5292

5393
def get_loss_tercile(losses: list) -> (float, float):
5494
losses = sorted(losses)
@@ -61,11 +101,39 @@ def get_average_loss(batch: tuple) -> float:
61101
return sum(edge[2]['loss'] for edge in batch[1]) + sum(node['loss'] for node in batch[0]) / \
62102
(len(batch[0]) + len(batch[1]))
63103

104+
def _post_process_synthetic_data(data):
105+
block = data.split("\n\n")
106+
qas = []
107+
for line in block:
108+
if "Question:" in line and "Answer:" in line:
109+
question = line.split("Question:")[1].split("Answer:")[0].strip()
110+
answer = line.split("Answer:")[1].strip()
111+
qas.append({
112+
"question": question,
113+
"answer": answer
114+
})
115+
elif "问题:" in line and "答案:" in line:
116+
question = line.split("问题:")[1].split("答案:")[0].strip()
117+
answer = line.split("答案:")[1].strip()
118+
qas.append({
119+
"question": question,
120+
"answer": answer
121+
})
122+
elif "问题:" in line and "回答:" in line:
123+
question = line.split("问题:")[1].split("回答:")[0].strip()
124+
answer = line.split("回答:")[1].strip()
125+
qas.append({
126+
"question": question,
127+
"answer": answer
128+
})
129+
return qas
130+
64131
async def traverse_graph_by_edge(
65132
llm_client: OpenAIModel,
66133
tokenizer: Tokenizer,
67134
graph_storage: NetworkXStorage,
68135
traverse_strategy: TraverseStrategy,
136+
text_chunks_storage: JsonKVStorage,
69137
max_concurrent: int = 1000
70138
) -> dict:
71139
"""
@@ -75,6 +143,7 @@ async def traverse_graph_by_edge(
75143
:param tokenizer
76144
:param graph_storage
77145
:param traverse_strategy
146+
:param text_chunks_storage
78147
:param max_concurrent
79148
:return: question and answer
80149
"""
@@ -84,26 +153,15 @@ async def traverse_graph_by_edge(
84153
async def _process_nodes_and_edges(
85154
_process_nodes: list,
86155
_process_edges: list,
87-
_difficulty: str
156+
_difficulty: str,
88157
) -> str:
89-
entities = [
90-
f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes
91-
]
92-
relations = [
93-
f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
94-
for _process_edge in _process_edges
95-
]
96-
97-
entities_str = "\n".join([f"{index + 1}. {entity}" for index, entity in enumerate(entities)])
98-
relations_str = "\n".join([f"{index + 1}. {relation}" for index, relation in enumerate(relations)])
99-
100-
language = "Chinese" if detect_main_language(entities_str + relations_str) == "zh" else "English"
101-
prompt = ANSWER_REPHRASING_PROMPT[_difficulty][language]['TEMPLATE'].format(
102-
language=language,
103-
entities=entities_str,
104-
relationships=relations_str
158+
prompt = await _construct_rephrasing_prompt(
159+
_process_nodes,
160+
_process_edges,
161+
_difficulty,
162+
text_chunks_storage,
163+
add_context = False
105164
)
106-
107165
context = await llm_client.generate_answer(prompt)
108166

109167
# post-process the context
@@ -115,7 +173,8 @@ async def _process_nodes_and_edges(
115173
return context
116174

117175
async def _process_single_batch(
118-
_process_batch: tuple
176+
_process_batch: tuple,
177+
question_type: str = "single"
119178
) -> dict:
120179
async with semaphore:
121180
context = await _process_nodes_and_edges(
@@ -125,32 +184,55 @@ async def _process_single_batch(
125184
)
126185

127186
language = "Chinese" if detect_main_language(context) == "zh" else "English"
128-
question = await llm_client.generate_answer(
129-
QUESTION_GENERATION_PROMPT[language]['TEMPLATE'].format(
130-
answer=context
131-
)
132-
)
133-
134-
if question.startswith("Question:"):
135-
question = question[len("Question:"):].strip()
136-
elif question.startswith("问题:"):
137-
question = question[len("问题:"):].strip()
138-
139187
pre_length = sum(node['length'] for node in _process_batch[0]) \
140188
+ sum(edge[2]['length'] for edge in _process_batch[1])
141189

142190
logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
143191
logger.info("Pre-length: %s", pre_length)
144-
logger.info("Question: %s Answer: %s", question, context)
145192

146-
return {
147-
compute_content_hash(context): {
148-
"question": question,
149-
"answer": context,
193+
if question_type == "single":
194+
question = await llm_client.generate_answer(
195+
QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format(
196+
answer=context
197+
)
198+
)
199+
if question.startswith("Question:"):
200+
question = question[len("Question:"):].strip()
201+
elif question.startswith("问题:"):
202+
question = question[len("问题:"):].strip()
203+
204+
return {
205+
compute_content_hash(context): {
206+
"question": question,
207+
"answer": context,
208+
"loss": get_average_loss(_process_batch),
209+
"difficulty": _process_batch[2],
210+
}
211+
}
212+
213+
content = await llm_client.generate_answer(
214+
QUESTION_GENERATION_PROMPT[language]['MULTI_TEMPLATE'].format(
215+
doc=context
216+
)
217+
)
218+
qas = _post_process_synthetic_data(content)
219+
220+
if len(qas) == 0:
221+
print(content)
222+
logger.error("Error occurred while processing batch, question or answer is None")
223+
return {}
224+
225+
final_results = {}
226+
for qa in qas:
227+
logger.info("Question: %s", qa['question'])
228+
logger.info("Answer: %s", qa['answer'])
229+
final_results[compute_content_hash(qa['question'])] = {
230+
"question": qa['question'],
231+
"answer": qa['answer'],
150232
"loss": get_average_loss(_process_batch),
151233
"difficulty": _process_batch[2],
152234
}
153-
}
235+
return final_results
154236

155237
results = {}
156238
edges = list(await graph_storage.get_all_edges())

models/storage/json_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ async def index_done_callback(self):
2323
async def get_by_id(self, id):
2424
return self._data.get(id, None)
2525

26-
async def get_by_ids(self, ids, fields=None):
26+
async def get_by_ids(self, ids, fields=None) -> list:
2727
if fields is None:
2828
return [self._data.get(id, None) for id in ids]
2929
return [

scripts/generate.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
python3 generate.py --input_file resources/examples/raw_demo.jsonl \
2-
--data_type raw \
3-
# --web_search
1+
python3 generate.py --config_file configs/graphgen_config.yaml

scripts/judge.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 evaluate.py --output cache/output/new_graph.graphml \

0 commit comments

Comments
 (0)