@@ -504,6 +504,7 @@ def handle_call_function(self, node: torch.fx.Node):
504504 assert len (node .kwargs ) == 0
505505 meta_val = node .meta ["val" ]
506506 ex_node = Node (
507+ name = node .name ,
507508 target = self .serialize_operator (node .target ),
508509 inputs = self .serialize_sym_op_inputs (node .target , node .args ),
509510 outputs = [
@@ -517,6 +518,7 @@ def handle_call_function(self, node: torch.fx.Node):
517518 assert len (node .kwargs ) == 0
518519 meta_val = node .meta ["val" ]
519520 ex_node = Node (
521+ name = node .name ,
520522 target = self .serialize_operator (node .target ),
521523 inputs = self .serialize_sym_op_inputs (node .target , node .args ),
522524 outputs = [
@@ -528,6 +530,7 @@ def handle_call_function(self, node: torch.fx.Node):
528530 )
529531 elif isinstance (node .target , torch ._ops .OpOverload ):
530532 ex_node = Node (
533+ name = node .name ,
531534 target = self .serialize_operator (node .target ),
532535 inputs = self .serialize_inputs (node .target , node .args , node .kwargs ),
533536 outputs = self .serialize_outputs (node ),
@@ -536,6 +539,7 @@ def handle_call_function(self, node: torch.fx.Node):
536539 )
537540 elif isinstance (node .target , torch ._ops .HigherOrderOperator ):
538541 ex_node = Node (
542+ name = node .name ,
539543 target = self .serialize_operator (node .target ),
540544 inputs = self .serialize_hoo_inputs (node .args , node .kwargs ),
541545 outputs = self .serialize_hoo_outputs (node ),
@@ -1658,7 +1662,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
16581662
16591663 def deserialize_node (self , serialized_node : Node , target : Callable ) -> None :
16601664 if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS :
1661- name = serialized_node .outputs [ 0 ]. value . as_name
1665+ name = serialized_node .name
16621666 args = self .deserialize_sym_op_inputs (serialized_node .inputs )
16631667
16641668 fx_node = self .graph .create_node ("call_function" , target , args , {}, name )
@@ -1671,12 +1675,7 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
16711675 # have names that are consistent with serialized.
16721676 #
16731677 # HOPs don't have schema yet, just check the output lengths and as_tensor attribute
1674- name = (
1675- serialized_node .outputs [0 ].as_tensor .name
1676- if len (serialized_node .outputs ) == 1
1677- and hasattr (serialized_node .outputs [0 ], "as_tensor" )
1678- else None
1679- )
1678+ name = serialized_node .name
16801679 fx_node = self .graph .create_node (
16811680 "call_function" , target , args , kwargs , name
16821681 )
@@ -1687,16 +1686,30 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
16871686 # For convenience: if this node returns a single tensor, name the
16881687 # newly-created node after it. This ensures that these tensor values
16891688 # have names that are consistent with serialized.
1690- name = (
1691- serialized_node .outputs [0 ].as_tensor .name
1692- if _is_single_tensor_return (target )
1693- else None # FX will generate a name for us.
1694- )
1689+
1690+ print (target )
1691+ print (target .__name__ )
1692+ print (target .name )
1693+
1694+ name = serialized_node .name
1695+
1696+ print (name )
1697+
1698+ if name == "split_tensor" :
1699+ print (serialized_node )
1700+ print (serialized_node .inputs )
1701+ print (serialized_node .outputs )
1702+
16951703 args , kwargs = self .deserialize_inputs (target , serialized_node )
16961704 fx_node = self .graph .create_node (
16971705 "call_function" , target , args , kwargs , name
16981706 )
16991707 self .deserialize_outputs (serialized_node , fx_node )
1708+
1709+ if name == "split_tensor" :
1710+ print (fx_node )
1711+ print (fx_node .args )
1712+ print (fx_node .kwargs )
17001713 else :
17011714 raise SerializeError (
17021715 f"Unsupported target type for node { serialized_node } : { target } "
0 commit comments