Skip to content

Commit c6075e4

Browse files
angelayifacebook-github-bot
authored andcommitted
Fix delegate node metadata
Summary: The delegate node's metadata was set incorrectly, causing deserialization to fail Reviewed By: mcr229 Differential Revision: D78350040
1 parent a8070ec commit c6075e4

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

exir/backend/backend_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def generate_debug_handle(ep: ExportedProgram) -> int:
235235
call_submodule_node.kwargs,
236236
)
237237
call_delegate_node.meta["debug_handle"] = generate_debug_handle(owning_program)
238-
call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
238+
call_delegate_node.meta["val"] = [
239+
out_arg.meta["val"] for out_arg in submodule_output_node.args[0]
240+
]
239241
call_submodule_node.replace_all_uses_with(call_delegate_node)
240242
owning_graph_module.graph.erase_node(call_submodule_node)
241243
if is_submodule:
@@ -476,7 +478,6 @@ def _create_partitions_in_graph_module(
476478
submodule_output_node = submodule.graph.output_node()
477479
# Copy the output node meta from the original output node, because
478480
# create_submodule_from_nodes doesn't cover the meta field
479-
submodule_output_node.meta = tagged_graph_module_output_node.meta
480481
logging.debug(f"Partitioned graph module: {tagged_graph_module}")
481482
(
482483
submodule_program,

exir/backend/test/demos/test_xnnpack_qnnpack.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import tempfile
78
import unittest
89

910
import executorch.exir as exir
@@ -41,6 +42,18 @@
4142
prepare_fx,
4243
)
4344

45+
from typing import Tuple
46+
47+
import torch
48+
from executorch import exir
49+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
50+
XnnpackFloatingPointPartitioner,
51+
)
52+
from executorch.exir import (
53+
EdgeCompileConfig,
54+
EdgeProgramManager,
55+
to_edge_transform_and_lower,
56+
)
4457

4558
class TestXnnQnnBackends(unittest.TestCase):
4659
def test_add_xnnpack_and_dqlinear_qnn(self):
@@ -132,3 +145,42 @@ def forward(self, x, y):
132145
self.assertTrue(
133146
torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
134147
)
148+
149+
def test_serde(self):
150+
# The module with blank_logprobs() function
151+
class BlankLogProbsModule(torch.nn.Module):
152+
def __init__(self) -> None:
153+
super().__init__()
154+
self.linear = torch.nn.Linear(768, 1)
155+
self.log_sigmoid = torch.nn.LogSigmoid()
156+
157+
def forward(self, joint_encodings: torch.Tensor) -> torch.Tensor:
158+
tanh_out = torch.tanh(joint_encodings)
159+
linear_out = self.linear(tanh_out)
160+
blank_output = self.log_sigmoid(linear_out)
161+
return blank_output
162+
163+
def get_blank_logprobs_inputs_fn() -> Tuple[torch.Tensor, ...]:
164+
"""
165+
Get the input to the blank_logprobs() and nonblank_logprobs() functions.
166+
"""
167+
return (torch.randn(1, 1, 1, 768),)
168+
169+
model = BlankLogProbsModule()
170+
# Get the inputs for the logprobs function
171+
logprobs_fake_inputs = get_blank_logprobs_inputs_fn()
172+
173+
# Export and partition
174+
aten_prog = torch.export.export(model, logprobs_fake_inputs, strict=True)
175+
partitioned_prog: EdgeProgramManager = to_edge_transform_and_lower(
176+
aten_prog,
177+
partitioner=[XnnpackFloatingPointPartitioner()],
178+
compile_config=EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=True),
179+
)
180+
181+
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
182+
exir.save(partitioned_prog.exported_program(), f.name)
183+
f.seek(0)
184+
loaded_model = exir.load(f.name)
185+
186+
self.assertTrue(torch.allclose(model(*logprobs_fake_inputs), loaded_model.module()(*logprobs_fake_inputs)))

0 commit comments

Comments
 (0)