Skip to content

Commit 6447f2a

Browse files
Merge pull request #52 from open-sciencelab/kg_builder
Refactor KG builder
2 parents 4ea9ba9 + b30f5a1 commit 6447f2a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+799
-516
lines changed

baselines/Genie/genie.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dotenv import load_dotenv
1111
from tqdm.asyncio import tqdm as tqdm_async
1212

13-
from graphgen.models import OpenAIModel
13+
from graphgen.models import OpenAIClient
1414
from graphgen.utils import compute_content_hash, create_event_loop
1515

1616
PROMPT_TEMPLATE = """Instruction: Given the next [document], create a [question] and [answer] pair that are grounded \
@@ -59,7 +59,7 @@ def _post_process(content: str) -> tuple:
5959

6060
@dataclass
6161
class Genie:
62-
llm_client: OpenAIModel = None
62+
llm_client: OpenAIClient = None
6363
max_concurrent: int = 1000
6464

6565
def generate(self, docs: List[List[dict]]) -> List[dict]:
@@ -121,7 +121,7 @@ async def process_chunk(content: str):
121121

122122
load_dotenv()
123123

124-
llm_client = OpenAIModel(
124+
llm_client = OpenAIClient(
125125
model_name=os.getenv("SYNTHESIZER_MODEL"),
126126
api_key=os.getenv("SYNTHESIZER_API_KEY"),
127127
base_url=os.getenv("SYNTHESIZER_BASE_URL"),

baselines/LongForm/longform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from dotenv import load_dotenv
1212
from tqdm.asyncio import tqdm as tqdm_async
1313

14-
from graphgen.models import OpenAIModel
14+
from graphgen.models import OpenAIClient
1515
from graphgen.utils import compute_content_hash, create_event_loop
1616

1717
PROMPT_TEMPLATE = """Instruction: X
@@ -23,7 +23,7 @@
2323

2424
@dataclass
2525
class LongForm:
26-
llm_client: OpenAIModel = None
26+
llm_client: OpenAIClient = None
2727
max_concurrent: int = 1000
2828

2929
def generate(self, docs: List[List[dict]]) -> List[dict]:
@@ -88,7 +88,7 @@ async def process_chunk(content: str):
8888

8989
load_dotenv()
9090

