Skip to content

Commit 6e4a142

Browse files
Merge pull request #74 from open-sciencelab/feature/inference-backend
feat: add inference backends
2 parents 0cc61c5 + 03e6d23 commit 6e4a142

32 files changed

+1238
-113
lines changed

.env.example

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,29 @@
1+
# Tokenizer
12
TOKENIZER_MODEL=
2-
SYNTHESIZER_MODEL=
3+
4+
# LLM
5+
# Support different backends: http_api, openai_api, ollama_api, ollama, huggingface, tgi, sglang, tensorrt
6+
7+
# http_api / openai_api
8+
SYNTHESIZER_BACKEND=openai_api
9+
SYNTHESIZER_MODEL=gpt-4o-mini
310
SYNTHESIZER_BASE_URL=
411
SYNTHESIZER_API_KEY=
5-
TRAINEE_MODEL=
12+
TRAINEE_BACKEND=openai_api
13+
TRAINEE_MODEL=gpt-4o-mini
614
TRAINEE_BASE_URL=
715
TRAINEE_API_KEY=
16+
17+
# # ollama_api
18+
# SYNTHESIZER_BACKEND=ollama_api
19+
# SYNTHESIZER_MODEL=gemma3
20+
# SYNTHESIZER_BASE_URL=http://localhost:11434
21+
#
22+
# Note: TRAINEE with ollama_api backend is not supported yet as ollama_api does not support logprobs.
23+
24+
# # huggingface
25+
# SYNTHESIZER_BACKEND=huggingface
26+
# SYNTHESIZER_MODEL=Qwen/Qwen2.5-0.5B-Instruct
27+
#
28+
# TRAINEE_BACKEND=huggingface
29+
# TRAINEE_MODEL=Qwen/Qwen2.5-0.5B-Instruct

graphgen/bases/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .base_generator import BaseGenerator
22
from .base_kg_builder import BaseKGBuilder
3-
from .base_llm_client import BaseLLMClient
3+
from .base_llm_wrapper import BaseLLMWrapper
44
from .base_partitioner import BasePartitioner
55
from .base_reader import BaseReader
66
from .base_splitter import BaseSplitter

graphgen/bases/base_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from abc import ABC, abstractmethod
22
from typing import Any
33

4-
from graphgen.bases.base_llm_client import BaseLLMClient
4+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
55

66

77
class BaseGenerator(ABC):
88
"""
99
Generate QAs based on given prompts.
1010
"""
1111

12-
def __init__(self, llm_client: BaseLLMClient):
12+
def __init__(self, llm_client: BaseLLMWrapper):
1313
self.llm_client = llm_client
1414

1515
@staticmethod

graphgen/bases/base_kg_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from collections import defaultdict
33
from typing import Dict, List, Tuple
44

5-
from graphgen.bases.base_llm_client import BaseLLMClient
5+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
66
from graphgen.bases.base_storage import BaseGraphStorage
77
from graphgen.bases.datatypes import Chunk
88

99

1010
class BaseKGBuilder(ABC):
11-
def __init__(self, llm_client: BaseLLMClient):
11+
def __init__(self, llm_client: BaseLLMWrapper):
1212
self.llm_client = llm_client
1313
self._nodes: Dict[str, List[dict]] = defaultdict(list)
1414
self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from graphgen.bases.datatypes import Token
99

1010

11-
class BaseLLMClient(abc.ABC):
11+
class BaseLLMWrapper(abc.ABC):
1212
"""
1313
LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
1414
"""
@@ -66,3 +66,9 @@ def filter_think_tags(text: str, think_tag: str = "think") -> str:
6666
think_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL)
6767
filtered_text = think_pattern.sub("", text).strip()
6868
return filtered_text if filtered_text else text.strip()
69+
70+
def shutdown(self) -> None:
71+
"""Shutdown the LLM engine if applicable."""
72+
73+
def restart(self) -> None:
74+
"""Reinitialize the LLM engine if applicable."""

graphgen/graphgen.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import asyncio
22
import os
33
import time
4-
from dataclasses import dataclass
54
from typing import Dict, cast
65

76
import gradio as gr
87

