Skip to content

Commit 76b53fa

Browse files
fix: implement generate method
1 parent 9434fd3 commit 76b53fa

File tree

7 files changed

+54
-97
lines changed

7 files changed

+54
-97
lines changed

graphgen/generate.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,20 @@ def main():
5353
mode = config["generate"]["mode"]
5454
unique_id = int(time.time())
5555

56-
output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}_{mode}")
56+
output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}")
5757
set_working_dir(output_path)
5858

5959
set_logger(
60-
os.path.join(output_path, f"{unique_id}.log"),
60+
os.path.join(output_path, f"{unique_id}_{mode}.log"),
6161
if_stream=True,
6262
)
6363
logger.info(
6464
"GraphGen with unique ID %s logging to %s",
6565
unique_id,
66-
os.path.join(working_dir, f"{unique_id}.log"),
66+
os.path.join(working_dir, f"{unique_id}_{mode}.log"),
6767
)
6868

69-
graph_gen = GraphGen(working_dir=working_dir, output_path=output_path)
69+
graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir)
7070

7171
graph_gen.insert(read_config=config["read"], split_config=config["split"])
7272

@@ -81,8 +81,11 @@ def main():
8181
logger.warning(
8282
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
8383
)
84-
# TODO: make edge sampling random
85-
# graph_gen.traverse_strategy.edge_sampling = "random"
84+
assert (
85+
config["partition"]["method"] == "ece"
86+
and "ece_params" in config["partition"]
87+
), "Only ECE partition with edge sampling is supported."
88+
config["partition"]["ece_params"]["edge_sampling"] = "random"
8689
elif mode == "cot":
8790
logger.info("Generation mode set to 'cot'. Start generation.")
8891
else:

graphgen/graphgen.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,8 @@
3939

4040
@dataclass
4141
class GraphGen:
42+
unique_id: int = int(time.time())
4243
working_dir: str = os.path.join(sys_path, "cache")
43-
output_path: str = os.path.join(
44-
working_dir, "data", "graphgen", str(int(time.time()))
45-
)
4644

4745
# llm
4846
tokenizer_instance: Tokenizer = None
@@ -86,7 +84,7 @@ def __post_init__(self):
8684
self.working_dir, namespace="rephrase"
8785
)
8886
self.qa_storage: JsonListStorage = JsonListStorage(
89-
self.working_dir,
87+
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
9088
namespace="qa",
9189
)
9290

@@ -238,59 +236,49 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
238236
async def generate(self, partition_config: Dict, generate_config: Dict):
239237
# Step 1: partition the graph
240238
# TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage)
241-
pass
242-
243-
@async_to_sync_method
244-
async def traverse(self):
245-
output_data_type = self.config["output_data_type"]
246-
247-
if output_data_type == "atomic":
239+
mode = generate_config["mode"]
240+
if mode == "atomic":
248241
results = await traverse_graph_for_atomic(
249242
self.synthesizer_llm_client,
250243
self.tokenizer_instance,
251244
self.graph_storage,
252-
self.traverse_strategy,
245+
partition_config["ece_params"],
253246
self.text_chunks_storage,
254247
self.progress_bar,
255248
)
256-
elif output_data_type == "multi_hop":
249+
elif mode == "multi_hop":
257250
results = await traverse_graph_for_multi_hop(
258251
self.synthesizer_llm_client,
259252
self.tokenizer_instance,
260253
self.graph_storage,
261-
self.traverse_strategy,
254+
partition_config["ece_params"],
262255
self.text_chunks_storage,
263256
self.progress_bar,
264257
)
265-
elif output_data_type == "aggregated":
258+
elif mode == "aggregated":
266259
results = await traverse_graph_for_aggregated(
267260
self.synthesizer_llm_client,
268261
self.tokenizer_instance,
269262
self.graph_storage,
270-
self.traverse_strategy,
263+
partition_config["ece_params"],
271264
self.text_chunks_storage,
272265
self.progress_bar,
273266
)
267+
elif mode == "cot":
268+
method_params = generate_config.get("method_params", {})
269+
results = await generate_cot(
270+
self.graph_storage,
271+
self.synthesizer_llm_client,
272+
method_params=method_params,
273+
)
274274
else:
275-
raise ValueError(f"Unknown qa_form: {output_data_type}")
276-
277-
results = format_generation_results(
278-
results, output_data_format=self.config["output_data_format"]
279-
)
280-
281-
await self.qa_storage.upsert(results)
282-
await self.qa_storage.index_done_callback()
283-
284-
@async_to_sync_method
285-
async def generate_reasoning(self, method_params):
286-
results = await generate_cot(
287-
self.graph_storage,
288-
self.synthesizer_llm_client,
289-
method_params=method_params,
290-
)
275+
raise ValueError(f"Unknown generation mode: {mode}")
276+
# Step 2: generate QA pairs
277+
# TODO
291278

