-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_api_dataset.py
More file actions
765 lines (621 loc) · 28.1 KB
/
generate_api_dataset.py
File metadata and controls
765 lines (621 loc) · 28.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
#!/usr/bin/env python3
"""
HarmonyOS ETS代码补全数据集生成工具 - API版本
该脚本专门用于生成以API调用为目标的代码补全数据集。
重点关注依赖其他文件的API调用(如 a.b() 形式),并确保挖空长度在8-16个token之间。
作者: Assistant
创建时间: 2024
Python版本: 3.7+
"""
import os
import json
import argparse
import logging
import re
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
from pathlib import Path
import random
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
@dataclass
class CodeBlock:
"""
表示一个代码块,通常是一个文件的内容。
Attributes:
file_path (str): 文件路径
code_content (str): 文件的完整代码内容
language (str): 编程语言类型,默认为 'arkts'
"""
file_path: str
code_content: str
language: str
@dataclass
class CodeData:
"""
表示一个代码补全训练样本。
Attributes:
file_path (str): 目标文件路径
left_context (str): 目标代码左侧的上下文
right_context (str): 目标代码右侧的上下文
related_files (List[CodeBlock]): 相关文件列表,提供跨文件上下文
target_code (str): 需要补全的目标代码片段(API调用)
language (str): 编程语言类型
task_id (str): 唯一任务标识符
api_info (Dict): API调用的详细信息
"""
file_path: str
left_context: str
right_context: str
related_files: List[CodeBlock]
target_code: str
language: str
task_id: str
api_info: Dict[str, Any]
class APICodeDataGenerator:
"""
HarmonyOS ETS代码补全数据集生成器 - API版本。
专门用于生成以API调用为目标的代码补全样本。
重点关注依赖其他文件的API调用,确保挖空长度在8-16个token之间。
"""
def __init__(self, metadata_path: str, topo_sort_path: str,
min_hole_length: int = 8, max_hole_length: int = 16,
pure_api: bool = True):
"""
初始化数据集生成器。
Args:
metadata_path (str): 项目元数据文件路径,包含文件依赖关系
topo_sort_path (str): 拓扑排序结果文件路径,包含文件分层信息
min_hole_length (int): 挖空最小长度(token数)
max_hole_length (int): 挖空最大长度(token数)
pure_api (bool): 是否使用纯API补全模式
- True: 只挖空API调用本身
- False: API调用 + 随机前后上下文,总长度在min_hole_length到max_hole_length之间
Raises:
FileNotFoundError: 当指定的文件不存在时
json.JSONDecodeError: 当JSON文件格式错误时
"""
self.metadata_path = Path(metadata_path)
self.topo_sort_path = Path(topo_sort_path)
self.min_hole_length = min_hole_length
self.max_hole_length = max_hole_length
self.pure_api = pure_api
# 验证文件存在性
if not self.metadata_path.exists():
raise FileNotFoundError(f"元数据文件不存在: {metadata_path}")
if not self.topo_sort_path.exists():
raise FileNotFoundError(f"拓扑排序文件不存在: {topo_sort_path}")
# 加载配置文件
logger.info(f"加载元数据文件: {metadata_path}")
with open(self.metadata_path, 'r', encoding='utf-8') as f:
self.metadata = json.load(f)
logger.info(f"加载拓扑排序文件: {topo_sort_path}")
with open(self.topo_sort_path, 'r', encoding='utf-8') as f:
self.topo_sort = json.load(f)
# 初始化数据存储
self.data: Dict[str, CodeData] = {}
self.task_id_counter = 1
mode_desc = "纯API补全" if pure_api else "API+上下文扩展"
logger.info(f"APICodeDataGenerator 初始化完成 (模式: {mode_desc}, 挖空长度: {min_hole_length}-{max_hole_length} tokens)")
def tokenize_code(self, code: str) -> List[str]:
"""
将代码切分为token列表。
Args:
code (str): 代码字符串
Returns:
List[str]: token列表
"""
# 使用正则表达式按照代码语义单元切分
tokens = re.findall(r'\w+|[^\w\s]', code)
return tokens
def tokenize_code_with_spans(self, code: str) -> List[tuple]:
"""
将代码切分为token及其在字符串中的(start, end)位置。
Returns:
List[tuple]: [(token, start, end), ...]
"""
return [(m.group(0), m.start(), m.end()) for m in re.finditer(r'\w+|[^\w\s]', code)]
def count_tokens(self, code: str) -> int:
"""
计算代码的token数量。
Args:
code (str): 代码字符串
Returns:
int: token数量
"""
return len(self.tokenize_code(code))
def find_api_calls(self, code_content: str) -> List[Dict[str, Any]]:
"""
查找代码中的API调用,特别是形如 a.b() 或 a.b 的调用模式。
Args:
code_content (str): 完整的代码内容
Returns:
List[Dict]: API调用信息列表,每个元素包含位置、内容等信息
"""
api_calls = []
# 模式1: 匹配 object.method() 或 object.property.method() 等多级调用
# 例如: console.log(), this.state.update(), Utils.formatDate()
pattern1 = r'\b([a-zA-Z_$][\w$]*(?:\.[a-zA-Z_$][\w$]*)+)\s*\('
# 模式2: 匹配 object.property (属性访问,不带括号)
# 例如: this.props, user.name
pattern2 = r'\b([a-zA-Z_$][\w$]*(?:\.[a-zA-Z_$][\w$]*)+)\b(?!\s*\()'
# 查找方法调用(带括号)
for match in re.finditer(pattern1, code_content):
start_pos = match.start()
base_api = match.group(1)
# 查找完整的方法调用,包括参数
# 找到匹配的右括号
paren_count = 0
i = match.end()
while i < len(code_content):
if code_content[i] == '(':
paren_count += 1
elif code_content[i] == ')':
if paren_count == 0:
end_pos = i + 1
break
paren_count -= 1
i += 1
else:
# 没找到匹配的右括号,跳过
continue
api_call = code_content[start_pos:end_pos]
token_count = self.count_tokens(api_call)
api_calls.append({
'api_call': api_call,
'start_char': start_pos,
'end_char': end_pos,
'token_count': token_count,
'has_import': self._check_if_imported(base_api, code_content),
'type': 'method_call'
})
# 查找属性访问(不带括号)
for match in re.finditer(pattern2, code_content):
start_pos = match.start()
end_pos = match.end()
api_call = match.group(1)
token_count = self.count_tokens(api_call)
api_calls.append({
'api_call': api_call,
'start_char': start_pos,
'end_char': end_pos,
'token_count': token_count,
'has_import': self._check_if_imported(api_call, code_content),
'type': 'property_access'
})
# 按照位置排序
api_calls.sort(key=lambda x: x['start_char'])
return api_calls
def _check_if_imported(self, api_call: str, code_content: str) -> bool:
"""
检查API调用的对象是否来自import语句。
Args:
api_call (str): API调用字符串,如 "Utils.formatDate()"
code_content (str): 完整代码内容
Returns:
bool: 如果对象来自导入则返回True
"""
# 提取API调用的第一个标识符(对象名)
first_identifier = api_call.split('.')[0]
# 检查是否在import语句中
# 匹配 import { xxx } from 'xxx' 或 import xxx from 'xxx'
import_patterns = [
rf'import\s+{{[^}}]*\b{first_identifier}\b[^}}]*}}\s+from',
rf'import\s+{first_identifier}\s+from',
rf'import\s+\*\s+as\s+{first_identifier}\s+from'
]
for pattern in import_patterns:
if re.search(pattern, code_content):
return True
return False
def extend_to_valid_hole_length(self, code_content: str, start_pos: int, end_pos: int) -> tuple:
"""
扩展API调用的范围,使其长度在指定范围内。
两种模式:
1. 纯API模式 (pure_api=True): 只返回API调用本身
2. 扩展模式 (pure_api=False): API调用 + 随机前后上下文
- 随机选择需要补充的token数量
- 随机决定在前面补还是后面补
- 总长度在 min_hole_length 到 max_hole_length 之间
Args:
code_content (str): 完整代码内容
start_pos (int): API调用起始位置
end_pos (int): API调用结束位置
Returns:
tuple: (新的start_pos, 新的end_pos, 扩展后的代码)
"""
# 获取API调用的实际token数
api_code = code_content[start_pos:end_pos]
api_token_count = self.count_tokens(api_code)
if self.pure_api:
# 纯API模式:不需要额外的上下文token
extra_tokens_needed = 0
else:
# 扩展模式:计算需要补充的token数量
# 目标总长度在 min_hole_length 到 max_hole_length 之间
target_length = random.randint(self.min_hole_length, self.max_hole_length)
extra_tokens_needed = max(0, target_length - api_token_count)
# 随机决定补充策略:0=前面补,1=后面补,2=前后都补
supplement_strategy = random.randint(0, 2)
if supplement_strategy == 0:
# 只在前面补充
tokens_before = extra_tokens_needed
tokens_after = 0
elif supplement_strategy == 1:
# 只在后面补充
tokens_before = 0
tokens_after = extra_tokens_needed
else:
# 前后都补,随机分配
tokens_before = random.randint(0, extra_tokens_needed)
tokens_after = extra_tokens_needed - tokens_before
# 获取前面的上下文tokens
left_context = code_content[:start_pos]
left_tokens = self.tokenize_code(left_context)
# 获取后面的上下文tokens
right_context = code_content[end_pos:]
right_tokens = self.tokenize_code(right_context)
# 检查是否有足够的上下文
if len(left_tokens) < tokens_before or len(right_tokens) < tokens_after:
# 如果上下文不够,尝试调整策略
available_before = min(len(left_tokens), tokens_before)
available_after = min(len(right_tokens), tokens_after)
# 如果总的可用token不够,返回None
if available_before + available_after < extra_tokens_needed // 2:
return None, None, None
tokens_before = available_before
tokens_after = available_after
# 计算新的起始和结束位置
# 从左边context取需要的tokens
if tokens_before > 0:
left_supplement_tokens = left_tokens[-tokens_before:]
left_supplement_text = ' '.join(left_supplement_tokens)
# 在原始代码中找到这些tokens的起始位置
# 简化处理:从start_pos向前回溯
new_start = start_pos
token_count = 0
while new_start > 0 and token_count < tokens_before:
new_start -= 1
temp_text = code_content[new_start:start_pos]
token_count = len(self.tokenize_code(temp_text))
else:
new_start = start_pos
# 从右边context取需要的tokens
if tokens_after > 0:
right_supplement_tokens = right_tokens[:tokens_after]
right_supplement_text = ' '.join(right_supplement_tokens)
# 在原始代码中找到这些tokens的结束位置
# 简化处理:从end_pos向后扩展
new_end = end_pos
token_count = 0
while new_end < len(code_content) and token_count < tokens_after:
new_end += 1
temp_text = code_content[end_pos:new_end]
token_count = len(self.tokenize_code(temp_text))
else:
new_end = end_pos
# 获取最终的代码片段
final_code = code_content[new_start:new_end]
final_tokens = self.count_tokens(final_code)
# 验证最终长度
if self.pure_api:
# 纯API模式:只要API调用有效即可,不强制长度限制
if final_tokens == 0:
return None, None, None
else:
# 扩展模式:验证长度是否在范围内
if final_tokens < self.min_hole_length or final_tokens > self.max_hole_length:
# 如果不在范围内,尝试微调
if final_tokens > self.max_hole_length:
# 太长了,截断
spans = self.tokenize_code_with_spans(final_code)
if not spans:
return None, None, None
truncate_end = spans[self.max_hole_length - 1][2] if len(spans) >= self.max_hole_length else spans[-1][2]
final_code = final_code[:truncate_end]
return new_start, new_start + truncate_end, final_code
else:
# 太短了,返回None
return None, None, None
return new_start, new_end, final_code
def generate_dependency(self, related_files: List[str]) -> List[CodeBlock]:
"""
根据相关文件路径列表生成CodeBlock对象列表。
Args:
related_files (List[str]): 相关文件路径列表
Returns:
List[CodeBlock]: CodeBlock对象列表,包含文件内容
"""
dependency = []
if not related_files:
return []
for file_path in related_files:
try:
with open(file_path, "r", encoding='utf-8') as f:
code_content = f.read()
dependency.append(CodeBlock(file_path, code_content, "arkts"))
logger.debug(f"成功读取依赖文件: {file_path}")
except (FileNotFoundError, UnicodeDecodeError) as e:
logger.warning(f"读取文件失败 {file_path}: {e}")
continue
logger.info(f"生成 {len(dependency)} 个依赖文件")
return dependency
def generate_code_data(self, file_path: str, related_files: List[CodeBlock]) -> Optional[CodeData]:
"""
为指定文件生成代码补全训练样本。
该方法会查找文件中的API调用(特别是依赖其他文件的API),并将其作为补全目标。
确保挖空长度在8-16个token之间。
Args:
file_path (str): 目标文件路径
related_files (List[CodeBlock]): 相关文件列表,提供跨文件上下文
Returns:
Optional[CodeData]: 生成的代码补全样本,如果生成失败则返回None
"""
try:
with open(file_path, "r", encoding='utf-8') as f:
code_content = f.read()
# 验证文件长度
total_tokens = self.count_tokens(code_content)
min_file_tokens = 50 # 最少需要50个token
if total_tokens < min_file_tokens:
logger.warning(f"文件太短,跳过: {file_path} (仅有 {total_tokens} 个token)")
return None
# 查找所有API调用
api_calls = self.find_api_calls(code_content)
if not api_calls:
logger.warning(f"未找到API调用,跳过: {file_path}")
return None
# 优先选择来自import的API调用
imported_apis = [api for api in api_calls if api['has_import']]
if imported_apis:
candidates = imported_apis
logger.debug(f"找到 {len(imported_apis)} 个导入的API")
else:
candidates = api_calls
logger.debug(f"找到 {len(api_calls)} 个普通API")
# 随机尝试多个候选,直到找到合适长度的
random.shuffle(candidates)
for selected_api in candidates:
original_start = selected_api['start_char']
original_end = selected_api['end_char']
# 扩展或截断到合适的长度
new_start, new_end, target_code = self.extend_to_valid_hole_length(
code_content, original_start, original_end
)
if new_start is None:
continue
# 验证长度
target_tokens = self.count_tokens(target_code)
if not self.pure_api:
# 扩展模式下验证长度范围
if target_tokens < self.min_hole_length or target_tokens > self.max_hole_length:
continue
# 生成左右上下文
left_context = code_content[:new_start].strip()
right_context = code_content[new_end:].strip()
# 验证上下文不为空
if not left_context or not right_context:
continue
# 生成唯一任务ID
task_id = str(self.task_id_counter)
self.task_id_counter += 1
api_info = {
'original_api': selected_api['api_call'],
'type': selected_api['type'],
'has_import': selected_api['has_import'],
'token_count': target_tokens
}
logger.debug(f"为文件 {file_path} 生成代码样本")
logger.debug(f" 任务ID: {task_id}")
logger.debug(f" 目标代码: {target_code[:50]}...")
logger.debug(f" Token数: {target_tokens}")
logger.debug(f" 是否导入: {selected_api['has_import']}")
return CodeData(
file_path=file_path,
left_context=left_context,
right_context=right_context,
related_files=related_files,
target_code=target_code,
language="arkts",
task_id=task_id,
api_info=api_info
)
logger.warning(f"未找到合适长度的API调用,跳过: {file_path}")
return None
except (FileNotFoundError, UnicodeDecodeError) as e:
logger.error(f"生成代码数据失败 {file_path}: {e}")
return None
def save_data(self, output_path: str) -> None:
"""
将生成的数据集保存为JSON文件。
Args:
output_path (str): 输出文件路径
"""
try:
# 将CodeData对象转换为字典以便JSON序列化
serializable_data = {}
for file_path, code_data in self.data.items():
serializable_data[file_path] = asdict(code_data)
# 确保输出目录存在
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
# 保存数据
with open(output_file, "w", encoding='utf-8') as f:
json.dump(serializable_data, f, ensure_ascii=False, indent=2)
logger.info(f"数据集已保存到: {output_path}")
logger.info(f"总样本数: {len(serializable_data)}")
# 统计信息
if serializable_data:
total_tokens = sum(data['api_info']['token_count'] for data in serializable_data.values())
avg_tokens = total_tokens / len(serializable_data)
imported_count = sum(1 for data in serializable_data.values() if data['api_info']['has_import'])
logger.info(f"平均挖空长度: {avg_tokens:.1f} tokens")
logger.info(f"导入API比例: {imported_count}/{len(serializable_data)} ({imported_count/len(serializable_data)*100:.1f}%)")
except Exception as e:
logger.error(f"保存数据集失败: {e}")
raise
def generate_data(self, max_layers: Optional[int] = None, output_path: str | None = None) -> None:
"""
生成代码补全数据集。
Args:
max_layers (Optional[int]): 最大处理层数,None表示处理所有层
"""
logger.info("开始生成API代码补全数据集")
# 获取依赖关系和分层信息
dependencies = self.metadata.get("dependencies", {})
layers = self.topo_sort.get("layers", {})
if not layers:
logger.error("未找到拓扑排序的层级信息")
return
# 统计信息
total_files = 0
processed_files = 0
failed_files = 0
# 按层级处理文件
layer_names = sorted(layers.keys(), key=lambda x: int(x.split('_')[1]))
if max_layers:
layer_names = layer_names[:max_layers]
for layer_name in layer_names:
layer_files = layers[layer_name]
total_files += len(layer_files)
logger.info(f"处理 {layer_name},包含 {len(layer_files)} 个文件")
for file_path in layer_files:
try:
# 获取相关文件(依赖)
related_file_paths = dependencies.get(file_path, [])
# 生成依赖文件的CodeBlock对象
code_dependencies = self.generate_dependency(related_file_paths)
# 生成代码补全样本
code_data = self.generate_code_data(file_path, code_dependencies)
if code_data:
self.data[file_path] = code_data
processed_files += 1
logger.debug(f"成功处理文件: {file_path}")
else:
failed_files += 1
logger.warning(f"跳过文件: {file_path}")
except Exception as e:
failed_files += 1
logger.error(f"处理文件失败 {file_path}: {e}")
# 输出统计信息
logger.info(f"数据生成完成:")
logger.info(f" 总文件数: {total_files}")
logger.info(f" 成功处理: {processed_files}")
logger.info(f" 失败跳过: {failed_files}")
logger.info(f" 成功率: {processed_files/total_files*100:.1f}%" if total_files > 0 else " 成功率: 0%")
if output_path is None:
output_path = f"./pure_api_data_2/{str(self.metadata_path).split('/')[-1].split('_')[0]}_api_code_data.json"
# 保存数据集
if self.data:
self.save_data(output_path)
else:
logger.warning("没有生成任何数据,跳过保存")
def main():
"""
主函数:解析命令行参数并生成代码补全数据集。
"""
parser = argparse.ArgumentParser(
description="HarmonyOS ETS代码补全数据集生成工具 - API版本",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例用法:
# 使用默认参数(纯API补全模式)
python generate_api_dataset.py
# 使用扩展上下文模式(API + 随机前后上下文)
python generate_api_dataset.py --with_context
# 指定自定义路径
python generate_api_dataset.py --metadata_path ./metadata.json
# 自定义挖空长度范围(扩展模式下生效)
python generate_api_dataset.py --with_context --min_hole_length 10 --max_hole_length 20
# 只处理前3层文件
python generate_api_dataset.py --max_layers 3
# 开启详细日志
python generate_api_dataset.py --verbose
模式说明:
--pure_api (默认): 只挖空API调用本身,如 Utils.formatDate()
--with_context: API调用 + 随机前后上下文,总长度在 min_hole_length 到 max_hole_length 之间
"""
)
parser.add_argument(
"--metadata_path",
type=str,
default="intermedia_result/FinanceTemplate_metadata.json",
help="项目元数据文件路径 (默认: intermedia_result/FinanceTemplate_metadata.json)"
)
parser.add_argument(
"--topo_sort_path",
type=str,
default="intermedia_result/FinanceTemplate_topological_sort.json",
help="拓扑排序结果文件路径 (默认: intermedia_result/FinanceTemplate_topological_sort.json)"
)
parser.add_argument(
"--min_hole_length",
type=int,
default=8,
help="挖空最小长度(token数,默认: 8)"
)
parser.add_argument(
"--max_hole_length",
type=int,
default=16,
help="挖空最大长度(token数,默认: 16)"
)
parser.add_argument(
"--max_layers",
type=int,
default=None,
help="最大处理层数,用于限制处理范围 (默认: 处理所有层)"
)
parser.add_argument(
"--output_path",
type=str,
default=None,
help="输出数据集文件路径 (默认: ./pure_api_data_2/<name>_api_code_data.json)",
)
parser.add_argument(
"--pure_api",
action="store_true",
default=True,
help="使用纯API补全模式,只挖空API调用本身 (默认: True)"
)
parser.add_argument(
"--with_context",
action="store_true",
help="使用扩展上下文模式,API调用 + 随机前后上下文 (与 --pure_api 互斥)"
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="开启详细日志输出"
)
args = parser.parse_args()
# 设置日志级别
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
logger.info("开启详细日志模式")
# 处理模式选择:--with_context 会覆盖 --pure_api
pure_api = not args.with_context
try:
# 创建数据集生成器
logger.info("初始化API代码数据集生成器")
generator = APICodeDataGenerator(
args.metadata_path,
args.topo_sort_path,
args.min_hole_length,
args.max_hole_length,
pure_api=pure_api
)
# 生成数据集
generator.generate_data(args.max_layers, args.output_path)
logger.info("程序执行完成")
except Exception as e:
logger.error(f"程序执行失败: {e}")
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())