8+
from graphgen.bases import BaseLLMWrapper
99
from graphgen.bases.base_storage import StorageNameSpace
1010
from graphgen.bases.datatypes import Chunk
1111
from graphgen.models import (
@@ -20,6 +20,7 @@
2020
build_text_kg,
2121
chunk_documents,
2222
generate_qas,
23+
init_llm,
2324
judge_statement,
2425
partition_kg,
2526
quiz,
@@ -31,40 +32,28 @@
3132
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
3233

3334

34-
@dataclass
3535
class GraphGen:
36-
unique_id: int = int(time.time())
37-
working_dir: str = os.path.join(sys_path, "cache")
38-
39-
# llm
40-
tokenizer_instance: Tokenizer = None
41-
synthesizer_llm_client: OpenAIClient = None
42-
trainee_llm_client: OpenAIClient = None
43-
44-
# webui
45-
progress_bar: gr.Progress = None
46-
47-
def __post_init__(self):
48-
self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer(
36+
def __init__(
37+
self,
38+
unique_id: int = int(time.time()),
39+
working_dir: str = os.path.join(sys_path, "cache"),
40+
tokenizer_instance: Tokenizer = None,
41+
synthesizer_llm_client: OpenAIClient = None,
42+
trainee_llm_client: OpenAIClient = None,
43+
progress_bar: gr.Progress = None,
44+
):
45+
self.unique_id: int = unique_id
46+
self.working_dir: str = working_dir
47+
48+
# llm
49+
self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer(
4950
model_name=os.getenv("TOKENIZER_MODEL")
5051
)
5152

52-
self.synthesizer_llm_client: OpenAIClient = (
53-
self.synthesizer_llm_client
54-
or OpenAIClient(
55-
model_name=os.getenv("SYNTHESIZER_MODEL"),
56-
api_key=os.getenv("SYNTHESIZER_API_KEY"),
57-
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
58-
tokenizer=self.tokenizer_instance,
59-
)
60-
)
61-
62-
self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient(
63-
model_name=os.getenv("TRAINEE_MODEL"),
64-
api_key=os.getenv("TRAINEE_API_KEY"),
65-
base_url=os.getenv("TRAINEE_BASE_URL"),
66-
tokenizer=self.tokenizer_instance,
53+
self.synthesizer_llm_client: BaseLLMWrapper = (
54+
synthesizer_llm_client or init_llm("synthesizer")
6755
)
56+
self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client
6857

6958
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
7059
self.working_dir, namespace="full_docs"
@@ -86,6 +75,9 @@ def __post_init__(self):
8675
namespace="qa",
8776
)
8877

78+
# webui
79+
self.progress_bar: gr.Progress = progress_bar
80+
8981
@async_to_sync_method
9082
async def insert(self, read_config: Dict, split_config: Dict):
9183
"""
@@ -272,16 +264,29 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
272264
)
273265

274266
# TODO: assert trainee_llm_client is valid before judge
267+
if not self.trainee_llm_client:
268+
# TODO: shutdown existing synthesizer_llm_client properly
269+
logger.info("No trainee LLM client provided, initializing a new one.")
270+
self.synthesizer_llm_client.shutdown()
271+
self.trainee_llm_client = init_llm("trainee")
272+
275273
re_judge = quiz_and_judge_config["re_judge"]
276274
_update_relations = await judge_statement(
277275
self.trainee_llm_client,
278276
self.graph_storage,
279277
self.rephrase_storage,
280278
re_judge,
281279
)
280+
282281
await self.rephrase_storage.index_done_callback()
283282
await _update_relations.index_done_callback()
284283

284+
logger.info("Shutting down trainee LLM client.")
285+
self.trainee_llm_client.shutdown()
286+
self.trainee_llm_client = None
287+
logger.info("Restarting synthesizer LLM client.")
288+
self.synthesizer_llm_client.restart()
289+
285290
@async_to_sync_method
286291
async def generate(self, partition_config: Dict, generate_config: Dict):
287292
# Step 1: partition the graph

graphgen/models/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
VQAGenerator,
88
)
99
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
10-
from .llm.openai_client import OpenAIClient
11-
from .llm.topk_token_model import TopkTokenModel
10+
from .llm import HTTPClient, OllamaClient, OpenAIClient
1211
from .partitioner import (
1312
AnchorBFSPartitioner,
1413
BFSPartitioner,

graphgen/models/kg_builder/light_rag_kg_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import Counter, defaultdict
33
from typing import Dict, List, Tuple
44

5-
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
5+
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMWrapper, Chunk
66
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
77
from graphgen.utils import (
88
detect_main_language,
@@ -15,7 +15,7 @@
1515

1616

1717
class LightRAGKGBuilder(BaseKGBuilder):
18-
def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3):
18+
def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3):
1919
super().__init__(llm_client)
2020
self.max_loop = max_loop
2121

graphgen/models/llm/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .api.http_client import HTTPClient
2+
from .api.ollama_client import OllamaClient
3+
from .api.openai_client import OpenAIClient
4+
from .local.hf_wrapper import HuggingFaceWrapper

graphgen/models/llm/api/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)