-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_dataset.py
More file actions
416 lines (330 loc) · 14.7 KB
/
generate_dataset.py
File metadata and controls
416 lines (330 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
#!/usr/bin/env python3
"""
HarmonyOS ETS代码补全数据集生成工具
该脚本用于从HarmonyOS项目中生成代码补全训练数据集,支持跨文件上下文。
基于项目的依赖关系图和拓扑排序结果,为每个文件生成包含上下文的代码补全样本。
作者: Assistant
创建时间: 2024
Python版本: 3.7+
"""
import os
import json
import argparse
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
from pathlib import Path
import random
import re
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
@dataclass
class CodeBlock:
"""
表示一个代码块,通常是一个文件的内容。
Attributes:
file_path (str): 文件路径
code_content (str): 文件的完整代码内容
language (str): 编程语言类型,默认为 'arkts'
"""
file_path: str
code_content: str
language: str
@dataclass
class CodeData:
"""
表示一个代码补全训练样本。
Attributes:
file_path (str): 目标文件路径
left_context (str): 目标代码左侧的上下文
right_context (str): 目标代码右侧的上下文
related_files (List[CodeBlock]): 相关文件列表,提供跨文件上下文
target_code (str): 需要补全的目标代码片段
language (str): 编程语言类型
task_id (str): 唯一任务标识符
"""
file_path: str
left_context: str
right_context: str
related_files: List[CodeBlock]
target_code: str
language: str
task_id: str
class CodeDataGenerator:
"""
HarmonyOS ETS代码补全数据集生成器。
该类负责从项目的元数据和拓扑排序结果中生成代码补全训练样本。
支持基于依赖关系的跨文件上下文生成。
"""
def __init__(self, metadata_path: str, topo_sort_path: str):
"""
初始化数据集生成器。
Args:
metadata_path (str): 项目元数据文件路径,包含文件依赖关系
topo_sort_path (str): 拓扑排序结果文件路径,包含文件分层信息
Raises:
FileNotFoundError: 当指定的文件不存在时
json.JSONDecodeError: 当JSON文件格式错误时
"""
self.metadata_path = Path(metadata_path)
self.topo_sort_path = Path(topo_sort_path)
# 验证文件存在性
if not self.metadata_path.exists():
raise FileNotFoundError(f"元数据文件不存在: {metadata_path}")
if not self.topo_sort_path.exists():
raise FileNotFoundError(f"拓扑排序文件不存在: {topo_sort_path}")
# 加载配置文件
logger.info(f"加载元数据文件: {metadata_path}")
with open(self.metadata_path, 'r', encoding='utf-8') as f:
self.metadata = json.load(f)
logger.info(f"加载拓扑排序文件: {topo_sort_path}")
with open(self.topo_sort_path, 'r', encoding='utf-8') as f:
self.topo_sort = json.load(f)
# 初始化数据存储
self.data: Dict[str, CodeData] = {}
self.task_id_counter = 1 # 简单的递增计数器
logger.info("CodeDataGenerator 初始化完成")
def generate_dependency(self, related_files: List[str]) -> List[CodeBlock]:
"""
根据相关文件路径列表生成CodeBlock对象列表。
Args:
related_files (List[str]): 相关文件路径列表
Returns:
List[CodeBlock]: CodeBlock对象列表,包含文件内容
Note:
如果文件读取失败,会记录警告并跳过该文件
"""
dependency = []
if not related_files:
return []
for file_path in related_files:
try:
with open(file_path, "r", encoding='utf-8') as f:
code_content = f.read()
dependency.append(CodeBlock(file_path, code_content, "arkts"))
logger.debug(f"成功读取依赖文件: {file_path}")
except (FileNotFoundError, UnicodeDecodeError) as e:
logger.warning(f"读取文件失败 {file_path}: {e}")
continue
logger.info(f"生成 {len(dependency)} 个依赖文件")
return dependency
def generate_code_data(self, file_path: str, related_files: List[CodeBlock]) -> Optional[CodeData]:
"""
为指定文件生成代码补全训练样本。
该方法会随机选择文件中的一段代码作为补全目标,并生成相应的上下文。
Args:
file_path (str): 目标文件路径
related_files (List[CodeBlock]): 相关文件列表,提供跨文件上下文
Returns:
Optional[CodeData]: 生成的代码补全样本,如果生成失败则返回None
Note:
- 分割位置在文件的10%-80%之间随机选择
- 目标代码长度在16-96个token之间
- 如果文件太短或读取失败,会返回None
"""
try:
with open(file_path, "r", encoding='utf-8') as f:
code_content = f.read()
# 使用正则表达式切分token,但上下文/目标代码用原始源码切片,避免 ' '.join(token) 引入额外空格
token_matches = list(re.finditer(r'\w+|[^\w\s]', code_content))
# 验证文件长度
min_tokens = 50 # 最少需要50个token
if len(token_matches) < min_tokens:
logger.warning(f"文件太短,跳过: {file_path} (仅有 {len(token_matches)} 个token)")
return None
# 随机选择分割位置 (10%-80%之间)
split_ratio = random.uniform(0.1, 0.8)
end_token_index = int(len(token_matches) * split_ratio)
# 生成左上下文
left_end_char = token_matches[end_token_index].start() if end_token_index < len(token_matches) else len(code_content)
left_context = code_content[:left_end_char]
# 随机生成目标代码长度
max_target_length = len(token_matches) - end_token_index - 10 # 保留一些右上下文
# 确保有足够的 token 可以生成 16-96 长度的目标代码
if max_target_length < 16:
logger.warning(f"剩余token不足,无法生成至少16个token的目标代码: {file_path}")
return None
target_length = random.randint(16, min(96, max_target_length))
if target_length <= 0:
logger.warning(f"无法生成有效目标代码: {file_path}")
return None
# 生成目标代码和右上下文
target_start_char = token_matches[end_token_index].start()
target_end_token_index = end_token_index + target_length - 1
target_end_char = token_matches[target_end_token_index].end()
target = code_content[target_start_char:target_end_char]
right_context = code_content[target_end_char:]
# 生成唯一任务ID
task_id = str(self.task_id_counter)
self.task_id_counter += 1
logger.debug(f"为文件 {file_path} 生成代码样本,任务ID: {task_id}")
return CodeData(
file_path=file_path,
left_context=left_context,
right_context=right_context,
related_files=related_files,
target_code=target,
language="arkts",
task_id=task_id
)
except (FileNotFoundError, UnicodeDecodeError) as e:
logger.error(f"生成代码数据失败 {file_path}: {e}")
return None
def save_data(self, output_path: str) -> None:
"""
将生成的数据集保存为JSON文件。
Args:
output_path (str): 输出文件路径
Note:
- 数据会被转换为字典格式以支持JSON序列化
- 使用UTF-8编码和格式化输出
- 如果保存失败会记录错误日志
"""
try:
# 将CodeData对象转换为字典以便JSON序列化
serializable_data = {}
for file_path, code_data in self.data.items():
serializable_data[file_path] = asdict(code_data)
# 确保输出目录存在
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
# 保存数据
with open(output_file, "w", encoding='utf-8') as f:
json.dump(serializable_data, f, ensure_ascii=False, indent=2)
logger.info(f"数据集已保存到: {output_path}")
logger.info(f"总样本数: {len(serializable_data)}")
except Exception as e:
logger.error(f"保存数据集失败: {e}")
raise
def generate_data(self, output_path: str, max_layers: Optional[int] = None) -> None:
"""
生成代码补全数据集。
根据拓扑排序的层级顺序,为每个文件生成代码补全样本。
支持跨文件上下文,基于文件的依赖关系提供相关文件信息。
Args:
output_path (str): 输出文件路径
max_layers (Optional[int]): 最大处理层数,None表示处理所有层
Note:
- 按拓扑排序的层级顺序处理文件
- 跳过无法读取或太短的文件
- 自动保存结果到指定路径
"""
logger.info("开始生成代码补全数据集")
# 获取依赖关系和分层信息
dependencies = self.metadata.get("dependencies", {})
layers = self.topo_sort.get("layers", {})
if not layers:
logger.error("未找到拓扑排序的层级信息")
return
# 统计信息
total_files = 0
processed_files = 0
failed_files = 0
# 按层级处理文件
layer_names = sorted(layers.keys(), key=lambda x: int(x.split('_')[1]))
if max_layers:
layer_names = layer_names[:max_layers]
for layer_name in layer_names:
layer_files = layers[layer_name]
total_files += len(layer_files)
logger.info(f"处理 {layer_name},包含 {len(layer_files)} 个文件")
for file_path in layer_files:
try:
# 获取相关文件(依赖)
related_file_paths = dependencies.get(file_path, [])
# 生成依赖文件的CodeBlock对象
code_dependencies = self.generate_dependency(related_file_paths)
# 生成代码补全样本
code_data = self.generate_code_data(file_path, code_dependencies)
if code_data:
self.data[file_path] = code_data
processed_files += 1
logger.debug(f"成功处理文件: {file_path}")
else:
failed_files += 1
logger.warning(f"跳过文件: {file_path}")
except Exception as e:
failed_files += 1
logger.error(f"处理文件失败 {file_path}: {e}")
# 输出统计信息
logger.info(f"数据生成完成:")
logger.info(f" 总文件数: {total_files}")
logger.info(f" 成功处理: {processed_files}")
logger.info(f" 失败跳过: {failed_files}")
logger.info(f" 成功率: {processed_files/total_files*100:.1f}%" if total_files > 0 else " 成功率: 0%")
# 保存数据集
if self.data:
self.save_data(output_path)
else:
logger.warning("没有生成任何数据,跳过保存")
def main():
"""
主函数:解析命令行参数并生成代码补全数据集。
"""
parser = argparse.ArgumentParser(
description="HarmonyOS ETS代码补全数据集生成工具",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例用法:
# 使用默认参数
python generate_dataset.py
# 指定自定义路径
python generate_dataset.py --metadata_path ./metadata.json --output_path ./dataset.json
# 只处理前3层文件
python generate_dataset.py --max_layers 3
# 开启详细日志
python generate_dataset.py --verbose
"""
)
parser.add_argument(
"--metadata_path",
type=str,
default="intermedia_result/FinanceTemplate_metadata.json",
help="项目元数据文件路径 (默认: metadata.json)"
)
parser.add_argument(
"--topo_sort_path",
type=str,
default="intermedia_result/FinanceTemplate_topological_sort.json",
help="拓扑排序结果文件路径 (默认: topological_sort_result.json)"
)
parser.add_argument(
"--output_path",
type=str,
default="data/code_data.json",
help="输出数据集文件路径 (默认: code_data.json)"
)
parser.add_argument(
"--max_layers",
type=int,
default=None,
help="最大处理层数,用于限制处理范围 (默认: 处理所有层)"
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="开启详细日志输出"
)
args = parser.parse_args()
# 设置日志级别
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
logger.info("开启详细日志模式")
try:
# 创建数据集生成器
logger.info("初始化代码数据集生成器")
generator = CodeDataGenerator(args.metadata_path, args.topo_sort_path)
# 生成数据集
generator.generate_data(args.output_path, args.max_layers)
logger.info("程序执行完成")
except Exception as e:
logger.error(f"程序执行失败: {e}")
return 1
return 0
if __name__ == "__main__":
exit(main())