|
41 | 41 | )
|
42 | 42 | from torch._export.verifier import load_verifier
|
43 | 43 | from torch.fx.experimental import symbolic_shapes
|
| 44 | +from torch.fx.traceback import NodeSource |
44 | 45 |
|
45 | 46 | log: logging.Logger = logging.getLogger(__name__)
|
46 | 47 |
|
@@ -141,8 +142,24 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
|
141 | 142 | debug_handle = node.meta["debug_handle"]
|
142 | 143 | meta["debug_handle"] = str(debug_handle)
|
143 | 144 |
|
| 145 | + if "from_node" in node.meta: |
| 146 | + from_node = node.meta["from_node"] |
| 147 | + # Serialize from_node as JSON since it's a complex nested structure |
| 148 | + meta["from_node"] = json.dumps(self._make_from_node_json_acceptable(from_node)) |
| 149 | + |
144 | 150 | return meta
|
145 | 151 |
|
| 152 | + def _make_from_node_json_acceptable(self, from_node: Optional[List[NodeSource]]): |
| 153 | + """ |
| 154 | + Serialize from_node metadata from a list of NodeSource objects to a list of dictionaries. |
| 155 | + """ |
| 156 | + if from_node is None: |
| 157 | + return None |
| 158 | + |
| 159 | + json_acceptable_from_node = [node_source.to_dict() for node_source in from_node if isinstance(node_source, NodeSource)] |
| 160 | + |
| 161 | + return json_acceptable_from_node |
| 162 | + |
146 | 163 | def serialize_alloc_inputs(
|
147 | 164 | self, inputs # pyre-ignore
|
148 | 165 | ) -> List[schema.NamedArgument]:
|
@@ -473,8 +490,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
|
473 | 490 | if debug_handle := metadata.get("debug_handle"):
|
474 | 491 | res["debug_handle"] = int(debug_handle)
|
475 | 492 |
|
| 493 | + if from_node_str := metadata.get("from_node"): |
| 494 | + res["from_node"] = self._deserialize_from_node(json.loads(from_node_str)) |
| 495 | + |
476 | 496 | return res
|
477 | 497 |
|
| 498 | + def _deserialize_from_node(self, from_node_data: Optional[List[Dict[str, Any]]]) -> Optional[List[NodeSource]]: |
| 499 | + """ |
| 500 | + Recursively deserialize from_node metadata from JSON data. |
| 501 | + """ |
| 502 | + if from_node_data is None: |
| 503 | + return None |
| 504 | + |
| 505 | + assert isinstance(from_node_data, list) |
| 506 | + |
| 507 | + return [NodeSource._from_dict(fn_dict) for fn_dict in from_node_data] |
| 508 | + |
478 | 509 | # pyre-ignore
|
479 | 510 | def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
|
480 | 511 | def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:
|
|
0 commit comments