Skip to content

Commit ea45038

Browse files
committed
完善自动获取url与url合法性校验脚本
1 parent 0604a87 commit ea45038

File tree

2 files changed

+181
-39
lines changed

2 files changed

+181
-39
lines changed

docs/guides/model_convert/convert_from_pytorch/tools/generate_pytorch_api_mapping.py

Lines changed: 133 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,57 +17,165 @@ def get_pytorch_url(torch_api: str) -> str:
1717
对应API的官方文档URL字符串
1818
1919
Rules:
20-
1. Tensor相关API指向tensors.html
21-
2. 顶层函数(torch.xxx)指向torch.html
22-
3. 模块级函数/常量指向模块名.html(如nn.init.html)
23-
4. 类/独立函数指向generated/[name].html
24-
5. 类方法指向父类页面#锚点
25-
6. 特殊处理torchvision等子库的URL结构
20+
1. 优先检查特殊映射
21+
2. 优先检查是否有专门的generated页面
22+
3. 类方法指向父类页面#锚点
23+
4. 模块级函数/常量指向模块名.html
24+
5. Tensor相关API指向tensors.html
25+
6. 顶层函数(torch.xxx)指向torch.html
26+
7. 特殊处理torchvision等子库的URL结构
2627
"""
2728
base_url = "https://pytorch.org/docs/stable/"
28-
api_name = torch_api.replace(r"\_", "_")
29+
torch_api = torch_api.replace(r"\_", "_")
30+
31+
# 特殊映射:手动指定已知问题API的正确URL
32+
special_mappings = {
33+
"torch.cuda.check_error": "generated/torch.cuda.cudart.html",
34+
"torch.cuda.mem_get_info": "generated/torch.cuda.memory.mem_get_info.html",
35+
"torch.nn.attention.sdpa_kernel": "generated/torch.nn.attention.sdpa_kernel.html",
36+
"torch.torch.int32": "tensors.html#torch.int32",
37+
"torch.nn.attention._cur_sdpa_kernel_backends": "nn.attention.html#torch.nn.attention.sdpa_kernel",
38+
"torch.cuda.memory_reserved": "generated/torch.cuda.memory.memory_reserved.html",
39+
"torch.cuda.memory_allocated": "generated/torch.cuda.memory.memory_allocated.html",
40+
"torch.cuda.empty_cache": "generated/torch.cuda.memory.empty_cache.html",
41+
}
42+
43+
# 检查特殊映射
44+
if torch_api in special_mappings:
45+
return f"{base_url}{special_mappings[torch_api]}"
46+
47+
# 优先检查是否有专门的generated页面
48+
generated_apis = {
49+
"torch.pow": "generated/torch.pow.html",
50+
"torch.nn.utils.parameters_to_vector": "generated/torch.nn.utils.parameters_to_vector.html",
51+
"torch.nn.utils.vector_to_parameters": "generated/torch.nn.utils.vector_to_parameters.html",
52+
"torch.nn.Module": "generated/torch.nn.Module.html",
53+
}
54+
55+
if torch_api in generated_apis:
56+
return f"{base_url}{generated_apis[torch_api]}"
57+
58+
# 特殊处理:类方法(如torch.nn.Module.to)
59+
if torch_api.startswith("torch.nn.Module."):
60+
return f"{base_url}generated/torch.nn.Module.html#{torch_api}"
61+
62+
if torch_api.startswith("torch.linalg.") or torch_api.startswith(
63+
"torch.cuda."
64+
):
65+
return f"{base_url}generated/{torch_api}.html#{torch_api}"
2966

3067
# 特殊子库处理(torchvision)
31-
if api_name.startswith("torchvision."):
68+
if torch_api.startswith("torchvision."):
3269
vision_base = "https://pytorch.org/vision/stable/"
33-
if api_name == "torchvision.models":
70+
if torch_api == "torchvision.models":
3471
return f"{vision_base}models.html"
35-
return f"{vision_base}generated/{api_name}.html#{api_name}"
72+
return f"{vision_base}generated/{torch_api}.html#{torch_api}"
73+
74+
# 特殊处理:torch.__version__相关
75+
if torch_api.startswith("torch.__version__"):
76+
return base_url # 版本信息通常在首页
77+
78+
# 特殊处理:torch.distributed.ReduceOp枚举值
79+
if torch_api.startswith("torch.distributed.ReduceOp."):
80+
return f"{base_url}distributed.html#{torch_api}"
81+
82+
# 特殊处理:torch.autograd.Function
83+
if torch_api == "torch.autograd.Function":
84+
return f"{base_url}autograd.html#{torch_api}"
85+
86+
# 特殊处理:torch.utils.cpp_extension
87+
if torch_api.startswith("torch.utils.cpp_extension"):
88+
return f"{base_url}cpp_extension.html#{torch_api}"
3689

3790
# 1. 处理Tensor相关API
38-
if api_name.startswith("torch.Tensor") or api_name == "torch.Tensor":
39-
return f"{base_url}tensors.html#{api_name}"
91+
if torch_api.startswith("torch.Tensor") or torch_api == "torch.Tensor":
92+
return f"{base_url}tensors.html#{torch_api}"
4093

4194
# 2. 处理顶层函数(无子模块)
42-
if len(api_name.split(".")) == 2 and api_name.startswith("torch."):
43-
return f"{base_url}torch.html#{api_name}"
95+
if len(torch_api.split(".")) == 2 and torch_api.startswith("torch."):
96+
# 检查是否有专门的generated页面
97+
generated_check = [
98+
"torch.pow",
99+
"torch.abs",
100+
"torch.add",
101+
"torch.sub",
102+
"torch.mul",
103+
"torch.div",
104+
"torch.exp",
105+
"torch.log",
106+
"torch.sin",
107+
"torch.cos",
108+
"torch.tan",
109+
"torch.sigmoid",
110+
]
111+
112+
if any(torch_api.startswith(prefix) for prefix in generated_check):
113+
return f"{base_url}generated/{torch_api}.html"
114+
return f"{base_url}torch.html#{torch_api}"
44115

45116
# 分割API路径
46-
parts = api_name.split(".")
117+
parts = torch_api.split(".")
47118
module_path = ".".join(parts[:-1]) # 模块路径
48119
item_name = parts[-1] # 最后一项名称
49120

121+
# 特殊处理:torch.functional函数
122+
if parts[0] == "torch" and parts[1] == "functional":
123+
return f"{base_url}torch.html#{torch_api}"
124+
50125
# 3. 处理模块级函数/常量
51126
if parts[0] == "torch" and not parts[-1][0].isupper():
52127
# 特殊模块映射(基于官方文档结构)
53128
module_map = {
54-
"torch.nn.init": "nn.init",
55-
"torch.nn.functional": "nn.functional",
56-
"torch.cuda.amp": "amp",
57-
"torch.distributions": "distributions",
129+
"torch.nn.init": "nn.init.html",
130+
"torch.nn.functional": "nn.functional.html",
131+
"torch.cuda.amp": "amp.html",
132+
"torch.distributions": "distributions.html",
133+
"torch.nn.utils": "nn.utils.html",
134+
"torch.optim": "optim.html",
135+
"torch.random": "random.html",
136+
"torch.special": "special.html",
137+
"torch.distributed": "distributed.html",
138+
"torch.utils.data": "data.html",
58139
}
59140
module_key = ".".join(parts[:-1])
60-
module_slug = module_map.get(
61-
module_key, module_key.replace("torch.", "")
62-
)
63-
return f"{base_url}{module_slug}.html#{api_name}"
141+
module_slug = module_map.get(module_key, f"generated/{module_key}.html")
142+
143+
# 检查是否是应该指向generated目录的API
144+
generated_modules = [
145+
"torch.nn.utils.parameters_to_vector",
146+
"torch.nn.utils.vector_to_parameters",
147+
]
148+
149+
if torch_api in generated_modules:
150+
return f"{base_url}generated/{torch_api}.html"
151+
152+
return f"{base_url}{module_slug}#{torch_api}"
64153

65154
# 4. 处理类/独立函数
66155
if parts[-1][0].isupper() or len(parts) == 1:
67-
return f"{base_url}generated/{api_name}.html#{api_name}"
156+
# 特殊类映射
157+
class_map = {
158+
"torch.autograd.Function": "autograd.html",
159+
"torch.utils.cpp_extension.BuildExtension": "cpp_extension.html",
160+
"torch.nn.Module": "generated/torch.nn.Module.html",
161+
}
162+
if torch_api in class_map:
163+
return f"{base_url}{class_map[torch_api]}#{torch_api}"
164+
return f"{base_url}generated/{torch_api}.html#{torch_api}"
68165

69166
# 5. 默认处理(类方法)
70-
return f"{base_url}generated/{module_path}.html#{api_name}"
167+
# 特殊处理类方法
168+
class_method_map = {
169+
"torch.nn.Module": "generated/torch.nn.Module.html",
170+
"torch.utils.cpp_extension.BuildExtension": "cpp_extension.html",
171+
}
172+
173+
for class_name, page_name in class_method_map.items():
174+
if module_path == class_name:
175+
return f"{base_url}{page_name}#{torch_api}"
176+
177+
# 默认情况下,尝试生成到generated目录
178+
return f"{base_url}generated/{module_path}.html#{torch_api}"
71179

72180

73181
def escape_underscores_in_api(api_name):

docs/guides/model_convert/convert_from_pytorch/tools/validate_api_difference.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,41 @@
11
import argparse
22
import concurrent.futures
33
import os
4+
import random
45
import re
6+
import time
57
from collections import defaultdict
68
from urllib.parse import urlparse
79

810
import requests
11+
from requests.adapters import HTTPAdapter
912
from tqdm import tqdm # 用于显示进度条
13+
from urllib3.util.retry import Retry
1014

1115
# 默认文件路径
1216
DEFAULT_FILE_PATH = "/workspace/paddleDocs/docs/guides/model_convert/convert_from_pytorch/pytorch_api_mapping_cn.md"
1317

1418
# 用户代理头,模拟浏览器访问
1519
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36"
1620

21+
# 重试策略配置
22+
RETRY_STRATEGY = Retry(
23+
total=3,
24+
backoff_factor=0.5,
25+
status_forcelist=[429, 500, 502, 503, 504],
26+
allowed_methods=["HEAD", "GET"],
27+
)
28+
29+
30+
def create_session():
31+
"""创建带有重试机制的会话"""
32+
session = requests.Session()
33+
adapter = HTTPAdapter(max_retries=RETRY_STRATEGY)
34+
session.mount("http://", adapter)
35+
session.mount("https://", adapter)
36+
session.headers.update({"User-Agent": USER_AGENT})
37+
return session
38+
1739

1840
def parse_toc(lines):
1941
"""
@@ -324,7 +346,7 @@ def is_valid_url(url):
324346
return False
325347

