|
8 | 8 | class ConverterRegistry: |
9 | 9 |
|
10 | 10 | def __init__(self): |
11 | | - self._module_converters: Dict[torch.nn.Module, Callable[[fx.Node], None]] = {} |
12 | | - self._method_converters: Dict[str, Callable[[fx.Node], None]] = {} |
| 11 | + # { "aten.matmul": { "default": fn, "out": fn, None: fn } } |
| 12 | + self._method_converters: Dict[str, Dict[Optional[str], Callable]] = {} |
13 | 13 |
|
14 | | - def register_module(self, module_class): |
15 | | - """装饰器:注册模块转换器""" |
16 | | - |
17 | | - def decorator(func): |
18 | | - self._module_converters[module_class] = func |
19 | | - return func |
20 | | - |
21 | | - return decorator |
22 | | - |
23 | | - def register_method(self, method_name: str): |
| 14 | + def register(self, op_name: str, overload: Optional[str] = None): |
24 | 15 | """装饰器:注册方法和函数转换器""" |
25 | 16 |
|
26 | 17 | def decorator(func): |
27 | | - self._method_converters[method_name] = func |
| 18 | + self._method_converters.setdefault(op_name, {})[overload] = func |
28 | 19 | return func |
29 | 20 |
|
30 | 21 | return decorator |
31 | 22 |
|
32 | | - def get_module_converter(self, module_class) -> Optional[Callable]: |
33 | | - """获取模块转换器""" |
34 | | - # 也检查父类 |
35 | | - if module_class in self._module_converters: |
36 | | - return self._module_converters[module_class] |
37 | | - return None |
38 | | - |
39 | | - def get_method_converter(self, method_name: str) -> Optional[Callable]: |
| 23 | + def get_method_converter(self, op_name: str, overload: Optional[str] = None) -> Optional[Callable]: |
40 | 24 | """获取方法和函数转换器""" |
41 | | - if method_name in self._method_converters: |
42 | | - return self._method_converters[method_name] |
43 | | - return None |
| 25 | + if op_name in self._method_converters: |
| 26 | + table = self._method_converters[op_name] |
| 27 | + if overload: |
| 28 | + if overload in table: |
| 29 | + return table[overload] |
| 30 | + else: |
| 31 | + raise ValueError(f"Unsupported op.overload : {op_name}_{overload}") |
| 32 | + else: |
| 33 | + if None in table: |
| 34 | + return table[None] |
| 35 | + else: |
| 36 | + raise ValueError(f"Unsupported op.overload : {op_name}") |
| 37 | + else: |
| 38 | + raise ValueError(f"Unsupported op : {op_name}") |
| 39 | + |
44 | 40 |
|
45 | 41 | def update(self, custom_converters: Dict): |
46 | | - """更新转换器""" |
| 42 | + """更新转换器 |
| 43 | + Args: |
| 44 | + custom_converters: |
| 45 | + { |
| 46 | + (op_name, overload): converter |
| 47 | + } |
| 48 | + """ |
47 | 49 | for key, converter in custom_converters.items(): |
48 | | - if inspect.isclass(key) and issubclass(key, nn.Module): |
49 | | - self._module_converters[key] = converter |
50 | | - elif isinstance(key, str): |
| 50 | + if isinstance(key, tuple) and len(key) == 2: |
| 51 | + op_name, overload = key |
| 52 | + self._method_converters.setdefault(op_name, {})[overload] = converter |
| 53 | + if isinstance(key, str): |
51 | 54 | self._method_converters[key] = converter |
| 55 | + else: |
| 56 | + raise TypeError(f"Invalid key type: {type(key)}") |
52 | 57 |
|
53 | 58 | def clear(self): |
54 | 59 | """清空所有转换器""" |
55 | | - self._module_converters.clear() |
56 | 60 | self._method_converters.clear() |
57 | 61 |
|
58 | 62 | def list_all_converters(self): |
59 | 63 | """列出所有转换器""" |
60 | 64 | return { |
61 | | - "modules": list(self._module_converters.keys()), |
62 | 65 | "methods": list(self._method_converters.keys()), |
63 | 66 | } |
64 | 67 |
|
|
0 commit comments