Skip to content

Commit ea056bf

Browse files
committed
feat: add some dataset scripts
1 parent 480b889 commit ea056bf

File tree

4 files changed

+530
-1
lines changed

4 files changed

+530
-1
lines changed

examples/basic_usage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ def demo_llm_webkit_with_preprocessed_html_evaluation():
889889

890890
# 1. 从真实数据集加载包含预处理HTML的数据
891891
print("1. 从真实数据集加载预处理HTML数据...")
892-
dataset_path = Path("data/WebMainBench_1827_v1_WebMainBench_dataset_merge_with_llm_webkit.jsonl")
892+
dataset_path = Path("data/track_id_diff_result_56.jsonl")
893893
print(f"📂 数据集文件: {dataset_path}")
894894

895895
# 加载数据集

scripts/diff_jsonl.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
#!/usr/bin/env python3
2+
"""
3+
比较两个JSONL文件,找出track_id在文件1中存在但在文件2中不存在的数据
4+
"""
5+
import json
6+
import sys
7+
from pathlib import Path
8+
9+
def load_track_ids(jsonl_file):
10+
"""
11+
从JSONL文件中加载所有track_id
12+
13+
Args:
14+
jsonl_file: JSONL文件路径
15+
16+
Returns:
17+
set: track_id集合
18+
"""
19+
track_ids = set()
20+
file_path = Path(jsonl_file)
21+
22+
if not file_path.exists():
23+
print(f"❌ 文件不存在: {file_path}")
24+
return track_ids
25+
26+
print(f"📖 正在读取文件: {file_path.name}")
27+
28+
line_count = 0
29+
try:
30+
with open(file_path, 'r', encoding='utf-8') as f:
31+
for line_num, line in enumerate(f, 1):
32+
line = line.strip()
33+
if not line:
34+
continue
35+
36+
line_count += 1
37+
38+
# 每处理1000行显示进度
39+
if line_count % 1000 == 0:
40+
print(f" 📊 已处理 {line_count} 行...")
41+
42+
try:
43+
data = json.loads(line)
44+
track_id = data.get('track_id')
45+
46+
if track_id:
47+
track_ids.add(track_id)
48+
49+
except json.JSONDecodeError as e:
50+
print(f" ⚠️ 第 {line_num} 行JSON解析错误: {e}")
51+
continue
52+
53+
except Exception as e:
54+
print(f"❌ 读取文件时出错: {e}")
55+
return set()
56+
57+
print(f" ✅ 共找到 {len(track_ids)} 个唯一track_id")
58+
return track_ids
59+
60+
def load_data_with_track_ids(jsonl_file, target_track_ids):
61+
"""
62+
从JSONL文件中加载指定track_id的数据
63+
64+
Args:
65+
jsonl_file: JSONL文件路径
66+
target_track_ids: 目标track_id集合
67+
68+
Returns:
69+
list: 匹配的数据列表
70+
"""
71+
matched_data = []
72+
file_path = Path(jsonl_file)
73+
74+
if not file_path.exists():
75+
print(f"❌ 文件不存在: {file_path}")
76+
return matched_data
77+
78+
print(f"📖 正在从 {file_path.name} 中提取目标数据...")
79+
80+
line_count = 0
81+
found_count = 0
82+
83+
try:
84+
with open(file_path, 'r', encoding='utf-8') as f:
85+
for line_num, line in enumerate(f, 1):
86+
line = line.strip()
87+
if not line:
88+
continue
89+
90+
line_count += 1
91+
92+
# 每处理1000行显示进度
93+
if line_count % 1000 == 0:
94+
print(f" 📊 已处理 {line_count} 行,找到 {found_count} 条目标数据...")
95+
96+
try:
97+
data = json.loads(line)
98+
track_id = data.get('track_id')
99+
100+
if track_id in target_track_ids:
101+
matched_data.append(data)
102+
found_count += 1
103+
104+
except json.JSONDecodeError as e:
105+
print(f" ⚠️ 第 {line_num} 行JSON解析错误: {e}")
106+
continue
107+
108+
except Exception as e:
109+
print(f"❌ 读取文件时出错: {e}")
110+
return []
111+
112+
print(f" ✅ 共找到 {len(matched_data)} 条目标数据")
113+
return matched_data
114+
115+
def main():
116+
"""主函数"""
117+
# 默认输入文件
118+
file1_default = "data/filtered_normal_data_1883.jsonl"
119+
file2_default = "data/WebMainBench_1827_v1_WebMainBench_dataset_merge_with_llm_webkit.jsonl"
120+
121+
# 检查命令行参数
122+
if len(sys.argv) >= 3:
123+
file1 = sys.argv[1]
124+
file2 = sys.argv[2]
125+
else:
126+
file1 = file1_default
127+
file2 = file2_default
128+
129+
print("=" * 80)
130+
print("🔍 比较JSONL文件中的track_id差异")
131+
print("=" * 80)
132+
print(f"📁 文件1 (源文件): {file1}")
133+
print(f"📁 文件2 (对比文件): {file2}")
134+
print(f"🎯 目标: 找出在文件1中存在但在文件2中不存在的track_id数据")
135+
print()
136+
137+
# 步骤1: 加载文件1的所有track_id
138+
print("🔸 步骤1: 加载文件1的track_id...")
139+
track_ids_file1 = load_track_ids(file1)
140+
141+
if not track_ids_file1:
142+
print("❌ 文件1中没有找到有效的track_id")
143+
return
144+
145+
print()
146+
147+
# 步骤2: 加载文件2的所有track_id
148+
print("🔸 步骤2: 加载文件2的track_id...")
149+
track_ids_file2 = load_track_ids(file2)
150+
151+
if not track_ids_file2:
152+
print("❌ 文件2中没有找到有效的track_id")
153+
return
154+
155+
print()
156+
157+
# 步骤3: 计算差集
158+
print("🔸 步骤3: 计算差集...")
159+
diff_track_ids = track_ids_file1 - track_ids_file2
160+
common_track_ids = track_ids_file1 & track_ids_file2
161+
162+
print(f" 📊 文件1中的track_id数量: {len(track_ids_file1):,}")
163+
print(f" 📊 文件2中的track_id数量: {len(track_ids_file2):,}")
164+
print(f" 📊 共同的track_id数量: {len(common_track_ids):,}")
165+
print(f" ⭐ 差异的track_id数量: {len(diff_track_ids):,}")
166+
167+
if not diff_track_ids:
168+
print("\n🎉 没有发现差异!文件1中的所有track_id在文件2中都存在。")
169+
return
170+
171+
print()
172+
173+
# 步骤4: 提取差异数据
174+
print("🔸 步骤4: 提取差异数据...")
175+
diff_data = load_data_with_track_ids(file1, diff_track_ids)
176+
177+
if not diff_data:
178+
print("❌ 没有找到差异数据")
179+
return
180+
181+
print()
182+
183+
# 步骤5: 保存结果
184+
print("🔸 步骤5: 保存差异数据...")
185+
output_file = "data/track_id_diff_result.jsonl"
186+
187+
try:
188+
with open(output_file, 'w', encoding='utf-8') as f:
189+
for data in diff_data:
190+
f.write(json.dumps(data, ensure_ascii=False) + '\n')
191+
192+
print(f"✅ 已保存 {len(diff_data)} 条差异数据到: {output_file}")
193+
194+
# 显示前几个差异的track_id作为示例
195+
print(f"\n📋 差异track_id示例 (前10个):")
196+
for i, track_id in enumerate(list(diff_track_ids)[:10], 1):
197+
print(f" {i}. {track_id}")
198+
199+
if len(diff_track_ids) > 10:
200+
print(f" ... 还有 {len(diff_track_ids) - 10} 个")
201+
202+
except Exception as e:
203+
print(f"❌ 保存文件时出错: {e}")
204+
return
205+
206+
print("\n" + "=" * 80)
207+
print("🎉 比较完成!")
208+
print("=" * 80)
209+
210+
if __name__ == "__main__":
211+
main()

