Skip to content

Commit 2623b15

Browse files
feat(charts): add graph traversal interface
1 parent 0bed25c commit 2623b15

File tree

6 files changed

+225
-26
lines changed

6 files changed

+225
-26
lines changed

.pylintrc

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,16 @@ ignore=CVS
5252
# ignore-list. The regex matches against paths and can be in Posix or Windows
5353
# format. Because '\\' represents the directory delimiter on Windows systems,
5454
# it can't be used as an escape character.
55-
ignore-paths=
5655

5756
# Files or directories matching the regular expression patterns are skipped.
5857
# The regex matches against base names, not paths. The default value ignores
5958
# Emacs file locks
6059
ignore-patterns=^\.#
6160

62-
# List of module names for which member attributes should not be checked and
63-
# will not be imported (useful for modules/projects where namespaces are
64-
# manipulated during runtime and thus existing member attributes cannot be
65-
# deduced by static analysis). It supports qualified module names, as well as
66-
# Unix pattern matching.
61+
# List of module names for which member attributes should not be checked
62+
# (useful for modules/projects where namespaces are manipulated during runtime
63+
# and thus existing member attributes cannot be deduced by static analysis). It
64+
# supports qualified module names, as well as Unix pattern matching.
6765
ignored-modules=
6866

6967
# Python code to execute, usually for sys.path manipulation such as
@@ -87,13 +85,9 @@ load-plugins=
8785
# Pickle collected data for later comparisons.
8886
persistent=yes
8987

90-
# Resolve imports to .pyi stubs if available. May reduce no-member messages and
91-
# increase not-an-iterable messages.
92-
prefer-stubs=no
93-
9488
# Minimum Python version to use for version dependent checks. Will default to
9589
# the version used to run pylint.
96-
py-version=3.10
90+
py-version=3.11
9791

9892
# Discover python modules and packages in the file system subtree.
9993
recursive=no
@@ -307,9 +301,6 @@ max-locals=15
307301
# Maximum number of parents for a class (see R0901).
308302
max-parents=7
309303

310-
# Maximum number of positional arguments for function / method.
311-
max-positional-arguments=5
312-
313304
# Maximum number of public methods for a class (see R0904).
314305
max-public-methods=20
315306

@@ -345,7 +336,7 @@ indent-after-paren=4
345336
indent-string=' '
346337

347338
# Maximum number of characters on a single line.
348-
max-line-length=100
339+
max-line-length=120
349340

350341
# Maximum number of lines in a module.
351342
max-module-lines=1000
@@ -438,7 +429,28 @@ disable=raw-checker-failed,
438429
deprecated-pragma,
439430
use-symbolic-message-instead,
440431
use-implicit-booleaness-not-comparison-to-string,
441-
use-implicit-booleaness-not-comparison-to-zero
432+
use-implicit-booleaness-not-comparison-to-zero,
433+
missing-module-docstring,
434+
missing-class-docstring,
435+
missing-function-docstring,
436+
W0122, # Use of exec (exec-used)
437+
R0914, # Too many local variables (19/15) (too-many-locals)
438+
R0903, # Too few public methods (1/2)
439+
W0613, # Unused argument
440+
W0511, # TODO
441+
W0719, # Raising too general exception: Exception
442+
R0801, # Similar lines
443+
W0105, # String statement has no effect (pointless-string-statement)
444+
R0913, # Too many arguments (6/5) (too-many-arguments)
445+
C0415, # Import outside toplevel
446+
R0902, # Too many instance attributes (11/7)
447+
R1725, # Consider using Python 3 style super() without arguments (super-with-arguments)
448+
W0622, # Redefining built-in 'id' (redefined-builtin)
449+
R0904, # Too many public methods (27/20) (too-many-public-methods)
450+
E1120, # TODO: unbound-method-call-no-value-for-parameter
451+
R0917, # Too many positional arguments (6/5) (too-many-positional-arguments)
452+
C0103,
453+
E0401
442454

443455
# Enable the message, report, category or checker with the given id(s). You can
444456
# either give multiple identifier separated by comma (,) or put this option
@@ -476,11 +488,6 @@ max-nested-blocks=5
476488
# printed.
477489
never-returning-functions=sys.exit,argparse.parse_error
478490

479-
# Let 'consider-using-join' be raised when the separator to join on would be
480-
# non-empty (resulting in expected fixes of the type: ``"- " + " -
481-
# ".join(items)``)
482-
suggest-join-with-non-empty-separator=yes
483-
484491

485492
[REPORTS]
486493

charts/plot_loss_change.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# 在训练前后的loss变化

charts/plot_rephrase_process.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from models import Tokenizer
66
from utils.log import parse_log
77
import plotly.express as px
8+
import plotly.graph_objects as go
9+
from collections import defaultdict
810

