Skip to content

Commit e203e63

Browse files
author
ssjia
committed
Update on "[etrecord] Implement generic fallback for GraphModuleSerializer.handle_call_function"
Title says it all! Implement the case where `node.target` is neither `torch._ops.OpOverload` or `torch._ops.HigherOrderOperator`, instead of throwing an exception. Differential Revision: [D88216198](https://our.internmc.facebook.com/intern/diff/D88216198/) [ghstack-poisoned]
1 parent 49e0ce4 commit e203e63

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

exir/serde/serialize.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -450,19 +450,23 @@ def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> No
450450
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
451451
return
452452
elif isinstance(target, str):
453-
# Create a dummy fake op if the target does not exist
454-
# because we cannot create a call_function node w/o a
455-
# callable target
456-
log.warning(
457-
f"Could not find operator {target}. Returning fake operator."
458-
) # noqa: G004
459-
460-
# pyre-ignore
461-
def fake_op(x):
462-
raise NotImplementedError("Fake op is not meant to be run.")
463-
464-
fake_op.__name__ = target
465-
target = fake_op
453+
# Special handling for memory ops, which are not EdgeOpOverload but
454+
# are still somewhat expected in serialized graphs.
455+
if target == "executorch.exir.memory.view":
456+
target = exir.memory.view
457+
else:
458+
# Otherwise, create a dummy fake op if the target does not exist
459+
# because we cannot create a call_function node w/o a callable
460+
# target
461+
log.warning(
462+
f"Could not find operator {target}. Returning fake operator."
463+
) # noqa: G004
464+
# pyre-ignore
465+
def fake_op(x):
466+
raise NotImplementedError("Fake op is not meant to be run.")
467+
468+
fake_op.__name__ = target
469+
target = fake_op
466470

467471
args = self.deserialize_inputs_no_schema(serialized_node)
468472
fx_node = self.graph.create_node("call_function", target, args, None, None)

exir/tests/test_serde.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,18 @@ def forward(self, x):
335335
node.meta.get("from_node"), node_new.meta.get("from_node")
336336
):
337337
self.assertEqual(node_source.to_dict(), node_source_new.to_dict())
338+
339+
def test_memory_ops(self) -> None:
340+
class MemoryOpsModule(nn.Module):
341+
def __init__(self):
342+
super().__init__()
343+
344+
def forward(self, x, y):
345+
x = exir.memory.view(x, (10, 10))
346+
return x + y
347+
348+
inputs = (
349+
torch.randn(100),
350+
torch.randn(10, 10),
351+
)
352+
self.check_serde(MemoryOpsModule(), inputs)

0 commit comments

Comments
 (0)