scripts/filter_normal_jsonl.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#!/usr/bin/env python3
2+
"""
3+
过滤JSONL文件中marked_type为normal的数据
4+
"""
5+
import json
6+
import sys
7+
from pathlib import Path
8+
9+
def filter_normal_data(input_file):
10+
"""
11+
过滤marked_type为normal的数据
12+
13+
Args:
14+
input_file: 输入的JSONL文件路径
15+
16+
Returns:
17+
tuple: (normal_data_list, total_count, normal_count)
18+
"""
19+
input_path = Path(input_file)
20+
21+
if not input_path.exists():
22+
print(f"❌ 文件不存在: {input_path}")
23+
return [], 0, 0
24+
25+
normal_data = []
26+
total_count = 0
27+
normal_count = 0
28+
29+
print(f"📖 正在读取文件: {input_path}")
30+
31+
try:
32+
with open(input_path, 'r', encoding='utf-8') as f:
33+
for line_num, line in enumerate(f, 1):
34+
line = line.strip()
35+
if not line:
36+
continue
37+
38+
total_count += 1
39+
40+
# 每处理1000行显示进度
41+
if total_count % 1000 == 0:
42+
print(f"📊 已处理 {total_count} 行...")
43+
44+
try:
45+
data = json.loads(line)
46+
marked_type = data.get('marked_type', '')
47+
48+
if marked_type == 'normal':
49+
normal_count += 1
50+
normal_data.append(data)
51+
52+
except json.JSONDecodeError as e:
53+
print(f"⚠️ 第 {line_num} 行JSON解析错误: {e}")
54+
continue
55+
56+
except Exception as e:
57+
print(f"❌ 读取文件时出错: {e}")
58+
return [], 0, 0
59+
60+
return normal_data, total_count, normal_count
61+
62+
def main():
63+
"""主函数"""
64+
# 默认输入文件
65+
default_input = "data/WebMainBench_dataset_merge_2549_llm_webkit.jsonl"
66+
67+
# 检查命令行参数
68+
if len(sys.argv) > 1:
69+
input_file = sys.argv[1]
70+
else:
71+
input_file = default_input
72+
73+
print("=" * 60)
74+
print("🔍 过滤 marked_type 为 'normal' 的数据")
75+
print("=" * 60)
76+
77+
# 执行过滤
78+
normal_data, total_count, normal_count = filter_normal_data(input_file)
79+
80+
# 输出统计结果
81+
print("\n" + "=" * 60)
82+
print("📊 统计结果")
83+
print("=" * 60)
84+
print(f"📁 输入文件: {input_file}")
85+
print(f"📄 总数据条数: {total_count:,}")
86+
print(f"✅ normal类型数据: {normal_count:,}")
87+
88+
if total_count > 0:
89+
percentage = (normal_count / total_count) * 100
90+
print(f"📈 normal类型占比: {percentage:.2f}%")
91+
92+
# 显示其他统计信息
93+
other_count = total_count - normal_count
94+
if other_count > 0:
95+
other_percentage = (other_count / total_count) * 100
96+
print(f"📊 其他类型数据: {other_count:,} ({other_percentage:.2f}%)")
97+
98+
# 询问是否保存过滤结果
99+
if normal_count > 0:
100+
print(f"\n💾 是否保存过滤结果? 将保存到 filtered_normal_data.jsonl")
101+
user_input = input("输入 'y' 保存,其他键跳过: ").strip().lower()
102+
103+
if user_input == 'y':
104+
output_file = "filtered_normal_data.jsonl"
105+
try:
106+
with open(output_file, 'w', encoding='utf-8') as f:
107+
for data in normal_data:
108+
f.write(json.dumps(data, ensure_ascii=False) + '\n')
109+
print(f"✅ 已保存 {normal_count} 条normal类型数据到: {output_file}")
110+
except Exception as e:
111+
print(f"❌ 保存文件时出错: {e}")
112+
113+
print("\n🎉 处理完成!")
114+
115+
if __name__ == "__main__":
116+
main()

0 commit comments

Comments
 (0)