Skip to content

Commit 22aae9a

Browse files
feat: add vqa_generator
1 parent 6fa1537 commit 22aae9a

File tree

3 files changed

+92
-2
lines changed

3 files changed

+92
-2
lines changed

graphgen/graphgen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,10 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
286286
async def generate(self, partition_config: Dict, generate_config: Dict):
287287
# Step 1: partition the graph
288288
batches = await partition_kg(
289-
self.graph_storage, self.tokenizer_instance, partition_config
289+
self.graph_storage,
290+
self.chunks_storage,
291+
self.tokenizer_instance,
292+
partition_config,
290293
)
291294

292295
# Step 2: generate QA pairs

graphgen/models/generator/vqa_generator.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,79 @@ def parse_response(response: str) -> Any:
6060
"answer": answer,
6161
}
6262
return qa_pairs
63+
64+
async def generate(
65+
self,
66+
batch: tuple[
67+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
68+
],
69+
) -> dict[str, Any]:
70+
"""
71+
Generate QAs based on a given batch.
72+
:param batch
73+
:return: QA pairs
74+
"""
75+
result = {}
76+
prompt = self.build_prompt(batch)
77+
response = await self.llm_client.generate_answer(prompt)
78+
qa_pairs = self.parse_response(response) # generate one or more QA pairs
79+
nodes, _ = batch
80+
for node in nodes:
81+
node_data = node[1]
82+
if "images" in node_data and node_data["images"]:
83+
img_path = node_data["images"]
84+
for qa in qa_pairs.values():
85+
qa["img_path"] = img_path
86+
result.update(qa_pairs)
87+
return result
88+
89+
@staticmethod
90+
def format_generation_results(
91+
results: list[dict], output_data_format: str
92+
) -> list[dict[str, Any]]:
93+
if output_data_format == "Alpaca":
94+
results = [
95+
{
96+
"instruction": v["question"],
97+
"input": "",
98+
"output": v["answer"],
99+
"image": v.get("img_path", ""),
100+
}
101+
for item in results
102+
for k, v in item.items()
103+
]
104+
elif output_data_format == "Sharegpt":
105+
results = [
106+
{
107+
"conversations": [
108+
{
109+
"from": "human",
110+
"value": [
111+
{"text": v["question"], "image": v.get("img_path", "")}
112+
],
113+
},
114+
{"from": "gpt", "value": v["answer"]},
115+
]
116+
}
117+
for item in results
118+
for k, v in item.items()
119+
]
120+
elif output_data_format == "ChatML":
121+
results = [
122+
{
123+
"messages": [
124+
{
125+
"role": "user",
126+
"content": [
127+
{"text": v["question"], "image": v.get("img_path", "")}
128+
],
129+
},
130+
{"role": "assistant", "content": v["answer"]},
131+
]
132+
}
133+
for item in results
134+
for k, v in item.items()
135+
]
136+
else:
137+
raise ValueError(f"Unknown output data format: {output_data_format}")
138+
return results

graphgen/operators/partition/partition_kg.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any
22

3-
from graphgen.bases import BaseGraphStorage, BaseTokenizer
3+
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer
44
from graphgen.models import (
55
AnchorBFSPartitioner,
66
BFSPartitioner,
@@ -15,6 +15,7 @@
1515

1616
async def partition_kg(
1717
kg_instance: BaseGraphStorage,
18+
chunk_storage: BaseKVStorage,
1819
tokenizer: Any = BaseTokenizer,
1920
partition_config: dict = None,
2021
) -> list[
@@ -54,4 +55,14 @@ async def partition_kg(
5455
communities = await partitioner.partition(g=kg_instance, **method_params)
5556
logger.info("Partitioned the graph into %d communities.", len(communities))
5657
batches = await partitioner.community2batch(communities, g=kg_instance)
58+
59+
for _, batch in enumerate(batches):
60+
nodes, edges = batch
61+
for node_id, node_data in nodes:
62+
entity_type = node_data.get("entity_type")
63+
if "image" in entity_type.lower():
64+
node_id = node_id.strip('"').lower()
65+
image_data = await chunk_storage.get_by_id(node_id)
66+
if image_data:
67+
node_data["images"] = image_data
5768
return batches

0 commit comments

Comments
 (0)