91-
llm_client = OpenAIModel(
91+
llm_client = OpenAIClient(
9292
model_name=os.getenv("SYNTHESIZER_MODEL"),
9393
api_key=os.getenv("SYNTHESIZER_API_KEY"),
9494
base_url=os.getenv("SYNTHESIZER_BASE_URL"),

baselines/SELF-QA/self-qa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dotenv import load_dotenv
1111
from tqdm.asyncio import tqdm as tqdm_async
1212

13-
from graphgen.models import OpenAIModel
13+
from graphgen.models import OpenAIClient
1414
from graphgen.utils import compute_content_hash, create_event_loop
1515

1616
INSTRUCTION_GENERATION_PROMPT = """The background knowledge is:
@@ -58,7 +58,7 @@ def _post_process_answers(content: str) -> tuple:
5858

5959
@dataclass
6060
class SelfQA:
61-
llm_client: OpenAIModel = None
61+
llm_client: OpenAIClient = None
6262
max_concurrent: int = 100
6363

6464
def generate(self, docs: List[List[dict]]) -> List[dict]:
@@ -155,7 +155,7 @@ async def process_chunk(content: str):
155155

156156
load_dotenv()
157157

158-
llm_client = OpenAIModel(
158+
llm_client = OpenAIClient(
159159
model_name=os.getenv("SYNTHESIZER_MODEL"),
160160
api_key=os.getenv("SYNTHESIZER_API_KEY"),
161161
base_url=os.getenv("SYNTHESIZER_BASE_URL"),

baselines/Wrap/wrap.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dotenv import load_dotenv
1111
from tqdm.asyncio import tqdm as tqdm_async
1212

13-
from graphgen.models import OpenAIModel
13+
from graphgen.models import OpenAIClient
1414
from graphgen.utils import compute_content_hash, create_event_loop
1515

1616
PROMPT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant.
@@ -46,7 +46,7 @@ def _post_process(content: str) -> list:
4646

4747
@dataclass
4848
class Wrap:
49-
llm_client: OpenAIModel = None
49+
llm_client: OpenAIClient = None
5050
max_concurrent: int = 1000
5151

5252
def generate(self, docs: List[List[dict]]) -> List[dict]:
@@ -108,7 +108,7 @@ async def process_chunk(content: str):
108108

109109
load_dotenv()
110110

111-
llm_client = OpenAIModel(
111+
llm_client = OpenAIClient(
112112
model_name=os.getenv("SYNTHESIZER_MODEL"),
113113
api_key=os.getenv("SYNTHESIZER_API_KEY"),
114114
base_url=os.getenv("SYNTHESIZER_BASE_URL"),

graphgen/bases/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .base_kg_builder import BaseKGBuilder
2+
from .base_llm_client import BaseLLMClient
3+
from .base_reader import BaseReader
4+
from .base_splitter import BaseSplitter
5+
from .base_storage import (
6+
BaseGraphStorage,
7+
BaseKVStorage,
8+
BaseListStorage,
9+
StorageNameSpace,
10+
)
11+
from .base_tokenizer import BaseTokenizer
12+
from .datatypes import Chunk, QAPair, Token

graphgen/bases/base_kg_builder.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from abc import ABC, abstractmethod
2+
from collections import defaultdict
3+
from dataclasses import dataclass, field
4+
from typing import Dict, List, Tuple
5+
6+
from graphgen.bases.base_llm_client import BaseLLMClient
7+
from graphgen.bases.base_storage import BaseGraphStorage
8+
from graphgen.bases.datatypes import Chunk
9+
10+
11+
@dataclass
12+
class BaseKGBuilder(ABC):
13+
kg_instance: BaseGraphStorage
14+
llm_client: BaseLLMClient
15+
16+
_nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
17+
_edges: Dict[Tuple[str, str], List[dict]] = field(
18+
default_factory=lambda: defaultdict(list)
19+
)
20+
21+
def build(self, chunks: List[Chunk]) -> None:
22+
pass
23+
24+
@abstractmethod
25+
async def extract_all(self, chunks: List[Chunk]) -> None:
26+
"""Extract nodes and edges from all chunks."""
27+
raise NotImplementedError
28+
29+
@abstractmethod
30+
async def extract(
31+
self, chunk: Chunk
32+
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
33+
"""Extract nodes and edges from a single chunk."""
34+
raise NotImplementedError
35+
36+
@abstractmethod
37+
async def merge_nodes(
38+
self, nodes_data: Dict[str, List[dict]], kg_instance: BaseGraphStorage, llm
39+
) -> None:
40+
"""Merge extracted nodes into the knowledge graph."""
41+
raise NotImplementedError

graphgen/bases/base_llm_client.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from __future__ import annotations
2+
3+
import abc
4+
import re
5+
from typing import Any, List, Optional
6+
7+
from graphgen.bases.base_tokenizer import BaseTokenizer
8+
from graphgen.bases.datatypes import Token
9+
10+
11+
class BaseLLMClient(abc.ABC):
12+
"""
13+
LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
14+
"""
15+
16+
def __init__(
17+
self,
18+
*,
19+
system_prompt: str = "",
20+
temperature: float = 0.0,
21+
max_tokens: int = 4096,
22+
repetition_penalty: float = 1.05,
23+
top_p: float = 0.95,
24+
top_k: int = 50,
25+
tokenizer: Optional[BaseTokenizer] = None,
26+
**kwargs: Any,
27+
):
28+
self.system_prompt = system_prompt
29+
self.temperature = temperature
30+
self.max_tokens = max_tokens
31+
self.repetition_penalty = repetition_penalty
32+
self.top_p = top_p
33+
self.top_k = top_k
34+
self.tokenizer = tokenizer
35+
36+
for k, v in kwargs.items():
37+
setattr(self, k, v)
38+
39+
@abc.abstractmethod
40+
async def generate_answer(
41+
self, text: str, history: Optional[List[str]] = None, **extra: Any
42+
) -> str:
43+
"""Generate answer from the model."""
44+
raise NotImplementedError
45+
46+
@abc.abstractmethod
47+
async def generate_topk_per_token(
48+
self, text: str, history: Optional[List[str]] = None, **extra: Any
49+
) -> List[Token]:
50+
"""Generate top-k tokens for the next token prediction."""
51+
raise NotImplementedError
52+
53+
@abc.abstractmethod
54+
async def generate_inputs_prob(
55+
self, text: str, history: Optional[List[str]] = None, **extra: Any
56+
) -> List[Token]:
57+
"""Generate probabilities for each token in the input."""
58+
raise NotImplementedError
59+
60+
def count_tokens(self, text: str) -> int:
61+
"""Count the number of tokens in the text."""
62+
if self.tokenizer is None:
63+
raise ValueError("Tokenizer is not set. Please provide a tokenizer to use count_tokens.")
64+
return len(self.tokenizer.encode(text))
65+
66+
@staticmethod
67+
def filter_think_tags(text: str, think_tag: str = "think") -> str:
68+
"""
69+
Remove <think> tags from the text.
70+
If the text contains <think> and </think>, it removes everything between them and the tags themselves.
71+
"""
72+
think_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL)
73+
filtered_text = think_pattern.sub("", text).strip()
74+
return filtered_text if filtered_text else text.strip()

graphgen/bases/base_tokenizer.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from dataclasses import dataclass
5+
from typing import List
6+
7+
8+
@dataclass
9+
class BaseTokenizer(ABC):
10+
model_name: str = "cl100k_base"
11+
12+
@abstractmethod
13+
def encode(self, text: str) -> List[int]:
14+
"""Encode text -> token ids."""
15+
raise NotImplementedError
16+
17+
@abstractmethod
18+
def decode(self, token_ids: List[int]) -> str:
19+
"""Decode token ids -> text."""
20+
raise NotImplementedError
21+
22+
def count_tokens(self, text: str) -> int:
23+
return len(self.encode(text))
24+
25+
def chunk_by_token_size(
26+
self,
27+
content: str,
28+
*,
29+
overlap_token_size: int = 128,
30+
max_token_size: int = 1024,
31+
) -> List[dict]:
32+
tokens = self.encode(content)
33+
results = []
34+
step = max_token_size - overlap_token_size
35+
for index, start in enumerate(range(0, len(tokens), step)):
36+
chunk_ids = tokens[start : start + max_token_size]
37+
results.append(
38+
{
39+
"tokens": len(chunk_ids),
40+
"content": self.decode(chunk_ids).strip(),
41+
"chunk_order_index": index,
42+
}
43+
)
44+
return results

graphgen/bases/datatypes.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import math
12
from dataclasses import dataclass, field
3+
from typing import List, Union
24

35

46
@dataclass
@@ -16,3 +18,15 @@ class QAPair:
1618

1719
question: str
1820
answer: str
21+
22+
23+
@dataclass
24+
class Token:
25+
text: str
26+
prob: float
27+
top_candidates: List = field(default_factory=list)
28+
ppl: Union[float, None] = field(default=None)
29+
30+
@property
31+
def logprob(self) -> float:
32+
return math.log(self.prob)

0 commit comments

Comments
 (0)