Skip to content

Commit 8cefaa8

Browse files
committed
fix: replace front dynamo export to torch export
1 parent 308b3a0 commit 8cefaa8

File tree

3 files changed

+153
-164
lines changed

3 files changed

+153
-164
lines changed

python/src/infinitensor/converter/registry.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,57 +8,60 @@
88
class ConverterRegistry:
99

1010
def __init__(self):
11-
self._module_converters: Dict[torch.nn.Module, Callable[[fx.Node], None]] = {}
12-
self._method_converters: Dict[str, Callable[[fx.Node], None]] = {}
11+
# { "aten.matmul": { "default": fn, "out": fn, None: fn } }
12+
self._method_converters: Dict[str, Dict[Optional[str], Callable]] = {}
1313

14-
def register_module(self, module_class):
15-
"""装饰器:注册模块转换器"""
16-
17-
def decorator(func):
18-
self._module_converters[module_class] = func
19-
return func
20-
21-
return decorator
22-
23-
def register_method(self, method_name: str):
14+
def register(self, op_name: str, overload: Optional[str] = None):
2415
"""装饰器:注册方法和函数转换器"""
2516

2617
def decorator(func):
27-
self._method_converters[method_name] = func
18+
self._method_converters.setdefault(op_name, {})[overload] = func
2819
return func
2920

3021
return decorator
3122

32-
def get_module_converter(self, module_class) -> Optional[Callable]:
33-
"""获取模块转换器"""
34-
# 也检查父类
35-
if module_class in self._module_converters:
36-
return self._module_converters[module_class]
37-
return None
38-
39-
def get_method_converter(self, method_name: str) -> Optional[Callable]:
23+
def get_method_converter(self, op_name: str, overload: Optional[str] = None) -> Optional[Callable]:
4024
"""获取方法和函数转换器"""
41-
if method_name in self._method_converters:
42-
return self._method_converters[method_name]
43-
return None
25+
if op_name in self._method_converters:
26+
table = self._method_converters[op_name]
27+
if overload:
28+
if overload in table:
29+
return table[overload]
30+
else:
31+
raise ValueError(f"Unsupported op.overload : {op_name}_{overload}")
32+
else:
33+
if None in table:
34+
return table[None]
35+
else:
36+
raise ValueError(f"Unsupported op.overload : {op_name}")
37+
else:
38+
raise ValueError(f"Unsupported op : {op_name}")
39+
4440

4541
def update(self, custom_converters: Dict):
46-
"""更新转换器"""
42+
"""更新转换器
43+
Args:
44+
custom_converters:
45+
{
46+
(op_name, overload): converter
47+
}
48+
"""
4749
for key, converter in custom_converters.items():
48-
if inspect.isclass(key) and issubclass(key, nn.Module):
49-
self._module_converters[key] = converter
50-
elif isinstance(key, str):
50+
if isinstance(key, tuple) and len(key) == 2:
51+
op_name, overload = key
52+
self._method_converters.setdefault(op_name, {})[overload] = converter
53+
if isinstance(key, str):
5154
self._method_converters[key] = converter
55+
else:
56+
raise TypeError(f"Invalid key type: {type(key)}")
5257

5358
def clear(self):
5459
"""清空所有转换器"""
55-
self._module_converters.clear()
5660
self._method_converters.clear()
5761

5862
def list_all_converters(self):
5963
"""列出所有转换器"""
6064
return {
61-
"modules": list(self._module_converters.keys()),
6265
"methods": list(self._method_converters.keys()),
6366
}
6467

python/src/infinitensor/converter/unified_converters.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
11
import torch.nn as nn
22
from .registry import registry
33

4+
#https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
45

5-
@registry.register_module(nn.Linear)
6-
def convert_linear(translator, node, module):
7-
x = translator.tensors[node.args[0]]
8-
module = translator.named_modules[node.target]
9-
weight = translator.params[module.weight]
10-
bias = translator.params.get(module.bias, None)
11-
translator.tensors[node] = translator.builder.gemm(x, weight, bias, transB=True)
12-
13-
14-
@registry.register_method("matmul")
6+
@registry.register("matmul","default")
157
def convert_matmul(translator, node):
168
a = translator.tensors[node.args[0]]
179
b = translator.tensors[node.args[1]]

0 commit comments

Comments
 (0)