11import asyncio
22import os
33import time
4- from dataclasses import dataclass
54from typing import Dict , cast
65
76import gradio as gr
87
8+ from graphgen .bases import BaseLLMWrapper
99from graphgen .bases .base_storage import StorageNameSpace
1010from graphgen .bases .datatypes import Chunk
1111from graphgen .models import (
2020 build_text_kg ,
2121 chunk_documents ,
2222 generate_qas ,
23+ init_llm ,
2324 judge_statement ,
2425 partition_kg ,
2526 quiz ,
3132sys_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." ))
3233
3334
34- @dataclass
3535class GraphGen :
36- unique_id : int = int (time .time ())
37- working_dir : str = os .path .join (sys_path , "cache" )
38-
39- # llm
40- tokenizer_instance : Tokenizer = None
41- synthesizer_llm_client : OpenAIClient = None
42- trainee_llm_client : OpenAIClient = None
43-
44- # webui
45- progress_bar : gr .Progress = None
46-
47- def __post_init__ (self ):
48- self .tokenizer_instance : Tokenizer = self .tokenizer_instance or Tokenizer (
36+ def __init__ (
37+ self ,
38+ unique_id : int = int (time .time ()),
39+ working_dir : str = os .path .join (sys_path , "cache" ),
40+ tokenizer_instance : Tokenizer = None ,
41+ synthesizer_llm_client : OpenAIClient = None ,
42+ trainee_llm_client : OpenAIClient = None ,
43+ progress_bar : gr .Progress = None ,
44+ ):
45+ self .unique_id : int = unique_id
46+ self .working_dir : str = working_dir
47+
48+ # llm
49+ self .tokenizer_instance : Tokenizer = tokenizer_instance or Tokenizer (
4950 model_name = os .getenv ("TOKENIZER_MODEL" )
5051 )
5152
52- self .synthesizer_llm_client : OpenAIClient = (
53- self .synthesizer_llm_client
54- or OpenAIClient (
55- model_name = os .getenv ("SYNTHESIZER_MODEL" ),
56- api_key = os .getenv ("SYNTHESIZER_API_KEY" ),
57- base_url = os .getenv ("SYNTHESIZER_BASE_URL" ),
58- tokenizer = self .tokenizer_instance ,
59- )
60- )
61-
62- self .trainee_llm_client : OpenAIClient = self .trainee_llm_client or OpenAIClient (
63- model_name = os .getenv ("TRAINEE_MODEL" ),
64- api_key = os .getenv ("TRAINEE_API_KEY" ),
65- base_url = os .getenv ("TRAINEE_BASE_URL" ),
66- tokenizer = self .tokenizer_instance ,
53+ self .synthesizer_llm_client : BaseLLMWrapper = (
54+ synthesizer_llm_client or init_llm ("synthesizer" )
6755 )
56+ self .trainee_llm_client : BaseLLMWrapper = trainee_llm_client
6857
6958 self .full_docs_storage : JsonKVStorage = JsonKVStorage (
7059 self .working_dir , namespace = "full_docs"
@@ -86,6 +75,9 @@ def __post_init__(self):
8675 namespace = "qa" ,
8776 )
8877
78+ # webui
79+ self .progress_bar : gr .Progress = progress_bar
80+
8981 @async_to_sync_method
9082 async def insert (self , read_config : Dict , split_config : Dict ):
9183 """
@@ -272,16 +264,29 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
272264 )
273265
274266 # TODO: assert trainee_llm_client is valid before judge
267+ if not self .trainee_llm_client :
268+ # TODO: shutdown existing synthesizer_llm_client properly
269+ logger .info ("No trainee LLM client provided, initializing a new one." )
270+ self .synthesizer_llm_client .shutdown ()
271+ self .trainee_llm_client = init_llm ("trainee" )
272+
275273 re_judge = quiz_and_judge_config ["re_judge" ]
276274 _update_relations = await judge_statement (
277275 self .trainee_llm_client ,
278276 self .graph_storage ,
279277 self .rephrase_storage ,
280278 re_judge ,
281279 )
280+
282281 await self .rephrase_storage .index_done_callback ()
283282 await _update_relations .index_done_callback ()
284283
284+ logger .info ("Shutting down trainee LLM client." )
285+ self .trainee_llm_client .shutdown ()
286+ self .trainee_llm_client = None
287+ logger .info ("Restarting synthesizer LLM client." )
288+ self .synthesizer_llm_client .restart ()
289+
285290 @async_to_sync_method
286291 async def generate (self , partition_config : Dict , generate_config : Dict ):
287292 # Step 1: partition the graph
0 commit comments