diff --git a/docs/guides/model_convert/convert_from_pytorch/tools/.gitignore b/docs/guides/model_convert/convert_from_pytorch/tools/.gitignore index 4204846c164..cb85c8cbd72 100644 --- a/docs/guides/model_convert/convert_from_pytorch/tools/.gitignore +++ b/docs/guides/model_convert/convert_from_pytorch/tools/.gitignore @@ -26,3 +26,4 @@ unique_warnings.txt link_warnings.txt url_warnings.txt api_difference_error.txt +diff_doc_warnings.txt diff --git a/docs/guides/model_convert/convert_from_pytorch/tools/generate_pytorch_api_mapping.py b/docs/guides/model_convert/convert_from_pytorch/tools/generate_pytorch_api_mapping.py index c4af7f44da1..8d0a2088ba2 100644 --- a/docs/guides/model_convert/convert_from_pytorch/tools/generate_pytorch_api_mapping.py +++ b/docs/guides/model_convert/convert_from_pytorch/tools/generate_pytorch_api_mapping.py @@ -167,20 +167,47 @@ def generate_category2_table( rows = [] # 存储表格行数据的列表 used_apis = set() # 用于记录已处理的API,避免重复 + invok_diff_matchers = { + "ChangeAPIMatcher", + "NumelMatcher", + "Is_InferenceMatcher", + } + + special_matchers = { + "TensorFunc2PaddleFunc", + "Func2Attribute", + "Attribute2Func", + } + # 处理api_mapping中Matcher为"UnchangeMatcher"且不在no_need_convert_list中的API for src_api, mapping_info in api_mapping_data.items(): if src_api in whitelist_skip or src_api in no_need_convert_list: continue matcher = mapping_info.get("Matcher", "") + valid = False # ChangeAPIMatcher、TensorFunc2PaddleFunc、Func2Attribute、Attribute2Func类别 - if matcher in [ - "ChangeAPIMatcher", - "TensorFunc2PaddleFunc", - "Func2Attribute", - "Attribute2Func", - "NumelMatcher", - "Is_InferenceMatcher", - ]: + if matcher in special_matchers: + has_unsupport_args = "unsupport_args" in mapping_info + has_kwargs_change = "kwargs_change" in mapping_info + has_paddle_default_kwargs = "paddle_default_kwargs" in mapping_info + if has_unsupport_args: + print( + f"[torch_more_args] {src_api} -> {mapping_info.get('paddle_api', 'N/A')}" + ) + continue + elif has_kwargs_change: + print( + f"[args_name_diff] {src_api} -> {mapping_info.get('paddle_api', 'N/A')}" + ) + continue + elif has_paddle_default_kwargs: + print( + f"[paddle_more_args_or_default_diff] {src_api} -> {mapping_info.get('paddle_api', 'N/A')}" + ) + continue + valid = True + + if matcher in invok_diff_matchers or valid: # 在docs_mapping中查找当前src_api对应的信息 docs_mapping_info = docs_mapping.get(src_api, {}) src_url = docs_mapping_info.get("src_api_url") diff --git a/docs/guides/model_convert/convert_from_pytorch/tools/validate_api_difference.py b/docs/guides/model_convert/convert_from_pytorch/tools/validate_api_difference_format.py similarity index 100% rename from docs/guides/model_convert/convert_from_pytorch/tools/validate_api_difference.py rename to docs/guides/model_convert/convert_from_pytorch/tools/validate_api_difference_format.py diff --git a/docs/guides/model_convert/convert_from_pytorch/tools/validate_pytorch_api_mapping.py b/docs/guides/model_convert/convert_from_pytorch/tools/validate_pytorch_api_mapping.py index 0066d839abe..551ff78582e 100644 --- a/docs/guides/model_convert/convert_from_pytorch/tools/validate_pytorch_api_mapping.py +++ b/docs/guides/model_convert/convert_from_pytorch/tools/validate_pytorch_api_mapping.py @@ -15,6 +15,22 @@ # 默认文件路径 DEFAULT_FILE_PATH = "/workspace/paddleDocs/docs/guides/model_convert/convert_from_pytorch/pytorch_api_mapping_cn.md" +# 类别映射关系 +CATEGORY_MAP = { + "invok_diff_only": "仅 API 调用方式不一致", + "args_name_diff": "仅参数名不一致", + "paddle_more_args": "paddle 参数更多", + "args_default_value_diff": "参数默认值不一致", + "torch_more_args": "torch 参数更多", + "input_args_usage_diff": "输入参数用法不一致", + "input_args_type_diff": "输入参数类型不一致", + "output_args_type_diff": "返回参数类型不一致", + "composite_implement": "组合替代实现", +} + +# 反向映射(中文到英文) +REVERSE_CATEGORY_MAP = {v: k for k, v in CATEGORY_MAP.items()} + USER_AGENT = "" # 重试策略配置 @@ -64,7 +80,7 @@ def parse_toc(lines): continue # 跳过分隔行 if found_header and line.startswith("|"): - # 解析数据行 + # 解析数据行 - 现在表格有5列,但我们只关心前2列(序号和类别名称) columns = [col.strip() for col in line.split("|")[1:-1]] if len(columns) >= 2: toc.append((columns[0], columns[1])) # (序号, 类别名称) @@ -79,7 +95,7 @@ def parse_toc(lines): def parse_categories(lines): """ - 解析类别部分,提取所有类别及其表格数据 + 解析类别部分,提取所有类别及其表格数据(适应五列表格结构) """ categories = [] current_category = None @@ -111,8 +127,13 @@ def parse_categories(lines): in_table = False continue - # 检测表格开始 - if line.startswith("|") and "Pytorch" in line and "Paddle" in line: + # 检测表格开始(五列表格) + if ( + line.startswith("|") + and "Pytorch" in line + and "Paddle" in line + and "映射分类" in line + ): in_table = True continue @@ -122,15 +143,16 @@ def parse_categories(lines): if re.match(r"^\|?[-:\s|]+\|?$", line): continue - # 解析表格行 + # 解析五列表格行 columns = [col.strip() for col in line.split("|")[1:-1]] - if len(columns) >= 4: + if len(columns) >= 5: # 现在有5列 current_table.append( { "index": columns[0], "pytorch": columns[1], "paddle": columns[2], - "note": columns[3] if len(columns) > 3 else "", + "mapping_category": columns[3], # 新增的映射分类列 + "note": columns[4] if len(columns) > 4 else "", } ) continue @@ -159,6 +181,26 @@ def extract_links(text): return re.findall(r"\[([^\]]+)\]\(([^)]+)\)", text) +def extract_torch_api_name(pytorch_column): + """ + 从Pytorch列提取Torch API名称(链接文本),并将转义的下划线恢复为普通下划线 + """ + links = extract_links(pytorch_column) + if links: + api_name = links[0][0] # 返回第一个链接的文本 + # 将转义的下划线 "\_" 替换为普通下划线 "_" + api_name = api_name.replace(r"\_", "_") + return api_name + + # 如果没有链接,尝试直接提取文本内容 + clean_text = re.sub(r"[\[\]\(\)]", "", pytorch_column).strip() + if clean_text: + # 同样处理转义的下划线 + clean_text = clean_text.replace(r"\_", "_") + return clean_text + return None + + def check_toc_consistency(toc, categories): """ 检查目录与类别标题的一致性 @@ -198,9 +240,8 @@ def check_unique_torch_apis(categories): for category in categories: for row in category["table"]: - links = extract_links(row["pytorch"]) - if links: - api_name = links[0][0] # 取第一个链接的文本作为 API 名称 + api_name = extract_torch_api_name(row["pytorch"]) + if api_name: torch_apis[api_name].append(category["id"]) # 检查重复的 API @@ -215,13 +256,14 @@ def check_unique_torch_apis(categories): def check_links_exist(categories): """ - 检查必要的超链接是否存在(根据新规则) + 检查必要的超链接是否存在(适应五列表格结构) 规则: 1. 第二列(Pytorch)必须有超链接 2. 第三列(Paddle): - 对于"组合替代实现"、"可删除"、"功能缺失"类别,如果内容为空或"-"则不检查 - 否则必须有超链接 - 3. 第四列(备注): 除了"API完全一致类别"(类别1)外,都需要有超链接 + 3. 第五列(备注): 除了"API完全一致类别"(类别1)外,都需要有超链接 + 4. 第四列(映射分类)不检查超链接,但需要检查内容一致性(在另一个函数中处理) """ warnings = [] @@ -263,7 +305,7 @@ def check_links_exist(categories): warning_msg = f"类别 {category_id}({category_name}) 第 {row_num} 行第三列缺少超链接: {row['paddle']}" warnings.append(warning_msg) - # 3. 检查第四列 (备注) + # 3. 检查第五列 (备注) note_links = extract_links(row["note"]) note_content = row["note"].strip() @@ -274,8 +316,135 @@ def check_links_exist(categories): and note_content != "-" and not note_links ): - warning_msg = f"类别 {category_id}({category_name}) 第 {row_num} 行第四列缺少超链接: {row['note']}" + warning_msg = f"类别 {category_id}({category_name}) 第 {row_num} 行第五列缺少超链接: {row['note']}" + warnings.append(warning_msg) + + return warnings + + +def check_mapping_category_consistency(categories): + """ + 检查映射分类列内容与类别标题的一致性 + """ + warnings = [] + + for category in categories: + category_id = category["id"] + category_name = category["name"] + + for i, row in enumerate(category["table"]): + row_num = i + 1 + mapping_category = row.get("mapping_category", "").strip() + + # 检查映射分类列是否与类别标题一致 + if mapping_category != category_name: + warning_msg = ( + f"类别 {category_id}({category_name}) 第 {row_num} 行映射分类不一致:\n" + f" 表格中的映射分类: '{mapping_category}'\n" + f" 类别标题: '{category_name}'" + ) + warnings.append(warning_msg) + + return warnings + + +def check_diff_doc_consistency(categories, base_dir): + """ + 检查映射文档和差异文档的一致性 + """ + warnings = [] + diff_doc_base = os.path.join(base_dir, "api_difference") + + # 检查差异文档目录是否存在 + if not os.path.exists(diff_doc_base): + warnings.append(f"差异文档根目录不存在: {diff_doc_base}") + return warnings + + # 构建API到类别的映射,用于反向检查 + expected_apis = defaultdict(set) # category_name -> set of torch_apis + found_apis = defaultdict(set) # category_name -> set of torch_apis + + # 第一步:检查表格中的每个API是否有对应的差异文档 + for category in categories: + category_name = category["name"] + + # 检查这个类别是否需要差异文档 + if category_name not in REVERSE_CATEGORY_MAP: + continue # 跳过不需要差异文档的类别 + + category_en = REVERSE_CATEGORY_MAP[category_name] + diff_category_dir = os.path.join(diff_doc_base, category_en) + + # 检查类别目录是否存在 + if not os.path.exists(diff_category_dir): + warning_msg = ( + f"差异文档目录不存在: {diff_category_dir}\n" + f"对应类别: {category_name}" + ) + warnings.append(warning_msg) + continue + + for row in category["table"]: + torch_api = extract_torch_api_name(row["pytorch"]) + torch_api = torch_api.replace(r"\_", "_") + if not torch_api: + continue + + expected_apis[category_name].add(torch_api) + + # 构建预期的MD文件名 + expected_md_file = f"{torch_api}.md" + expected_md_path = os.path.join(diff_category_dir, expected_md_file) + + if not os.path.exists(expected_md_path): + warning_msg = ( + f"差异文档缺失: {expected_md_file}\n" + f"对应Torch API: {torch_api}\n" + f"类别: {category_name} ({category_en})\n\n" + ) warnings.append(warning_msg) + else: + found_apis[category_name].add(torch_api) + + # 第二步:反向检查差异文档目录中的文件是否在表格中有对应 + for category_en, category_cn in CATEGORY_MAP.items(): + diff_category_dir = os.path.join(diff_doc_base, category_en) + + if not os.path.exists(diff_category_dir): + continue + + # 遍历差异文档目录中的所有.md文件 + try: + for filename in os.listdir(diff_category_dir): + if filename.endswith(".md"): + torch_api = filename[:-3] # 去掉.md后缀 + + # 检查这个API是否在对应类别的表格中 + api_found = False + for category in categories: + if category["name"] == category_cn: + for row in category["table"]: + if ( + extract_torch_api_name(row["pytorch"]) + == torch_api + ): + api_found = True + break + if api_found: + break + + if not api_found: + warning_msg = ( + f"多余的差异文档: {filename}\n" + f"对应Torch API: {torch_api}\n" + f"类别: {category_cn} ({category_en})\n" + f"该API在映射表格中不存在\n\n" + ) + warnings.append(warning_msg) + except FileNotFoundError: + warnings.append(f"差异文档目录不存在: {diff_category_dir}") + except PermissionError: + warnings.append(f"无权限访问差异文档目录: {diff_category_dir}") return warnings @@ -430,6 +599,7 @@ def check_urls_exist(urls_with_context, max_workers=10): """ warnings = [] + # 限制检查的URL数量(避免过多网络请求) urls_with_context = urls_with_context[:600] total_urls = len(urls_with_context) @@ -536,32 +706,33 @@ def main(): print(f"找到 {len(toc)} 个目录条目") print(f"找到 {len(categories)} 个类别") - # 执行三个基本校验 + # 执行基本校验 toc_warnings = check_toc_consistency(toc, categories) unique_warnings = check_unique_torch_apis(categories) link_warnings = check_links_exist(categories) + mapping_category_warnings = check_mapping_category_consistency(categories) + diff_doc_warnings = check_diff_doc_consistency(categories, base_dir) # 输出警告到文件 - if toc_warnings: - output_path = os.path.join(tools_dir, "toc_warnings.txt") - with open(output_path, "w", encoding="utf-8") as f: - f.write("目录一致性校验警告:\n") - f.writelines(warning + "\n" for warning in toc_warnings) - print(f"生成 {output_path},包含 {len(toc_warnings)} 个警告") - - if unique_warnings: - output_path = os.path.join(tools_dir, "unique_warnings.txt") - with open(output_path, "w", encoding="utf-8") as f: - f.write("Torch API 唯一性校验警告:\n") - f.writelines(warning + "\n" for warning in unique_warnings) - print(f"生成 {output_path},包含 {len(unique_warnings)} 个警告") - - if link_warnings: - output_path = os.path.join(tools_dir, "link_warnings.txt") - with open(output_path, "w", encoding="utf-8") as f: - f.write("超链接存在性校验警告:\n") - f.writelines(warning + "\n" for warning in link_warnings) - print(f"生成 {output_path},包含 {len(link_warnings)} 个警告") + warning_files = [ + ("toc_warnings.txt", "目录一致性校验警告:", toc_warnings), + ("unique_warnings.txt", "Torch API 唯一性校验警告:", unique_warnings), + ("link_warnings.txt", "超链接存在性校验警告:", link_warnings), + ( + "mapping_category_warnings.txt", + "映射分类一致性校验警告:", + mapping_category_warnings, + ), + ("diff_doc_warnings.txt", "差异文档一致性校验警告:", diff_doc_warnings), + ] + + for filename, description, warnings in warning_files: + if warnings: + output_path = os.path.join(tools_dir, filename) + with open(output_path, "w", encoding="utf-8") as f: + f.write(f"{description}\n") + f.writelines(warning + "\n" for warning in warnings) + print(f"生成 {output_path},包含 {len(warnings)} 个警告") # 执行URL存在性检查(除非明确跳过) url_warnings = [] @@ -584,14 +755,22 @@ def main(): else: print("跳过URL存在性检查") - # 如果没有警告,输出成功信息 - if ( - not toc_warnings - and not unique_warnings - and not link_warnings - and not url_warnings - ): + # 汇总统计 + total_warnings = ( + len(toc_warnings) + + len(unique_warnings) + + len(link_warnings) + + len(mapping_category_warnings) + + len(diff_doc_warnings) + + len(url_warnings) + ) + + if total_warnings == 0: print("所有校验通过,没有发现警告!") + else: + print( + f"校验完成,共发现 {total_warnings} 个警告,请查看生成的警告文件。" + ) if __name__ == "__main__":