279+
# Step 3: format
292280
results = format_generation_results(
293-
results, output_data_format=self.config["output_data_format"]
281+
results, output_data_format=generate_config["data_format"]
294282
)
295283

296284
await self.qa_storage.upsert(results)

graphgen/models/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,4 @@
1313
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
1414
from .storage.json_storage import JsonKVStorage, JsonListStorage
1515
from .storage.networkx_storage import NetworkXStorage
16-
from .strategy.travserse_strategy import TraverseStrategy
1716
from .tokenizer import Tokenizer

graphgen/models/strategy/__init__.py

Whitespace-only changes.

graphgen/models/strategy/travserse_strategy.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

graphgen/operators/build_kg/split_kg.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import random
22
from collections import defaultdict
3+
from typing import Dict
34

45
from tqdm.asyncio import tqdm as tqdm_async
56

6-
from graphgen.models import NetworkXStorage, TraverseStrategy
7+
from graphgen.models import NetworkXStorage
78
from graphgen.utils import logger
89

910

@@ -247,18 +248,18 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches
247248
nodes: list,
248249
edges: list,
249250
graph_storage: NetworkXStorage,
250-
traverse_strategy: TraverseStrategy,
251+
traverse_strategy: Dict,
251252
):
252-
expand_method = traverse_strategy.expand_method
253+
expand_method = traverse_strategy["expand_method"]
253254
if expand_method == "max_width":
254255
logger.info("Using max width strategy")
255256
elif expand_method == "max_tokens":
256257
logger.info("Using max tokens strategy")
257258
else:
258259
raise ValueError(f"Invalid expand method: {expand_method}")
259260

260-
max_depth = traverse_strategy.max_depth
261-
edge_sampling = traverse_strategy.edge_sampling
261+
max_depth = traverse_strategy["max_depth"]
262+
edge_sampling = traverse_strategy["edge_sampling"]
262263

263264
# 构建临接矩阵
264265
edge_adj_list = defaultdict(list)
@@ -275,16 +276,16 @@ async def get_cached_node_info(node_id: str) -> dict:
275276
for i, (node_name, _) in enumerate(nodes):
276277
node_dict[node_name] = i
277278

278-
if traverse_strategy.loss_strategy == "both":
279+
if traverse_strategy["loss_strategy"] == "both":
279280
er_tuples = [
280281
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
281282
for edge in edges
282283
]
283284
edges = _sort_tuples(er_tuples, edge_sampling)
284-
elif traverse_strategy.loss_strategy == "only_edge":
285+
elif traverse_strategy["loss_strategy"] == "only_edge":
285286
edges = _sort_edges(edges, edge_sampling)
286287
else:
287-
raise ValueError(f"Invalid loss strategy: {traverse_strategy.loss_strategy}")
288+
raise ValueError(f"Invalid loss strategy: {traverse_strategy['loss_strategy']}")
288289

289290
for i, (src, tgt, _) in enumerate(edges):
290291
edge_adj_list[src].append(i)
@@ -315,10 +316,10 @@ async def get_cached_node_info(node_id: str) -> dict:
315316
nodes,
316317
edge,
317318
max_depth,
318-
traverse_strategy.bidirectional,
319-
traverse_strategy.max_extra_edges,
319+
traverse_strategy["bidirectional"],
320+
traverse_strategy["max_extra_edges"],
320321
edge_sampling,
321-
traverse_strategy.loss_strategy,
322+
traverse_strategy["loss_strategy"],
322323
)
323324
else:
324325
level_n_edges = _get_level_n_edges_by_max_tokens(
@@ -328,10 +329,10 @@ async def get_cached_node_info(node_id: str) -> dict:
328329
nodes,
329330
edge,
330331
max_depth,
331-
traverse_strategy.bidirectional,
332-
traverse_strategy.max_tokens,
332+
traverse_strategy["bidirectional"],
333+
traverse_strategy["max_tokens"],
333334
edge_sampling,
334-
traverse_strategy.loss_strategy,
335+
traverse_strategy["loss_strategy"],
335336
)
336337

