@@ -17,57 +17,165 @@ def get_pytorch_url(torch_api: str) -> str:
1717 对应API的官方文档URL字符串
1818
1919 Rules:
20- 1. Tensor相关API指向tensors.html
21- 2. 顶层函数(torch.xxx)指向torch.html
22- 3. 模块级函数/常量指向模块名.html(如nn.init.html)
23- 4. 类/独立函数指向generated/[name].html
24- 5. 类方法指向父类页面#锚点
25- 6. 特殊处理torchvision等子库的URL结构
20+ 1. 优先检查特殊映射
21+ 2. 优先检查是否有专门的generated页面
22+ 3. 类方法指向父类页面#锚点
23+ 4. 模块级函数/常量指向模块名.html
24+ 5. Tensor相关API指向tensors.html
25+ 6. 顶层函数(torch.xxx)指向torch.html
26+ 7. 特殊处理torchvision等子库的URL结构
2627 """
2728 base_url = "https://pytorch.org/docs/stable/"
28- api_name = torch_api .replace (r"\_" , "_" )
29+ torch_api = torch_api .replace (r"\_" , "_" )
30+
31+ # 特殊映射:手动指定已知问题API的正确URL
32+ special_mappings = {
33+ "torch.cuda.check_error" : "generated/torch.cuda.cudart.html" ,
34+ "torch.cuda.mem_get_info" : "generated/torch.cuda.memory.mem_get_info.html" ,
35+ "torch.nn.attention.sdpa_kernel" : "generated/torch.nn.attention.sdpa_kernel.html" ,
36+ "torch.torch.int32" : "tensors.html#torch.int32" ,
37+ "torch.nn.attention._cur_sdpa_kernel_backends" : "nn.attention.html#torch.nn.attention.sdpa_kernel" ,
38+ "torch.cuda.memory_reserved" : "generated/torch.cuda.memory.memory_reserved.html" ,
39+ "torch.cuda.memory_allocated" : "generated/torch.cuda.memory.memory_allocated.html" ,
40+ "torch.cuda.empty_cache" : "generated/torch.cuda.memory.empty_cache.html" ,
41+ }
42+
43+ # 检查特殊映射
44+ if torch_api in special_mappings :
45+ return f"{ base_url } { special_mappings [torch_api ]} "
46+
47+ # 优先检查是否有专门的generated页面
48+ generated_apis = {
49+ "torch.pow" : "generated/torch.pow.html" ,
50+ "torch.nn.utils.parameters_to_vector" : "generated/torch.nn.utils.parameters_to_vector.html" ,
51+ "torch.nn.utils.vector_to_parameters" : "generated/torch.nn.utils.vector_to_parameters.html" ,
52+ "torch.nn.Module" : "generated/torch.nn.Module.html" ,
53+ }
54+
55+ if torch_api in generated_apis :
56+ return f"{ base_url } { generated_apis [torch_api ]} "
57+
58+ # 特殊处理:类方法(如torch.nn.Module.to)
59+ if torch_api .startswith ("torch.nn.Module." ):
60+ return f"{ base_url } generated/torch.nn.Module.html#{ torch_api } "
61+
62+ if torch_api .startswith ("torch.linalg." ) or torch_api .startswith (
63+ "torch.cuda."
64+ ):
65+ return f"{ base_url } generated/{ torch_api } .html#{ torch_api } "
2966
3067 # 特殊子库处理(torchvision)
31- if api_name .startswith ("torchvision." ):
68+ if torch_api .startswith ("torchvision." ):
3269 vision_base = "https://pytorch.org/vision/stable/"
33- if api_name == "torchvision.models" :
70+ if torch_api == "torchvision.models" :
3471 return f"{ vision_base } models.html"
35- return f"{ vision_base } generated/{ api_name } .html#{ api_name } "
72+ return f"{ vision_base } generated/{ torch_api } .html#{ torch_api } "
73+
74+ # 特殊处理:torch.__version__相关
75+ if torch_api .startswith ("torch.__version__" ):
76+ return base_url # 版本信息通常在首页
77+
78+ # 特殊处理:torch.distributed.ReduceOp枚举值
79+ if torch_api .startswith ("torch.distributed.ReduceOp." ):
80+ return f"{ base_url } distributed.html#{ torch_api } "
81+
82+ # 特殊处理:torch.autograd.Function
83+ if torch_api == "torch.autograd.Function" :
84+ return f"{ base_url } autograd.html#{ torch_api } "
85+
86+ # 特殊处理:torch.utils.cpp_extension
87+ if torch_api .startswith ("torch.utils.cpp_extension" ):
88+ return f"{ base_url } cpp_extension.html#{ torch_api } "
3689
3790 # 1. 处理Tensor相关API
38- if api_name .startswith ("torch.Tensor" ) or api_name == "torch.Tensor" :
39- return f"{ base_url } tensors.html#{ api_name } "
91+ if torch_api .startswith ("torch.Tensor" ) or torch_api == "torch.Tensor" :
92+ return f"{ base_url } tensors.html#{ torch_api } "
4093
4194 # 2. 处理顶层函数(无子模块)
42- if len (api_name .split ("." )) == 2 and api_name .startswith ("torch." ):
43- return f"{ base_url } torch.html#{ api_name } "
95+ if len (torch_api .split ("." )) == 2 and torch_api .startswith ("torch." ):
96+ # 检查是否有专门的generated页面
97+ generated_check = [
98+ "torch.pow" ,
99+ "torch.abs" ,
100+ "torch.add" ,
101+ "torch.sub" ,
102+ "torch.mul" ,
103+ "torch.div" ,
104+ "torch.exp" ,
105+ "torch.log" ,
106+ "torch.sin" ,
107+ "torch.cos" ,
108+ "torch.tan" ,
109+ "torch.sigmoid" ,
110+ ]
111+
112+ if any (torch_api .startswith (prefix ) for prefix in generated_check ):
113+ return f"{ base_url } generated/{ torch_api } .html"
114+ return f"{ base_url } torch.html#{ torch_api } "
44115
45116 # 分割API路径
46- parts = api_name .split ("." )
117+ parts = torch_api .split ("." )
47118 module_path = "." .join (parts [:- 1 ]) # 模块路径
48119 item_name = parts [- 1 ] # 最后一项名称
49120
121+ # 特殊处理:torch.functional函数
122+ if parts [0 ] == "torch" and parts [1 ] == "functional" :
123+ return f"{ base_url } torch.html#{ torch_api } "
124+
50125 # 3. 处理模块级函数/常量
51126 if parts [0 ] == "torch" and not parts [- 1 ][0 ].isupper ():
52127 # 特殊模块映射(基于官方文档结构)
53128 module_map = {
54- "torch.nn.init" : "nn.init" ,
55- "torch.nn.functional" : "nn.functional" ,
56- "torch.cuda.amp" : "amp" ,
57- "torch.distributions" : "distributions" ,
129+ "torch.nn.init" : "nn.init.html" ,
130+ "torch.nn.functional" : "nn.functional.html" ,
131+ "torch.cuda.amp" : "amp.html" ,
132+ "torch.distributions" : "distributions.html" ,
133+ "torch.nn.utils" : "nn.utils.html" ,
134+ "torch.optim" : "optim.html" ,
135+ "torch.random" : "random.html" ,
136+ "torch.special" : "special.html" ,
137+ "torch.distributed" : "distributed.html" ,
138+ "torch.utils.data" : "data.html" ,
58139 }
59140 module_key = "." .join (parts [:- 1 ])
60- module_slug = module_map .get (
61- module_key , module_key .replace ("torch." , "" )
62- )
63- return f"{ base_url } { module_slug } .html#{ api_name } "
141+ module_slug = module_map .get (module_key , f"generated/{ module_key } .html" )
142+
143+ # 检查是否是应该指向generated目录的API
144+ generated_modules = [
145+ "torch.nn.utils.parameters_to_vector" ,
146+ "torch.nn.utils.vector_to_parameters" ,
147+ ]
148+
149+ if torch_api in generated_modules :
150+ return f"{ base_url } generated/{ torch_api } .html"
151+
152+ return f"{ base_url } { module_slug } #{ torch_api } "
64153
65154 # 4. 处理类/独立函数
66155 if parts [- 1 ][0 ].isupper () or len (parts ) == 1 :
67- return f"{ base_url } generated/{ api_name } .html#{ api_name } "
156+ # 特殊类映射
157+ class_map = {
158+ "torch.autograd.Function" : "autograd.html" ,
159+ "torch.utils.cpp_extension.BuildExtension" : "cpp_extension.html" ,
160+ "torch.nn.Module" : "generated/torch.nn.Module.html" ,
161+ }
162+ if torch_api in class_map :
163+ return f"{ base_url } { class_map [torch_api ]} #{ torch_api } "
164+ return f"{ base_url } generated/{ torch_api } .html#{ torch_api } "
68165
69166 # 5. 默认处理(类方法)
70- return f"{ base_url } generated/{ module_path } .html#{ api_name } "
167+ # 特殊处理类方法
168+ class_method_map = {
169+ "torch.nn.Module" : "generated/torch.nn.Module.html" ,
170+ "torch.utils.cpp_extension.BuildExtension" : "cpp_extension.html" ,
171+ }
172+
173+ for class_name , page_name in class_method_map .items ():
174+ if module_path == class_name :
175+ return f"{ base_url } { page_name } #{ torch_api } "
176+
177+ # 默认情况下,尝试生成到generated目录
178+ return f"{ base_url } generated/{ module_path } .html#{ torch_api } "
71179
72180
73181def escape_underscores_in_api (api_name ):
0 commit comments