|
41 | 41 | ) |
42 | 42 | from torch._export.verifier import load_verifier |
43 | 43 | from torch.fx.experimental import symbolic_shapes |
| 44 | +from torch.fx.traceback import NodeSource |
44 | 45 |
|
45 | 46 | log: logging.Logger = logging.getLogger(__name__) |
46 | 47 |
|
@@ -141,8 +142,24 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: |
141 | 142 | debug_handle = node.meta["debug_handle"] |
142 | 143 | meta["debug_handle"] = str(debug_handle) |
143 | 144 |
|
| 145 | + if "from_node" in node.meta: |
| 146 | + from_node = node.meta["from_node"] |
| 147 | + # Serialize from_node as JSON since it's a complex nested structure |
| 148 | + meta["from_node"] = json.dumps(self._make_from_node_json_acceptable(from_node)) |
| 149 | + |
144 | 150 | return meta |
145 | 151 |
|
| 152 | + def _make_from_node_json_acceptable(self, from_node: Optional[List[NodeSource]]): |
| 153 | + """ |
| 154 | + Serialize from_node metadata from a list of NodeSource objects to a list of dictionaries. |
| 155 | + """ |
| 156 | + if from_node is None: |
| 157 | + return None |
| 158 | + |
| 159 | + json_acceptable_from_node = [node_source.to_dict() for node_source in from_node if isinstance(node_source, NodeSource)] |
| 160 | + |
| 161 | + return json_acceptable_from_node |
| 162 | + |
146 | 163 | def serialize_alloc_inputs( |
147 | 164 | self, inputs # pyre-ignore |
148 | 165 | ) -> List[schema.NamedArgument]: |
@@ -473,8 +490,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: |
473 | 490 | if debug_handle := metadata.get("debug_handle"): |
474 | 491 | res["debug_handle"] = int(debug_handle) |
475 | 492 |
|
| 493 | + if from_node_str := metadata.get("from_node"): |
| 494 | + res["from_node"] = self._deserialize_from_node(json.loads(from_node_str)) |
| 495 | + |
476 | 496 | return res |
477 | 497 |
|
| 498 | + def _deserialize_from_node(self, from_node_data: Optional[List[Dict[str, Any]]]) -> Optional[List[NodeSource]]: |
| 499 | + """ |
| 500 | + Recursively deserialize from_node metadata from JSON data. |
| 501 | + """ |
| 502 | + if from_node_data is None: |
| 503 | + return None |
| 504 | + |
| 505 | + assert isinstance(from_node_data, list) |
| 506 | + |
| 507 | + return [NodeSource._from_dict(fn_dict) for fn_dict in from_node_data] |
| 508 | + |
478 | 509 | # pyre-ignore |
479 | 510 | def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]): |
480 | 511 | def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec: |
|
0 commit comments