|
41 | 41 | ) |
42 | 42 | from torch._export.verifier import load_verifier |
43 | 43 | from torch.fx.experimental import symbolic_shapes |
| 44 | +from torch.fx.traceback import NodeSource, NodeSourceAction |
44 | 45 |
|
45 | 46 | log: logging.Logger = logging.getLogger(__name__) |
46 | 47 |
|
@@ -141,8 +142,24 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: |
141 | 142 | debug_handle = node.meta["debug_handle"] |
142 | 143 | meta["debug_handle"] = str(debug_handle) |
143 | 144 |
|
| 145 | + if "from_node" in node.meta: |
| 146 | + from_node = node.meta["from_node"] |
| 147 | + # Serialize from_node as JSON since it's a complex nested structure |
| 148 | + meta["from_node"] = json.dumps(self._make_from_node_json_acceptable(from_node)) |
| 149 | + |
144 | 150 | return meta |
145 | 151 |
|
| 152 | + def _make_from_node_json_acceptable(self, from_node: Optional[List[NodeSource]]): |
| 153 | + """ |
| 154 | + Recursively serialize from_node metadata which can be a list of NodeSource objects. |
| 155 | + """ |
| 156 | + if from_node is None: |
| 157 | + return None |
| 158 | + |
| 159 | + json_acceptable_from_node = [node_source.to_dict() for node_source in from_node if isinstance(node_source, NodeSource)] |
| 160 | + |
| 161 | + return json_acceptable_from_node |
| 162 | + |
146 | 163 | def serialize_alloc_inputs( |
147 | 164 | self, inputs # pyre-ignore |
148 | 165 | ) -> List[schema.NamedArgument]: |
@@ -473,8 +490,59 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: |
473 | 490 | if debug_handle := metadata.get("debug_handle"): |
474 | 491 | res["debug_handle"] = int(debug_handle) |
475 | 492 |
|
| 493 | + if from_node_str := metadata.get("from_node"): |
| 494 | + res["from_node"] = self._deserialize_from_node(json.loads(from_node_str)) |
| 495 | + |
476 | 496 | return res |
477 | 497 |
|
| 498 | + def _deserialize_from_node(self, from_node_data): |
| 499 | + """ |
| 500 | + Recursively deserialize from_node metadata from JSON data. |
| 501 | + """ |
| 502 | + if from_node_data is None: |
| 503 | + return None |
| 504 | + |
| 505 | + if isinstance(from_node_data, list): |
| 506 | + return [self._deserialize_from_node(item) for item in from_node_data] |
| 507 | + |
| 508 | + if isinstance(from_node_data, dict): |
| 509 | + # Create a NodeSource object directly without going through the constructor |
| 510 | + # to avoid issues with graph ID and node creation |
| 511 | + node_source = NodeSource.__new__(NodeSource) |
| 512 | + |
| 513 | + # Set the basic attributes |
| 514 | + node_source.pass_name = from_node_data.get('pass_name', '') |
| 515 | + |
| 516 | + # Parse action string back to NodeSourceAction enum list |
| 517 | + action_str = from_node_data.get('action', '') |
| 518 | + actions = [] |
| 519 | + if action_str: |
| 520 | + for action_name in action_str.split('+'): |
| 521 | + if action_name.upper() == 'CREATE': |
| 522 | + actions.append(NodeSourceAction.CREATE) |
| 523 | + elif action_name.upper() == 'REPLACE': |
| 524 | + actions.append(NodeSourceAction.REPLACE) |
| 525 | + node_source.action = actions |
| 526 | + |
| 527 | + # Create the NodeInfo object directly |
| 528 | + if 'name' in from_node_data and 'target' in from_node_data and 'graph_id' in from_node_data: |
| 529 | + node_info = NodeSource.NodeInfo( |
| 530 | + from_node_data.get('name', ''), |
| 531 | + from_node_data.get('target', ''), |
| 532 | + from_node_data.get('graph_id', -1) |
| 533 | + ) |
| 534 | + node_source.node_info = node_info |
| 535 | + else: |
| 536 | + node_source.node_info = None |
| 537 | + |
| 538 | + # Recursively deserialize nested from_node |
| 539 | + node_source.from_node = self._deserialize_from_node(from_node_data.get('from_node', [])) |
| 540 | + |
| 541 | + return node_source |
| 542 | + |
| 543 | + # Fallback for primitive types |
| 544 | + return from_node_data |
| 545 | + |
478 | 546 | # pyre-ignore |
479 | 547 | def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]): |
480 | 548 | def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec: |
|
0 commit comments