Skip to content

Commit fb3fc25

Browse files
fix: fix webui
1 parent 05c827d commit fb3fc25

File tree

3 files changed

+52
-39
lines changed

3 files changed

+52
-39
lines changed

graphgen/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def main():
8585
config["partition"]["method"] == "ece"
8686
and "ece_params" in config["partition"]
8787
), "Only ECE partition with edge sampling is supported."
88-
config["partition"]["ece_params"]["edge_sampling"] = "random"
88+
config["partition"]["method_params"]["edge_sampling"] = "random"
8989
elif mode == "cot":
9090
logger.info("Generation mode set to 'cot'. Start generation.")
9191
else:

graphgen/graphgen.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,21 @@ class GraphGen:
5151
progress_bar: gr.Progress = None
5252

5353
def __post_init__(self):
54-
self.tokenizer_instance: Tokenizer = Tokenizer(
54+
self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer(
5555
model_name=os.getenv("TOKENIZER_MODEL")
5656
)
5757

58-
self.synthesizer_llm_client: OpenAIClient = OpenAIClient(
59-
model_name=os.getenv("SYNTHESIZER_MODEL"),
60-
api_key=os.getenv("SYNTHESIZER_API_KEY"),
61-
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
62-
tokenizer=self.tokenizer_instance,
58+
self.synthesizer_llm_client: OpenAIClient = (
59+
self.synthesizer_llm_client
60+
or OpenAIClient(
61+
model_name=os.getenv("SYNTHESIZER_MODEL"),
62+
api_key=os.getenv("SYNTHESIZER_API_KEY"),
63+
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
64+
tokenizer=self.tokenizer_instance,
65+
)
6366
)
64-
self.trainee_llm_client: OpenAIClient = OpenAIClient(
67+
68+
self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient(
6569
model_name=os.getenv("TRAINEE_MODEL"),
6670
api_key=os.getenv("TRAINEE_API_KEY"),
6771
base_url=os.getenv("TRAINEE_BASE_URL"),

webui/app.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,32 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
3939
set_logger(log_file, if_stream=True)
4040
os.environ.update({k: str(v) for k, v in env.items()})
4141

42-
graph_gen = GraphGen(working_dir=working_dir, config=config)
43-
# Set up LLM clients
44-
graph_gen.synthesizer_llm_client = OpenAIClient(
42+
tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
43+
synthesizer_llm_client = OpenAIClient(
4544
model_name=env.get("SYNTHESIZER_MODEL", ""),
4645
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
4746
api_key=env.get("SYNTHESIZER_API_KEY", ""),
4847
request_limit=True,
4948
rpm=RPM(env.get("RPM", 1000)),
5049
tpm=TPM(env.get("TPM", 50000)),
50+
tokenizer=tokenizer_instance,
5151
)
52-
53-
graph_gen.trainee_llm_client = OpenAIClient(
52+
trainee_llm_client = OpenAIClient(
5453
model_name=env.get("TRAINEE_MODEL", ""),
5554
base_url=env.get("TRAINEE_BASE_URL", ""),
5655
api_key=env.get("TRAINEE_API_KEY", ""),
5756
request_limit=True,
5857
rpm=RPM(env.get("RPM", 1000)),
5958
tpm=TPM(env.get("TPM", 50000)),
59+
tokenizer=tokenizer_instance,
6060
)
6161

62-
graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
62+
graph_gen = GraphGen(
63+
working_dir=working_dir,
64+
tokenizer_instance=tokenizer_instance,
65+
synthesizer_llm_client=synthesizer_llm_client,
66+
trainee_llm_client=trainee_llm_client,
67+
)
6368

6469
return graph_gen
6570

@@ -78,27 +83,32 @@ def sum_tokens(client):
7883
"chunk_size": params.chunk_size,
7984
"chunk_overlap": params.chunk_overlap,
8085
},
81-
"output_data_type": params.output_data_type,
82-
"output_data_format": params.output_data_format,
83-
"tokenizer": params.tokenizer,
8486
"search": {"enabled": False},
85-
"quiz_and_judge_strategy": {
87+
"quiz_and_judge": {
8688
"enabled": params.if_trainee_model,
8789
"quiz_samples": params.quiz_samples,
8890
},
89-
"traverse_strategy": {
90-
"bidirectional": params.bidirectional,
91-
"expand_method": params.expand_method,
92-
"max_extra_edges": params.max_extra_edges,
93-
"max_tokens": params.max_tokens,
94-
"max_depth": params.max_depth,
95-
"edge_sampling": params.edge_sampling,
96-
"isolated_node_strategy": params.isolated_node_strategy,
97-
"loss_strategy": params.loss_strategy,
91+
"partition": {
92+
"method": "ece",
93+
"method_params": {
94+
"bidirectional": params.bidirectional,
95+
"expand_method": params.expand_method,
96+
"max_extra_edges": params.max_extra_edges,
97+
"max_tokens": params.max_tokens,
98+
"max_depth": params.max_depth,
99+
"edge_sampling": params.edge_sampling,
100+
"isolated_node_strategy": params.isolated_node_strategy,
101+
"loss_strategy": params.loss_strategy,
102+
},
103+
},
104+
"generate": {
105+
"mode": params.output_data_type,
106+
"data_format": params.output_data_format,
98107
},
99108
}
100109

101110
env = {
111+
"TOKENIZER_MODEL": params.tokenizer,
102112
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
103113
"SYNTHESIZER_MODEL": params.synthesizer_model,
104114
"TRAINEE_BASE_URL": params.trainee_url,
@@ -128,19 +138,18 @@ def sum_tokens(client):
128138

129139
try:
130140
# Process the data
131-
graph_gen.insert()
141+
graph_gen.insert(read_config=config["read"], split_config=config["split"])
132142

133143
if config["if_trainee_model"]:
134-
# Generate quiz
135-
graph_gen.quiz()
136-
137-
# Judge statements
138-
graph_gen.judge()
144+
# Quiz and Judge
145+
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
139146
else:
140-
graph_gen.traverse_strategy.edge_sampling = "random"
147+
config["partition"]["method_params"]["edge_sampling"] = "random"
141148

142-
# Traverse graph
143-
graph_gen.traverse()
149+
graph_gen.generate(
150+
partition_config=config["partition"],
151+
generate_config=config["generate"],
152+
)
144153

145154
# Save output
146155
output_data = graph_gen.qa_storage.data
@@ -249,6 +258,9 @@ def sum_tokens(client):
249258
)
250259

251260
with gr.Accordion(label=_("Model Config"), open=False):
261+
tokenizer = gr.Textbox(
262+
label="Tokenizer", value="cl100k_base", interactive=True
263+
)
252264
synthesizer_url = gr.Textbox(
253265
label="Synthesizer URL",
254266
value="https://api.siliconflow.cn/v1",
@@ -300,9 +312,6 @@ def sum_tokens(client):
300312
step=100,
301313
interactive=True,
302314
)
303-
tokenizer = gr.Textbox(
304-
label="Tokenizer", value="cl100k_base", interactive=True
305-
)
306315
output_data_type = gr.Radio(
307316
choices=["atomic", "multi_hop", "aggregated"],
308317
label="Output Data Type",

0 commit comments

Comments
 (0)