44# LICENSE file in the root directory of this source tree.
55
66# pyre-unsafe
7- from typing import List
7+ from typing import Any , List
88
99import torch .fx
1010
11- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
1211from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
1312 get_input_qparams ,
1413 get_output_qparams ,
2120from executorch .backends .arm .tosa_mapping import TosaArg
2221
2322
24- def get_negate_zero_points (node : torch .fx .Node , dtype : ts . DType ) -> tuple [int , int ]:
23+ def get_negate_zero_points (node : torch .fx .Node , is_int8 : bool ) -> tuple [int , int ]:
2524 """
2625 Returns (input1_zp, output_zp) for TOSA NEGATE.
2726 Must be zero for non-int8 types.
2827 """
29- if dtype == ts . DType . INT8 :
28+ if is_int8 :
3029 return (
3130 get_input_qparams (node )[0 ].zp ,
3231 get_output_qparams (node )[0 ].zp ,
@@ -35,38 +34,43 @@ def get_negate_zero_points(node: torch.fx.Node, dtype: ts.DType) -> tuple[int, i
3534
3635
3736@register_node_visitor
38- class NegVisitor (NodeVisitor ):
37+ class NegVisitor_0_80 (NodeVisitor ):
3938 target = "aten.neg.default"
4039
41- supported_dtypes = {
42- ts .DType .INT8 ,
43- ts .DType .INT16 ,
44- ts .DType .INT32 ,
45- ts .DType .FP16 ,
46- ts .DType .BF16 ,
47- ts .DType .FP32 ,
48- }
40+ tosa_specs = NodeVisitor .tosa_specs_0_80
4941
5042 def __init__ (self , * args ):
5143 super ().__init__ (* args )
5244
5345 def define_node (
5446 self ,
5547 node : torch .fx .Node ,
56- tosa_graph : ts . TosaSerializer ,
48+ tosa_graph : Any ,
5749 inputs : List [TosaArg ],
5850 output : TosaArg ,
5951 ) -> None :
52+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
53+
54+ supported_dtypes = {
55+ ts .DType .INT8 ,
56+ ts .DType .INT16 ,
57+ ts .DType .INT32 ,
58+ ts .DType .FP16 ,
59+ ts .DType .BF16 ,
60+ ts .DType .FP32 ,
61+ }
6062
61- if inputs [0 ].dtype not in self . supported_dtypes :
63+ if inputs [0 ].dtype not in supported_dtypes :
6264 raise ValueError (f"Unsupported dtype for NEGATE: { inputs [0 ].dtype } " )
6365
6466 if inputs [0 ].dtype != output .dtype :
6567 raise ValueError (
6668 "All inputs and output need same dtype."
6769 f"Got { inputs [0 ].dtype = } , { output .dtype = } "
6870 )
69- input_zp , output_zp = get_negate_zero_points (node , inputs [0 ].dtype )
71+ input_zp , output_zp = get_negate_zero_points (
72+ node , inputs [0 ].dtype == ts .DType .INT8
73+ )
7074
7175 attr = ts .TosaSerializerAttribute ()
7276 attr .NegateAttribute (input1_zp = input_zp , output_zp = output_zp )
@@ -76,3 +80,57 @@ def define_node(
7680 [output .name ],
7781 attributes = attr ,
7882 )
83+
84+
85+ @register_node_visitor
86+ class NegVisitor (NodeVisitor ):
87+ target = "aten.neg.default"
88+
89+ tosa_specs = NodeVisitor .tosa_specs_1_00
90+
91+ def __init__ (self , * args ):
92+ super ().__init__ (* args )
93+
94+ def define_node (
95+ self ,
96+ node : torch .fx .Node ,
97+ tosa_graph : Any ,
98+ inputs : List [TosaArg ],
99+ output : TosaArg ,
100+ ) -> None :
101+ import serializer .tosa_serializer as ts # type: ignore
102+
103+ supported_dtypes = {
104+ ts .DType .INT8 ,
105+ ts .DType .INT16 ,
106+ ts .DType .INT32 ,
107+ ts .DType .FP16 ,
108+ ts .DType .BF16 ,
109+ ts .DType .FP32 ,
110+ }
111+
112+ if inputs [0 ].dtype not in supported_dtypes :
113+ raise ValueError (f"Unsupported dtype for NEGATE: { inputs [0 ].dtype } " )
114+
115+ if inputs [0 ].dtype != output .dtype :
116+ raise ValueError (
117+ "All inputs and output need same dtype."
118+ f"Got { inputs [0 ].dtype = } , { output .dtype = } "
119+ )
120+ input_zp , output_zp = get_negate_zero_points (
121+ node , inputs [0 ].dtype == ts .DType .INT8
122+ )
123+
124+ input_zp_tensor = tosa_graph .addConst (
125+ (1 ,), inputs [0 ].dtype , [input_zp ], name = output .name + "_input_zp"
126+ )
127+
128+ output_zp_tensor = tosa_graph .addConst (
129+ (1 ,), output .dtype , [output_zp ], name = output .name + "_output_zp"
130+ )
131+
132+ tosa_graph .addOperator (
133+ ts .TosaOp .Op ().NEGATE ,
134+ [inputs [0 ].name , input_zp_tensor .name , output_zp_tensor .name ],
135+ [output .name ],
136+ )
0 commit comments