Skip to content

Commit 6af1cda

Browse files
feat(charts): add plot loss method
1 parent d2ee9cb commit 6af1cda

File tree

7 files changed

+128
-34
lines changed

7 files changed

+128
-34
lines changed

baselines/LongForm/longform.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33

44
import os
55
import json
6-
from dotenv import load_dotenv
6+
from dataclasses import dataclass
77
import argparse
88
import asyncio
9+
from typing import List
10+
from tqdm.asyncio import tqdm as tqdm_async
11+
from dotenv import load_dotenv
912

10-
from dataclasses import dataclass
1113
from models import OpenAIModel
12-
from typing import List
1314
from utils import create_event_loop, compute_content_hash
14-
from tqdm.asyncio import tqdm as tqdm_async
15+
1516

1617
PROMPT_TEMPLATE = '''Instruction: X
1718
Output:{doc}
@@ -28,7 +29,7 @@ def generate(self, docs: List[List[dict]]) -> List[dict]:
2829
loop = create_event_loop()
2930
return loop.run_until_complete(self.async_generate(docs))
3031

31-
async def async_generate(self, docs: List[List[dict]]) -> List[dict]:
32+
async def async_generate(self, docs: List[List[dict]]) -> dict:
3233
final_results = {}
3334
semaphore = asyncio.Semaphore(self.max_concurrent)
3435

@@ -51,7 +52,7 @@ async def process_chunk(content: str):
5152
try:
5253
qa = await result
5354
final_results.update(qa)
54-
except Exception as e:
55+
except Exception as e: # pylint: disable=broad-except
5556
print(f"Error: {e}")
5657
return final_results
5758

@@ -84,15 +85,15 @@ async def process_chunk(content: str):
8485
longform = LongForm(llm_client=llm_client)
8586

8687
if args.data_type == 'raw':
87-
with open(args.input_file, "r") as f:
88+
with open(args.input_file, "r", encoding='utf-8') as f:
8889
data = [json.loads(line) for line in f]
8990
data = [[chunk] for chunk in data]
9091
elif args.data_type == 'chunked':
91-
with open(args.input_file, "r") as f:
92+
with open(args.input_file, "r", encoding='utf-8') as f:
9293
data = json.load(f)
9394

9495
results = longform.generate(data)
9596

9697
# Save results
97-
with open(args.output_file, "w") as f:
98+
with open(args.output_file, "w", encoding='utf-8') as f:
9899
json.dump(results, f, indent=4, ensure_ascii=False)

baselines/SELF-QA/self-qa.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import os
44
import json
5-
from dotenv import load_dotenv
5+
from dataclasses import dataclass
6+
from typing import List
67
import argparse
78
import asyncio
8-
from dataclasses import dataclass
9+
from tqdm.asyncio import tqdm as tqdm_async
10+
from dotenv import load_dotenv
11+
912
from models import OpenAIModel
10-
from typing import List
1113
from utils import create_event_loop, compute_content_hash
12-
from tqdm.asyncio import tqdm as tqdm_async
1314

1415
INSTRUCTION_GENERATION_PROMPT = '''The background knowledge is:
1516
{doc}
@@ -49,6 +50,7 @@ def _post_process_answers(content: str) -> tuple:
4950
question = content.split('Question:')[1].split('Answer:')[0].strip()
5051
answer = content.split('Answer:')[1].strip()
5152
return question, answer
53+
return None, None
5254

5355
@dataclass
5456
class SelfQA:
@@ -59,7 +61,7 @@ def generate(self, docs: List[List[dict]]) -> List[dict]:
5961
loop = create_event_loop()
6062
return loop.run_until_complete(self.async_generate(docs))
6163

62-
async def async_generate(self, docs: List[List[dict]]) -> List[dict]:
64+
async def async_generate(self, docs: List[List[dict]]) -> dict:
6365
final_results = {}
6466
semaphore = asyncio.Semaphore(self.max_concurrent)
6567

@@ -71,20 +73,26 @@ async def process_chunk(content: str):
7173
instruction_questions = _post_process_instructions(response)
7274

7375
qas = []
74-
for qa in tqdm_async(asyncio.as_completed([self.llm_client.generate_answer(READING_COMPREHENSION_PROMPT.format(doc=content, question=question)) for question in instruction_questions]), total=len(instruction_questions), desc="Generating QAs"):
76+
for qa in tqdm_async(asyncio.as_completed([
77+
self.llm_client.generate_answer(READING_COMPREHENSION_PROMPT.format(
78+
doc=content,
79+
question=question
80+
)) for question in instruction_questions]),
81+
total=len(instruction_questions), desc="Generating QAs"):
7582
try:
7683
question, answer = _post_process_answers(await qa)
77-
qas.append({
78-
compute_content_hash(question): {
79-
'question': question,
80-
'answer': answer
81-
}
82-
})
83-
except Exception as e:
84+
if question and answer:
85+
qas.append({
86+
compute_content_hash(question): {
87+
'question': question,
88+
'answer': answer
89+
}
90+
})
91+
except Exception as e: # pylint: disable=broad-except
8492
print(f"Error: {e}")
8593
continue
8694
return qas
87-
except Exception as e:
95+
except Exception as e: # pylint: disable=broad-except
8896
print(f"Error: {e}")
8997
return []
9098

@@ -98,7 +106,7 @@ async def process_chunk(content: str):
98106
qas = await result
99107
for qa in qas:
100108
final_results.update(qa)
101-
except Exception as e:
109+
except Exception as e: # pylint: disable=broad-except
102110
print(f"Error: {e}")
103111
return final_results
104112

