Skip to content

Commit 3e2cf1e

Browse files
Merge pull request #78 from open-sciencelab/refactor/insert
refactor: refactor graphgen.insert for scalability
2 parents 45cea6f + 794cc75 commit 3e2cf1e

File tree

6 files changed

+99
-106
lines changed

6 files changed

+99
-106
lines changed

graphgen/bases/datatypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def from_dict(key: str, data: dict) -> "Chunk":
1515
return Chunk(
1616
id=key,
1717
content=data.get("content", ""),
18-
type=data.get("type", "unknown"),
18+
type=data.get("type", "text"),
1919
metadata={k: v for k, v in data.items() if k != "content"},
2020
)
2121

graphgen/graphgen.py

Lines changed: 36 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
Tokenizer,
1717
)
1818
from graphgen.operators import (
19-
build_mm_kg,
20-
build_text_kg,
19+
build_kg,
2120
chunk_documents,
2221
generate_qas,
2322
init_llm,
@@ -96,109 +95,45 @@ async def insert(self, read_config: Dict, split_config: Dict):
9695
new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
9796
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
9897
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
99-
new_text_docs = {k: v for k, v in new_docs.items() if v.get("type") == "text"}
100-
new_mm_docs = {k: v for k, v in new_docs.items() if v.get("type") != "text"}
101-
102-
await self.full_docs_storage.upsert(new_docs)
103-
104-
async def _insert_text_docs(text_docs):
105-
if len(text_docs) == 0:
106-
logger.warning("All text docs are already in the storage")
107-
return
108-
logger.info("[New Docs] inserting %d text docs", len(text_docs))
109-
# Step 2.1: Split chunks and filter existing ones
110-
inserting_chunks = await chunk_documents(
111-
text_docs,
112-
split_config["chunk_size"],
113-
split_config["chunk_overlap"],
114-
self.tokenizer_instance,
115-
self.progress_bar,
116-
)
11798

118-
_add_chunk_keys = await self.chunks_storage.filter_keys(
119-
list(inserting_chunks.keys())
120-
)
121-
inserting_chunks = {
122-
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
123-
}
124-
125-
if len(inserting_chunks) == 0:
126-
logger.warning("All text chunks are already in the storage")
127-
return
128-
129-
logger.info("[New Chunks] inserting %d text chunks", len(inserting_chunks))
130-
await self.chunks_storage.upsert(inserting_chunks)
131-
132-
# Step 2.2: Extract entities and relations from text chunks
133-
logger.info("[Text Entity and Relation Extraction] processing ...")
134-
_add_entities_and_relations = await build_text_kg(
135-
llm_client=self.synthesizer_llm_client,
136-
kg_instance=self.graph_storage,
137-
chunks=[
138-
Chunk(id=k, content=v["content"], type="text")
139-
for k, v in inserting_chunks.items()
140-
],
141-
progress_bar=self.progress_bar,
142-
)
143-
if not _add_entities_and_relations:
144-
logger.warning("No entities or relations extracted from text chunks")
145-
return
146-
147-
await self._insert_done()
148-
return _add_entities_and_relations
149-
150-
async def _insert_multi_modal_docs(mm_docs):
151-
if len(mm_docs) == 0:
152-
logger.warning("No multi-modal documents to insert")
153-
return
154-
155-
logger.info("[New Docs] inserting %d multi-modal docs", len(mm_docs))
156-
157-
# Step 3.1: Transform multi-modal documents into chunks and filter existing ones
158-
inserting_chunks = await chunk_documents(
159-
mm_docs,
160-
split_config["chunk_size"],
161-
split_config["chunk_overlap"],
162-
self.tokenizer_instance,
163-
self.progress_bar,
164-
)
99+
if len(new_docs) == 0:
100+
logger.warning("All documents are already in the storage")
101+
return
165102

166-
_add_chunk_keys = await self.chunks_storage.filter_keys(
167-
list(inserting_chunks.keys())
168-
)
169-
inserting_chunks = {
170-
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
171-
}
103+
inserting_chunks = await chunk_documents(
104+
new_docs,
105+
split_config["chunk_size"],
106+
split_config["chunk_overlap"],
107+
self.tokenizer_instance,
108+
self.progress_bar,
109+
)
172110

173-
if len(inserting_chunks) == 0:
174-
logger.warning("All multi-modal chunks are already in the storage")
175-
return
111+
_add_chunk_keys = await self.chunks_storage.filter_keys(
112+
list(inserting_chunks.keys())
113+
)
114+
inserting_chunks = {
115+
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
116+
}
176117

177-
logger.info(
178-
"[New Chunks] inserting %d multimodal chunks", len(inserting_chunks)
179-
)
180-
await self.chunks_storage.upsert(inserting_chunks)
181-
182-
# Step 3.2: Extract multi-modal entities and relations from chunks
183-
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
184-
_add_entities_and_relations = await build_mm_kg(
185-
llm_client=self.synthesizer_llm_client,
186-
kg_instance=self.graph_storage,
187-
chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
188-
progress_bar=self.progress_bar,
189-
)
190-
if not _add_entities_and_relations:
191-
logger.warning(
192-
"No entities or relations extracted from multi-modal chunks"
193-
)
194-
return
195-
await self._insert_done()
196-
return _add_entities_and_relations
197-
198-
# Step 2: Insert text documents
199-
await _insert_text_docs(new_text_docs)
200-
# Step 3: Insert multi-modal documents
201-
await _insert_multi_modal_docs(new_mm_docs)
118+
if len(inserting_chunks) == 0:
119+
logger.warning("All chunks are already in the storage")
120+
return
121+
122+
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
123+
await self.chunks_storage.upsert(inserting_chunks)
124+
125+
_add_entities_and_relations = await build_kg(
126+
llm_client=self.synthesizer_llm_client,
127+
kg_instance=self.graph_storage,
128+
chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
129+
progress_bar=self.progress_bar,
130+
)
131+
if not _add_entities_and_relations:
132+
logger.warning("No entities or relations extracted from text chunks")
133+
return
134+
135+
await self._insert_done()
136+
return _add_entities_and_relations
202137

203138
async def _insert_done(self):
204139
tasks = []

graphgen/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .build_kg import build_mm_kg, build_text_kg
1+
from .build_kg import build_kg
22
from .generate import generate_qas
33
from .init import init_llm
44
from .judge import judge_statement
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
from .build_mm_kg import build_mm_kg
2-
from .build_text_kg import build_text_kg
1+
from .build_kg import build_kg
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import List
2+
3+
import gradio as gr
4+
5+
from graphgen.bases import BaseLLMWrapper
6+
from graphgen.bases.base_storage import BaseGraphStorage
7+
from graphgen.bases.datatypes import Chunk
8+
from graphgen.utils import logger
9+
10+
from .build_mm_kg import build_mm_kg
11+
from .build_text_kg import build_text_kg
12+
13+
14+
async def build_kg(
15+
llm_client: BaseLLMWrapper,
16+
kg_instance: BaseGraphStorage,
17+
chunks: List[Chunk],
18+
progress_bar: gr.Progress = None,
19+
):
20+
"""
21+
Build knowledge graph (KG) and merge into kg_instance
22+
:param llm_client: Synthesizer LLM model to extract entities and relationships
23+
:param kg_instance
24+
:param chunks
25+
:param anchor_type: get this type of information from chunks
26+
:param progress_bar: Gradio progress bar to show the progress of the extraction
27+
:return:
28+
"""
29+
30+
text_chunks = [chunk for chunk in chunks if chunk.type == "text"]
31+
mm_chunks = [
32+
chunk
33+
for chunk in chunks
34+
if chunk.type in ("image", "video", "table", "formula")
35+
]
36+
37+
if len(text_chunks) == 0:
38+
logger.info("All text chunks are already in the storage")
39+
else:
40+
logger.info("[Text Entity and Relation Extraction] processing ...")
41+
await build_text_kg(
42+
llm_client=llm_client,
43+
kg_instance=kg_instance,
44+
chunks=text_chunks,
45+
progress_bar=progress_bar,
46+
)
47+
48+
if len(mm_chunks) == 0:
49+
logger.info("All multi-modal chunks are already in the storage")
50+
else:
51+
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
52+
await build_mm_kg(
53+
llm_client=llm_client,
54+
kg_instance=kg_instance,
55+
chunks=mm_chunks,
56+
progress_bar=progress_bar,
57+
)
58+
59+
return kg_instance

graphgen/operators/generate/generate_qas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async def generate_qas(
4040
generator = MultiHopGenerator(llm_client)
4141
elif mode == "cot":
4242
generator = CoTGenerator(llm_client)
43-
elif mode == "vqa":
43+
elif mode in ["vqa"]:
4444
generator = VQAGenerator(llm_client)
4545
else:
4646
raise ValueError(f"Unsupported generation mode: {mode}")

0 commit comments

Comments
 (0)