Skip to content

Commit bc669bc

Browse files
committed
Serializing from_node info in et serializer
We need to use the from_node informaton in deserialzied exported graph for operator tracing in et.inspector. this diff updates the serizalier to support serde from_node info. Differential Revision: [D78293986](https://our.internmc.facebook.com/intern/diff/D78293986/) [ghstack-poisoned]
1 parent 501cfb8 commit bc669bc

File tree

3 files changed

+113
-0
lines changed

3 files changed

+113
-0
lines changed

devtools/etrecord/tests/etrecord_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,17 @@ def check_graph_closeness(self, graph_a, graph_b):
9292
self.assertEqual(
9393
node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle")
9494
)
95+
from_node_a = node_a.meta.get("from_node")
96+
from_node_b = node_b.meta.get("from_node")
97+
98+
if from_node_a is None:
99+
self.assertIsNone(from_node_b)
100+
else:
101+
self.assertIsNotNone(from_node_b)
102+
for node_source_a, node_source_b in zip(from_node_a, from_node_b):
103+
self.assertEqual(
104+
node_source_a.to_dict(), node_source_b.to_dict()
105+
)
95106

96107
def test_etrecord_generation(self):
97108
captured_output, edge_output, et_output = self.get_test_model()

exir/serde/serialize.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from torch._export.verifier import load_verifier
4343
from torch.fx.experimental import symbolic_shapes
44+
from torch.fx.traceback import NodeSource, NodeSourceAction
4445

4546
log: logging.Logger = logging.getLogger(__name__)
4647

@@ -141,8 +142,24 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
141142
debug_handle = node.meta["debug_handle"]
142143
meta["debug_handle"] = str(debug_handle)
143144

145+
if "from_node" in node.meta:
146+
from_node = node.meta["from_node"]
147+
# Serialize from_node as JSON since it's a complex nested structure
148+
meta["from_node"] = json.dumps(self._make_from_node_json_acceptable(from_node))
149+
144150
return meta
145151

152+
def _make_from_node_json_acceptable(self, from_node: Optional[List[NodeSource]]):
153+
"""
154+
Recursively serialize from_node metadata which can be a list of NodeSource objects.
155+
"""
156+
if from_node is None:
157+
return None
158+
159+
json_acceptable_from_node = [node_source.to_dict() for node_source in from_node if isinstance(node_source, NodeSource)]
160+
161+
return json_acceptable_from_node
162+
146163
def serialize_alloc_inputs(
147164
self, inputs # pyre-ignore
148165
) -> List[schema.NamedArgument]:
@@ -473,8 +490,59 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
473490
if debug_handle := metadata.get("debug_handle"):
474491
res["debug_handle"] = int(debug_handle)
475492

493+
if from_node_str := metadata.get("from_node"):
494+
res["from_node"] = self._deserialize_from_node(json.loads(from_node_str))
495+
476496
return res
477497

498+
def _deserialize_from_node(self, from_node_data):
499+
"""
500+
Recursively deserialize from_node metadata from JSON data.
501+
"""
502+
if from_node_data is None:
503+
return None
504+
505+
if isinstance(from_node_data, list):
506+
return [self._deserialize_from_node(item) for item in from_node_data]
507+
508+
if isinstance(from_node_data, dict):
509+
# Create a NodeSource object directly without going through the constructor
510+
# to avoid issues with graph ID and node creation
511+
node_source = NodeSource.__new__(NodeSource)
512+
513+
# Set the basic attributes
514+
node_source.pass_name = from_node_data.get('pass_name', '')
515+
516+
# Parse action string back to NodeSourceAction enum list
517+
action_str = from_node_data.get('action', '')
518+
actions = []
519+
if action_str:
520+
for action_name in action_str.split('+'):
521+
if action_name.upper() == 'CREATE':
522+
actions.append(NodeSourceAction.CREATE)
523+
elif action_name.upper() == 'REPLACE':
524+
actions.append(NodeSourceAction.REPLACE)
525+
node_source.action = actions
526+
527+
# Create the NodeInfo object directly
528+
if 'name' in from_node_data and 'target' in from_node_data and 'graph_id' in from_node_data:
529+
node_info = NodeSource.NodeInfo(
530+
from_node_data.get('name', ''),
531+
from_node_data.get('target', ''),
532+
from_node_data.get('graph_id', -1)
533+
)
534+
node_source.node_info = node_info
535+
else:
536+
node_source.node_info = None
537+
538+
# Recursively deserialize nested from_node
539+
node_source.from_node = self._deserialize_from_node(from_node_data.get('from_node', []))
540+
541+
return node_source
542+
543+
# Fallback for primitive types
544+
return from_node_data
545+
478546
# pyre-ignore
479547
def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
480548
def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:

exir/tests/test_serde.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,37 @@ def forward(self, x):
275275
)
276276
self.assertEqual(metadata[0], metadata_serde[0])
277277
self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys()))
278+
279+
def test_meta_debug_handle_and_from_node(self) -> None:
280+
class Model(nn.Module):
281+
def __init__(self):
282+
super(Model, self).__init__()
283+
self.conv_layer = nn.Conv2d(
284+
in_channels=1, out_channels=64, kernel_size=3, padding=1
285+
)
286+
287+
def forward(self, x):
288+
return self.conv_layer(x)
289+
290+
m = Model()
291+
inputs = (torch.randn(1, 1, 32, 32),)
292+
293+
edge = to_edge(export(m, inputs, strict=True))
294+
edge_new = deserialize(serialize(edge.exported_program()))
295+
for node, node_new in zip(
296+
edge.exported_program().graph_module.graph.nodes,
297+
edge_new.graph_module.graph.nodes,
298+
):
299+
if node.op not in {"placeholder", "output"}:
300+
self.assertIsNotNone(node.meta.get("debug_handle"))
301+
self.assertIsNotNone(node.meta.get("from_node"))
302+
self.assertEqual(
303+
node.meta.get("debug_handle"), node_new.meta.get("debug_handle")
304+
)
305+
self.assertEqual(
306+
len(node.meta.get("from_node")), len(node_new.meta.get("from_node"))
307+
)
308+
for node_source, node_source_new in zip(
309+
node.meta.get("from_node"), node_new.meta.get("from_node")
310+
):
311+
self.assertEqual(node_source.to_dict(), node_source_new.to_dict())

0 commit comments

Comments
 (0)