326348

327-
def check_url_exists(url_info):
349+
def check_url_exists(url_info, session=None):
328350
"""
329351
检查URL是否存在(是否返回404)
330352
返回状态码和错误信息
@@ -340,21 +362,21 @@ def check_url_exists(url_info):
340362
"url_info": url_info,
341363
}
342364

343-
# 设置请求头
344-
headers = {"User-Agent": USER_AGENT}
365+
# 添加随机延迟,避免请求过于频繁
366+
time.sleep(random.uniform(0.5, 1.5))
367+
368+
# 创建会话(如果未提供)
369+
if session is None:
370+
session = create_session()
345371

346372
try:
347373
# 发送HEAD请求(更快,节省带宽)
348-
response = requests.head(
349-
url, headers=headers, timeout=10, allow_redirects=True
350-
)
374+
response = session.head(url, timeout=10, allow_redirects=True)
351375
status_code = response.status_code
352376

353377
# 如果HEAD请求不被支持(405错误),则尝试GET请求
354378
if status_code == 405:
355-
response = requests.get(
356-
url, headers=headers, timeout=10, allow_redirects=True
357-
)
379+
response = session.get(url, timeout=10, allow_redirects=True)
358380
status_code = response.status_code
359381

360382
# 根据状态码判断URL是否存在
@@ -409,6 +431,9 @@ def check_urls_exist(urls_with_context, max_workers=10):
409431
返回警告列表
410432
"""
411433
warnings = []
434+
435+
urls_with_context = urls_with_context[-700:]
436+
412437
total_urls = len(urls_with_context)
413438