@@ -131,15 +139,15 @@ async def process_chunk(content: str):
131139
self_qa = SelfQA(llm_client=llm_client)
132140

133141
if args.data_type == 'raw':
134-
with open(args.input_file, "r") as f:
142+
with open(args.input_file, "r", encoding='utf-8') as f:
135143
data = [json.loads(line) for line in f]
136144
data = [[chunk] for chunk in data]
137145
elif args.data_type == 'chunked':
138-
with open(args.input_file, "r") as f:
146+
with open(args.input_file, "r", encoding='utf-8') as f:
139147
data = json.load(f)
140148

141149
results = self_qa.generate(data)
142150

143151
# Save results
144-
with open(args.output_file, "w") as f:
152+
with open(args.output_file, "w", encoding='utf-8') as f:
145153
json.dump(results, f, indent=4, ensure_ascii=False)

charts/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .plot_rephrase_process import plot_pre_length_distribution, plot_post_synth_length_distribution
2+
from .plot_loss_change import plot_loss_distribution

charts/plot_loss_change.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,54 @@
1-
# 在训练前后的loss变化
1+
from collections import defaultdict
2+
import plotly.graph_objects as go
3+
4+
def plot_loss_distribution(stats: list[dict]):
5+
"""
6+
Plot the distribution of edges' loss.
7+
8+
:return fig
9+
"""
10+
11+
if not stats:
12+
return go.Figure()
13+
14+
max_loss = max(item['average_loss'] for item in stats)
15+
bin_numbers = 50
16+
bin_size = max_loss / bin_numbers
17+
18+
length_distribution = defaultdict(int)
19+
20+
for item in stats:
21+
bin_start = (item['average_loss'] // bin_size) * bin_size
22+
bin_key = f"{bin_start}-{bin_start + bin_size}"
23+
length_distribution[bin_key] += 1
24+
25+
sorted_bins = sorted(length_distribution.keys(),
26+
key=lambda x: float(x.split('-')[0]))
27+
28+
fig = go.Figure(data=[
29+
go.Bar(
30+
x=sorted_bins,
31+
y=[length_distribution[bin_] for bin_ in sorted_bins],
32+
text=[length_distribution[bin_] for bin_ in sorted_bins],
33+
textposition='auto',
34+
)
35+
])
36+
37+
fig.update_layout(
38+
title='Distribution of Loss',
39+
xaxis_title='Loss Range',
40+
yaxis_title='Count',
41+
bargap=0.2,
42+
showlegend=False
43+
)
44+
45+
if len(sorted_bins) > 10:
46+
fig.update_layout(
47+
xaxis={
48+
'tickangle': 45,
49+
'tickmode': 'array',
50+
'ticktext': sorted_bins[::2],
51+
'tickvals': list(range(len(sorted_bins)))[::2]
52+
}
53+
)
54+
return fig

charts/plot_rephrase_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
2-
import plotly.express as px
32
from collections import defaultdict
3+
import plotly.express as px
44
import plotly.graph_objects as go
55
import pandas as pd
66
from tqdm import tqdm

generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
graph_gen.traverse()
8181

8282
config_path = os.path.join(sys_path, "cache", "configs", f"graphgen_{unique_id}.yaml")
83-
if not os.path.exists(config_path):
84-
os.makedirs(config_path)
83+
if not os.path.exists(os.path.dirname(config_path)):
84+
os.makedirs(os.path.dirname(config_path))
8585
with open(config_path, "w", encoding='utf-8') as f:
8686
yaml.dump(traverse_strategy.to_yaml(), f)

simulate.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import json
66
import gradio as gr
77

8-
from models import TraverseStrategy, NetworkXStorage
9-
from charts.plot_rephrase_process import plot_pre_length_distribution, plot_post_synth_length_distribution
8+
from models import TraverseStrategy, NetworkXStorage, Tokenizer
9+
from charts import plot_pre_length_distribution, plot_post_synth_length_distribution, plot_loss_distribution
1010
from graphgen.operators.split_graph import get_batches_with_strategy
1111
from utils import create_event_loop
12-
from models import Tokenizer
12+
1313

1414
if __name__ == "__main__":
1515
networkx_storage = NetworkXStorage(
@@ -153,4 +153,34 @@ def synthesize_text(file):
153153
outputs=[output_plot]
154154
)
155155

156+
with gr.Tab("After Judgement"):
157+
with gr.Row():
158+
with gr.Column():
159+
file_list = os.listdir("cache/data/graphgen")
160+
input_file = gr.Dropdown(choices=file_list, label="Input File")
161+
file_button = gr.Button("Submit File")
162+
163+
with gr.Row():
164+
output_plot = gr.Plot(label="Graph Visualization")
165+
166+
def judge_graph(file):
167+
with open(f"cache/data/graphgen/{file}", "r", encoding='utf-8') as f:
168+
data = json.load(f)
169+
stats = []
170+
for key in data:
171+
item = data[key]
172+
item['average_loss'] = sum(loss[2] for loss in item['losses']) / len(item['losses'])
173+
stats.append({
174+
'average_loss': item['average_loss']
175+
})
176+
fig = plot_loss_distribution(stats)
177+
return fig
178+
179+
file_button.click(
180+
fn=judge_graph,
181+
inputs=[input_file],
182+
outputs=[output_plot]
183+
)
184+
185+
156186
app.launch()

0 commit comments

Comments
 (0)