Skip to content

Commit 7911ea7

Browse files
tarun292facebook-github-bot
authored andcommitted
Add support for backend dialect in serde
Summary: Adding support for backend dialect ops in EXIR serde. Reviewed By: larryliu0820 Differential Revision: D48106247 fbshipit-source-id: 60424599ec59263c3a4e8e52086eb9a6c8e1f660
1 parent 4090d6c commit 7911ea7

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

exir/serde/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ python_library(
1616
"//executorch/exir:memory",
1717
"//executorch/exir/backend:compile_spec_schema",
1818
"//executorch/exir/dialects:lib",
19+
"//executorch/exir/dialects/backend:lib",
1920
"//executorch/exir/dialects/edge:lib",
2021
],
2122
)

exir/serde/serialize.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from executorch.exir.backend.compile_spec_schema import (
2525
CompileSpec as delegate_CompileSpec,
2626
)
27-
from executorch.exir.dialects._ops import ops as exir_ops
27+
from executorch.exir.dialects._ops import _DialectNamespace, ops as exir_ops
28+
from executorch.exir.dialects.backend._ops import BackendOpOverload
2829
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2930
from executorch.exir.lowered_backend_module import (
3031
LoweredBackendModule as ExirLoweredBackendModule,
@@ -50,20 +51,27 @@ def serialize_operator(
5051
target: Union[
5152
str,
5253
EdgeOpOverload,
54+
BackendOpOverload,
5355
torch._ops.OpOverload,
5456
torch._ops.HigherOrderOperator,
5557
],
5658
) -> str:
5759
if isinstance(target, str):
5860
return target
59-
elif target.__module__.startswith("executorch.exir.dialects"):
61+
elif target.__module__.startswith("executorch.exir.dialects.edge"):
6062
# TODO(zhxchen17) Maybe provide a function name helper in FX.
6163
# From torch.fx.node._get_qualified_name
6264
module = target.__module__.replace(
6365
"executorch.exir.dialects.edge._ops",
6466
"executorch.exir.dialects.edge.ops",
6567
)
6668
return f"{module}.{target.__name__}"
69+
elif target.__module__.startswith("executorch.exir.dialects.backend"):
70+
module = target.__module__.replace(
71+
"executorch.exir.dialects.backend._ops",
72+
"executorch.exir.dialects.backend.ops",
73+
)
74+
return f"{module}.{target.__name__}"
6775

6876
return super().serialize_operator(target)
6977

@@ -337,8 +345,7 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None:
337345
self.state_dict: Dict[str, Any] = state_dict # TODO(T157676982)
338346

339347
def deserialize_operator(self, serialized_target: str) -> str:
340-
if serialized_target.startswith("executorch.exir.dialects.edge.ops"):
341-
module = exir_ops.edge
348+
def find_operator(module: _DialectNamespace, serialized_target: str) -> str:
342349
serialized_target_names = serialized_target.split(".")[5:]
343350

344351
target = module
@@ -349,6 +356,11 @@ def deserialize_operator(self, serialized_target: str) -> str:
349356
target = getattr(target, name)
350357
return target
351358

359+
if serialized_target.startswith("executorch.exir.dialects.edge.ops"):
360+
return find_operator(exir_ops.edge, serialized_target)
361+
elif serialized_target.startswith("executorch.exir.dialects.backend.ops"):
362+
return find_operator(exir_ops.backend, serialized_target)
363+
352364
return super().deserialize_operator(serialized_target)
353365

354366
# pyre-ignore

0 commit comments

Comments
 (0)