1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Dict , List , TYPE_CHECKING
15+ from typing import Dict , List , TYPE_CHECKING , Union
1616
1717if TYPE_CHECKING :
1818 import torch ._ops
1919 import torch .fx
2020import torch
2121from circle_schema import circle
2222
23+ from tico .passes import ops
24+
2325from tico .serialize .circle_mapping import (
2426 extract_circle_dtype ,
2527 extract_torch_dtype ,
2931from tico .serialize .operators .node_visitor import NodeVisitor , register_node_visitor
3032from tico .serialize .operators .utils import create_builtin_operator , get_op_index
3133from tico .utils .errors import NotYetSupportedError
32- from tico .utils .validate_args_kwargs import ToCopyArgs
34+ from tico .utils .validate_args_kwargs import ToCopyArgs , ToDtypeArgs , ToDtypeLayoutArgs
3335
3436
3537@register_node_visitor
3638class ToCopyVisitor (NodeVisitor ):
37- target : List [torch ._ops .OpOverload ] = [ torch . ops .aten ._to_copy . default ]
39+ target : List [torch ._ops .OpOverload ] = ops .aten .to_copy
3840
3941 def __init__ (self , op_codes : Dict [OpCode , int ], graph ):
4042 super ().__init__ (op_codes , graph )
@@ -60,42 +62,55 @@ def define_cast_node(
6062
6163 return operator
6264
65+ def parse_args (self , op : torch ._ops .OpOverload , args , kwargs ):
66+ ret : Union [ToCopyArgs , ToDtypeArgs , ToDtypeLayoutArgs ]
67+ if op is torch .ops .aten ._to_copy .default :
68+ ret = ToCopyArgs (* args , ** kwargs )
69+ elif op is torch .ops .aten .to .dtype :
70+ ret = ToDtypeArgs (* args , ** kwargs )
71+ elif op is torch .ops .aten .to .dtype_layout :
72+ ret = ToDtypeLayoutArgs (* args , ** kwargs )
73+ else :
74+ raise NotImplementedError (f"Unsupported to_copy/to operator: { op } " )
75+
76+ return ret
77+
6378 def define_node (
6479 self ,
6580 node : torch .fx .Node ,
6681 ) -> circle .Operator .OperatorT :
67- supported_kwargs = ["dtype" , "device" , "layout" ]
68- if not all (k in supported_kwargs for k in node .kwargs ):
69- unsupported_node_kargs = list (node .kwargs .keys ())
70- for supported_key in supported_kwargs :
71- if supported_key in node .kwargs :
72- unsupported_node_kargs .remove (supported_key )
73- raise NotYetSupportedError (
74- f"Support only { supported_kwargs } kwargs now. Do not support { unsupported_node_kargs } "
75- )
76-
77- args = ToCopyArgs (* node .args , ** node .kwargs ) # type: ignore[arg-type, call-arg]
82+ args = ToCopyArgs (* node .args , ** node .kwargs ) # type: ignore[arg-type]
7883 input = args .input
7984 dtype = args .dtype
85+ layout = args .layout
86+ # device is meaningless in circle
87+
88+ pin_memory = args .pin_memory
89+ non_blocking = args .non_blocking
90+ memory_format = args .memory_format
91+
92+ if pin_memory is not None :
93+ raise NotYetSupportedError ("Do not support pin_memory yet" )
94+ if non_blocking is True :
95+ raise NotYetSupportedError ("Do not support non_blocking yet" )
96+ if memory_format is not None :
97+ raise NotYetSupportedError ("Do not support memory_format yet" )
8098
8199 input_meta = input .meta ["val" ]
82100 # https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout
83101 # layout is two types: torch.strided(dense Tensors), torch.sparse_coo(sparse COO Tensors)
84102 if "layout" in input .kwargs and input .kwargs ["layout" ] != input_meta :
85103 raise NotYetSupportedError (
86- f"Only support when node and its input have same layout: (input layout: { input_meta } ), (node layout: { node . kwargs [ ' layout' ] } )."
104+ f"Only support when node and its input have same layout: (input layout: { input_meta } ), (node layout: { layout } )."
87105 )
88106
89- if dtype is not None :
90- target_type = node .kwargs ["dtype" ]
91- else :
92- # device and layout are meaningless
93- target_type = extract_torch_dtype (node )
94- assert isinstance (target_type , torch .dtype ), type (target_type )
107+ if dtype is None :
108+ dtype = extract_torch_dtype (node )
109+ assert isinstance (dtype , torch .dtype ), type (dtype )
95110
96111 # define cast node
97112 in_type : int = extract_circle_dtype (input )
98- out_type : int = to_circle_dtype (target_type )
113+ out_type : int = to_circle_dtype (dtype )
99114 inputs = [input ]
100115 outputs = [node ]
101116 operator = self .define_cast_node (inputs , outputs , in_type , out_type )
0 commit comments