24
24
from executorch .exir .backend .compile_spec_schema import (
25
25
CompileSpec as delegate_CompileSpec ,
26
26
)
27
- from executorch .exir .dialects ._ops import ops as exir_ops
27
+ from executorch .exir .dialects ._ops import _DialectNamespace , ops as exir_ops
28
+ from executorch .exir .dialects .backend ._ops import BackendOpOverload
28
29
from executorch .exir .dialects .edge ._ops import EdgeOpOverload
29
30
from executorch .exir .lowered_backend_module import (
30
31
LoweredBackendModule as ExirLoweredBackendModule ,
@@ -50,20 +51,27 @@ def serialize_operator(
50
51
target : Union [
51
52
str ,
52
53
EdgeOpOverload ,
54
+ BackendOpOverload ,
53
55
torch ._ops .OpOverload ,
54
56
torch ._ops .HigherOrderOperator ,
55
57
],
56
58
) -> str :
57
59
if isinstance (target , str ):
58
60
return target
59
- elif target .__module__ .startswith ("executorch.exir.dialects" ):
61
+ elif target .__module__ .startswith ("executorch.exir.dialects.edge " ):
60
62
# TODO(zhxchen17) Maybe provide a function name helper in FX.
61
63
# From torch.fx.node._get_qualified_name
62
64
module = target .__module__ .replace (
63
65
"executorch.exir.dialects.edge._ops" ,
64
66
"executorch.exir.dialects.edge.ops" ,
65
67
)
66
68
return f"{ module } .{ target .__name__ } "
69
+ elif target .__module__ .startswith ("executorch.exir.dialects.backend" ):
70
+ module = target .__module__ .replace (
71
+ "executorch.exir.dialects.backend._ops" ,
72
+ "executorch.exir.dialects.backend.ops" ,
73
+ )
74
+ return f"{ module } .{ target .__name__ } "
67
75
68
76
return super ().serialize_operator (target )
69
77
@@ -337,8 +345,7 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None:
337
345
self .state_dict : Dict [str , Any ] = state_dict # TODO(T157676982)
338
346
339
347
def deserialize_operator (self , serialized_target : str ) -> str :
340
- if serialized_target .startswith ("executorch.exir.dialects.edge.ops" ):
341
- module = exir_ops .edge
348
+ def find_operator (module : _DialectNamespace , serialized_target : str ) -> str :
342
349
serialized_target_names = serialized_target .split ("." )[5 :]
343
350
344
351
target = module
@@ -349,6 +356,11 @@ def deserialize_operator(self, serialized_target: str) -> str:
349
356
target = getattr (target , name )
350
357
return target
351
358
359
+ if serialized_target .startswith ("executorch.exir.dialects.edge.ops" ):
360
+ return find_operator (exir_ops .edge , serialized_target )
361
+ elif serialized_target .startswith ("executorch.exir.dialects.backend.ops" ):
362
+ return find_operator (exir_ops .backend , serialized_target )
363
+
352
364
return super ().deserialize_operator (serialized_target )
353
365
354
366
# pyre-ignore
0 commit comments