337338
for _edge in level_n_edges:
@@ -352,7 +353,7 @@ async def get_cached_node_info(node_id: str) -> dict:
352353
logger.info("Processing batches: %d", len(processing_batches))
353354

354355
# isolate nodes
355-
isolated_node_strategy = traverse_strategy.isolated_node_strategy
356+
isolated_node_strategy = traverse_strategy["isolated_node_strategy"]
356357
if isolated_node_strategy == "add":
357358
processing_batches = await _add_isolated_nodes(
358359
nodes, processing_batches, graph_storage

graphgen/operators/traverse_graph.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
import asyncio
2+
from typing import Dict
23

34
import gradio as gr
45
from tqdm.asyncio import tqdm as tqdm_async
56

6-
from graphgen.models import (
7-
JsonKVStorage,
8-
NetworkXStorage,
9-
OpenAIClient,
10-
Tokenizer,
11-
TraverseStrategy,
12-
)
7+
from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient, Tokenizer
138
from graphgen.operators.build_kg.split_kg import get_batches_with_strategy
149
from graphgen.templates import (
1510
ANSWER_REPHRASING_PROMPT,
@@ -164,7 +159,7 @@ async def traverse_graph_for_aggregated(
164159
llm_client: OpenAIClient,
165160
tokenizer: Tokenizer,
166161
graph_storage: NetworkXStorage,
167-
traverse_strategy: TraverseStrategy,
162+
traverse_strategy: Dict,
168163
text_chunks_storage: JsonKVStorage,
169164
progress_bar: gr.Progress = None,
170165
max_concurrent: int = 1000,
@@ -240,7 +235,7 @@ async def _process_single_batch(
240235
"question": question,
241236
"answer": context,
242237
"loss": get_average_loss(
243-
_process_batch, traverse_strategy.loss_strategy
238+
_process_batch, traverse_strategy["loss_strategy"]
244239
),
245240
}
246241
}
@@ -272,7 +267,7 @@ async def _process_single_batch(
272267
"question": qa["question"],
273268
"answer": qa["answer"],
274269
"loss": get_average_loss(
275-
_process_batch, traverse_strategy.loss_strategy
270+
_process_batch, traverse_strategy["loss_strategy"]
276271
),
277272
}
278273
return final_results
@@ -313,7 +308,7 @@ async def traverse_graph_for_atomic(
313308
llm_client: OpenAIClient,
314309
tokenizer: Tokenizer,
315310
graph_storage: NetworkXStorage,
316-
traverse_strategy: TraverseStrategy,
311+
traverse_strategy: Dict,
317312
text_chunks_storage: JsonKVStorage,
318313
progress_bar: gr.Progress = None,
319314
max_concurrent: int = 1000,
@@ -331,7 +326,6 @@ async def traverse_graph_for_atomic(
331326
:return: question and answer
332327
"""
333328

334-
assert traverse_strategy.qa_form == "atomic"
335329
semaphore = asyncio.Semaphore(max_concurrent)
336330

337331
def _parse_qa(qa: str) -> tuple:
@@ -429,7 +423,7 @@ async def traverse_graph_for_multi_hop(
429423
llm_client: OpenAIClient,
430424
tokenizer: Tokenizer,
431425
graph_storage: NetworkXStorage,
432-
traverse_strategy: TraverseStrategy,
426+
traverse_strategy: Dict,
433427
text_chunks_storage: JsonKVStorage,
434428
progress_bar: gr.Progress = None,
435429
max_concurrent: int = 1000,
@@ -517,7 +511,7 @@ async def _process_single_batch(_process_batch: tuple) -> dict:
517511
"question": question,
518512
"answer": answer,
519513
"loss": get_average_loss(
520-
_process_batch, traverse_strategy.loss_strategy
514+
_process_batch, traverse_strategy["loss_strategy"]
521515
),
522516
}
523517
}

0 commit comments

Comments
 (0)