911
def analyse_log(log_info: dict) -> list:
1012
"""
@@ -71,7 +73,66 @@ async def plot_rephrase_process(stats: list[dict]):
7173
fig = px.scatter(df, x="pre_length", y="post_length", size="count", color="count", hover_name="count")
7274
fig.show()
7375

76+
def plot_pre_length_distribution(stats: list[dict]):
77+
"""
78+
Plot the distribution of pre-length.
79+
80+
:return fig
81+
"""
82+
83+
# 使用传入的stats参数而不是全局的data
84+
if not stats:
85+
return go.Figure()
86+
87+
# 计算最大长度并确定区间
88+
max_length = max(item['pre_length'] for item in stats)
89+
bin_size = 50
90+
max_length = ((max_length // bin_size) + 1) * bin_size
91+
92+
# 使用defaultdict避免键不存在的检查
93+
length_distribution = defaultdict(int)
94+
95+
# 一次遍历完成所有统计
96+
for item in stats:
97+
bin_start = (item['pre_length'] // bin_size) * bin_size
98+
bin_key = f"{bin_start}-{bin_start + bin_size}"
99+
length_distribution[bin_key] += 1
100+
101+
# 转换为排序后的列表以保持区间顺序
102+
sorted_bins = sorted(length_distribution.keys(),
103+
key=lambda x: int(x.split('-')[0]))
104+
105+
# 创建图表
106+
fig = go.Figure(data=[
107+
go.Bar(
108+
x=sorted_bins,
109+
y=[length_distribution[bin_] for bin_ in sorted_bins],
110+
text=[length_distribution[bin_] for bin_ in sorted_bins],
111+
textposition='auto',
112+
)
113+
])
114+
115+
# 设置图表布局
116+
fig.update_layout(
117+
title='Distribution of Pre-Length',
118+
xaxis_title='Length Range',
119+
yaxis_title='Count',
120+
bargap=0.2,
121+
showlegend=False
122+
)
123+
124+
# 如果数据点过多,优化x轴标签显示
125+
if len(sorted_bins) > 10:
126+
fig.update_layout(
127+
xaxis={
128+
'tickangle': 45,
129+
'tickmode': 'array',
130+
'ticktext': sorted_bins[::2], # 每隔一个显示标签
131+
'tickvals': list(range(len(sorted_bins)))[::2]
132+
}
133+
)
74134

135+
return fig
75136

76137
if __name__ == "__main__":
77138
log = parse_log('/home/PJLAB/chenzihong/Project/graphgen/cache/logs/graphgen.log')

evaluate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Evaluate the quality of the generated text using various metrics"""
2+
13
import os
24
import json
35
import argparse
@@ -72,7 +74,8 @@ def clean_gpu_cache():
7274
parser.add_argument('--output', type=str, default='cache/output', help='path to save output')
7375

7476
parser.add_argument('--tokenizer', type=str, default='cl100k_base', help='tokenizer name')
75-
parser.add_argument('--reward', type=str, default='OpenAssistant/reward-model-deberta-v3-large-v2', help='Comma-separated list of reward models')
77+
parser.add_argument('--reward', type=str, default='OpenAssistant/reward-model-deberta-v3-large-v2',
78+
help='Comma-separated list of reward models')
7679
parser.add_argument('--uni', type=str, default='MingZhong/unieval-sum', help='uni model name')
7780

7881
args = parser.parse_args()
@@ -122,5 +125,4 @@ def clean_gpu_cache():
122125

123126

124127
results = pd.DataFrame(results)
125-
126128
results.to_csv(os.path.join(args.output, 'evaluation.csv'), index=False)

generate.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
22
import json
33
import argparse
4+
from dotenv import load_dotenv
5+
46
from graphgen.graphgen import GraphGen
57
from models import OpenAIModel, Tokenizer, TraverseStrategy
6-
from dotenv import load_dotenv
78
from utils import set_logger
89

910
sys_path = os.path.abspath(os.path.dirname(__file__))
@@ -35,11 +36,13 @@
3536
input_file = args.input_file
3637

3738
if args.data_type == 'raw':
38-
with open(input_file, "r") as f:
39+
with open(input_file, "r", encoding='utf-8') as f:
3940
data = [json.loads(line) for line in f]
4041
elif args.data_type == 'chunked':
41-
with open(input_file, "r") as f:
42+
with open(input_file, "r", encoding='utf-8') as f:
4243
data = json.load(f)
44+
else:
45+
raise ValueError(f"Invalid data type: {args.data_type}")
4346