414439
print(
@@ -421,11 +446,16 @@ def check_urls_exist(urls_with_context, max_workers=10):
421446
max_workers=max_workers
422447
) as executor,
423448
):
449+
# 为每个线程创建一个会话
450+
sessions = [create_session() for _ in range(max_workers)]
451+
424452
# 提交所有任务
425-
future_to_url = {
426-
executor.submit(check_url_exists, url_info): url_info
427-
for url_info in urls_with_context
428-
}
453+
future_to_url = {}
454+
for i, url_info in enumerate(urls_with_context):
455+
# 分配会话给任务(轮询方式)
456+
session = sessions[i % max_workers]
457+
future = executor.submit(check_url_exists, url_info, session)
458+
future_to_url[future] = url_info
429459

430460
# 处理完成的任务
431461
for future in concurrent.futures.as_completed(future_to_url):
@@ -445,6 +475,10 @@ def check_urls_exist(urls_with_context, max_workers=10):
445475
warning_msg += f"状态码: {result['status_code']}\n"
446476
warnings.append(warning_msg)
447477

478+
# 关闭所有会话
479+
for session in sessions:
480+
session.close()
481+
448482
print(f"URL检查完成,发现 {len(warnings)} 个问题")
449483
return warnings
450484

@@ -479,7 +513,7 @@ def main():
479513
# 检查文件是否存在
480514
if not os.path.exists(md_file_path):
481515
print(f"错误: 文件 '{md_file_path}' 不存在")
482-
print("请使用 --file 参数指定正确的文件路径")
516+
print("请使用 --file 参数指定文件路径")
483517
return
484518

485519
# 读取文件所有行

0 commit comments

Comments
 (0)