4447
teacher_llm_client = OpenAIModel(
4548
model_name=os.getenv("TEACHER_MODEL"),

simulate.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Simulate text length distributions using input data distributions when rephrasing."""
2+
3+
import gradio as gr
4+
5+
from models import TraverseStrategy, NetworkXStorage
6+
from charts.plot_rephrase_process import plot_pre_length_distribution
7+
from graphgen.operators.split_graph import get_batches_with_strategy
8+
from utils import create_event_loop
9+
import copy
10+
11+
if __name__ == "__main__":
12+
networkx_storage = NetworkXStorage(
13+
'/home/PJLAB/chenzihong/Project/graphgen/cache', namespace="graph"
14+
)
15+
16+
async def get_batches(traverse_strategy: TraverseStrategy):
17+
nodes = await networkx_storage.get_all_nodes()
18+
edges = await networkx_storage.get_all_edges()
19+
20+
nodes = list(nodes)
21+
edges = list(edges)
22+
23+
# deepcopy
24+
nodes = [(node[0], node[1].copy()) for node in nodes]
25+
edges = [(edge[0], edge[1], edge[2].copy()) for edge in edges]
26+
27+
nodes = copy.deepcopy(nodes)
28+
edges = copy.deepcopy(edges)
29+
assert all('length' in edge[2] for edge in edges)
30+
assert all('length' in node[1] for node in nodes)
31+
32+
return await get_batches_with_strategy(nodes, edges, networkx_storage, traverse_strategy)
33+
34+
def traverse_graph(
35+
bidirectional: bool,
36+
expand_method: str,
37+
max_extra_edges: int,
38+
max_tokens: int,
39+
max_depth: int,
40+
edge_sampling: str,
41+
isolated_node_strategy: str
42+
) -> str:
43+
traverse_strategy = TraverseStrategy(
44+
bidirectional=bidirectional,
45+
expand_method=expand_method,
46+
max_extra_edges=max_extra_edges,
47+
max_tokens=max_tokens,
48+
max_depth=max_depth,
49+
edge_sampling=edge_sampling,
50+
isolated_node_strategy=isolated_node_strategy
51+
)
52+
53+
loop = create_event_loop()
54+
batches = loop.run_until_complete(get_batches(traverse_strategy))
55+
loop.close()
56+
57+
data = []
58+
for _process_batch in batches:
59+
pre_length = sum([node['length'] for node in _process_batch[0]]) + sum(
60+
[edge[2]['length'] for edge in _process_batch[1]])
61+
data.append({
62+
'pre_length': pre_length
63+
})
64+
fig = plot_pre_length_distribution(data)
65+
66+
return fig
67+
68+
69+
def update_sliders(expand_method):
70+
if expand_method == "max_tokens":
71+
return gr.update(visible=True), gr.update(visible=False) # Show max_tokens, hide max_extra_edges
72+
else:
73+
return gr.update(visible=False), gr.update(visible=True) # Hide max_tokens, show max_extra_edges
74+
75+
76+
with gr.Blocks() as iface:
77+
gr.Markdown("# Graph Traversal Interface")
78+
79+
with gr.Row():
80+
with gr.Column():
81+
bidirectional = gr.Checkbox(label="Bidirectional", value=False)
82+
expand_method = gr.Dropdown(
83+
choices=["max_width", "max_tokens"],
84+
value="max_tokens",
85+
label="Expand Method",
86+
interactive=True
87+
)
88+
89+
# Initialize sliders
90+
max_extra_edges = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Max Extra Edges",
91+
visible=False)
92+
max_tokens = gr.Slider(minimum=128, maximum=8 * 1024, value=1024, step=128, label="Max Tokens")
93+
max_depth = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Max Depth")
94+
edge_sampling = gr.Dropdown(
95+
choices=["max_loss", "random", "min_loss"],
96+
value="max_loss",
97+
label="Edge Sampling Strategy"
98+
)
99+
isolated_node_strategy = gr.Dropdown(
100+
choices=["add", "ignore", "connect"],
101+
value="add",
102+
label="Isolated Node Strategy"
103+
)
104+
submit_btn = gr.Button("Traverse Graph")
105+
106+
with gr.Row():
107+
output_plot = gr.Plot(label="Graph Visualization")
108+
109+
# Set up event listener for expand_method dropdown
110+
expand_method.change(fn=update_sliders, inputs=expand_method, outputs=[max_tokens, max_extra_edges])
111+
112+
submit_btn.click(
113+
fn=traverse_graph,
114+
inputs=[
115+
bidirectional,
116+
expand_method,
117+
max_extra_edges,
118+
max_tokens,
119+
max_depth,
120+
edge_sampling,
121+
isolated_node_strategy
122+
],
123+
outputs=[output_plot]
124+
)
125+
iface.launch()

0 commit comments

Comments
 (0)