@@ -42,6 +42,81 @@ def register_passable_op(op):
4242 passable_ops .append (op )
4343
4444
45+ def insert_rescale_ops_to_int32 (
46+ tosa_graph : ts .TosaSerializer , inputs : list [TosaArg ], node : Node
47+ ) -> tuple [list [TosaSerializerTensor ], float ]:
48+ """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'.
49+ The scales are adjusted using the smallest scale of all 'nodes'.
50+
51+ Returns a list of the rescaled nodes and the scale factor used,
52+ needed by rescale_node_back_to_int8.
53+
54+ This functions is used in serialization to TOSA for target ops that are
55+ handled by the DQ/D folding pass, which stores the quantization parameters
56+ in the node meta dict as opposed to 'rescale_nodes_to_int32' which search
57+ the graph upstream for DQ nodes.
58+ """
59+
60+ tensors = inputs .copy ()
61+
62+ # Reshape tensor according to TOSA dim order
63+ for tensor in tensors :
64+ dim_order = tensor .dim_order
65+ tensor .shape = [tensor .shape [i ] for i in dim_order ]
66+
67+ qargs = list (cast (dict [int , QuantArgs ], node .meta ["input_qparams" ]).values ())
68+
69+ # Scale the int8 quantized input to a common scale in the integer
70+ # domain
71+ min_scale = min ([qarg .scale for qarg in qargs ])
72+ scales = [qarg .scale / min_scale for qarg in qargs ]
73+
74+ rescaled_nodes : list [TosaSerializerTensor ] = []
75+ for tensor , qarg , scale in zip (tensors , qargs , scales ):
76+ rescaled_nodes .append (
77+ build_rescale_to_int32 (
78+ tosa_graph ,
79+ tensor ,
80+ qarg .zp ,
81+ scale ,
82+ )
83+ )
84+ return rescaled_nodes , min_scale
85+
86+
87+ def insert_rescale_node_back_to_int8 (
88+ tosa_graph : ts .TosaSerializer ,
89+ last_tensor : TosaArg ,
90+ scale : float ,
91+ node : Node ,
92+ ) -> None :
93+ """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'.
94+ Parameters:
95+ node: The original node that is being handled by the rescales.
96+ last_tensor:the tosa tensor to rescale back.
97+ scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32'
98+ tosa_graph: the tosa_graph to manipulate.
99+
100+ This functions is used in serialization to TOSA for target ops that are
101+ handled by the DQ/D folding pass, which stores the quantization parameters
102+ in the node meta dict as opposed to 'rescale_node_back_to_int8' which search
103+ the graph downstream for Q nodes.
104+ """
105+ assert len (node .meta ["output_qparams" ]) == 1
106+
107+ qargs_out = cast (dict [int , QuantArgs ], node .meta ["output_qparams" ])[0 ]
108+ output_rescale_scale = scale / qargs_out .scale
109+
110+ # Rescale Back to INT8
111+ build_rescale_from_int32 (
112+ tosa_graph ,
113+ last_tensor .name ,
114+ node .name ,
115+ qargs_out .zp ,
116+ output_rescale_scale ,
117+ )
118+
119+
45120class QuantArgs (NamedTuple ):
46121 scale : float
47122 zp : int
@@ -61,6 +136,20 @@ def quantize_value(self, x):
61136 def dequantize_value (self , qx : int ) -> float :
62137 return (qx - self .zp ) * self .scale
63138
139+ @classmethod
140+ def from_operator (cls , op , args ):
141+ if op in dq_q_ops :
142+ return cls (
143+ scale = cast (float , args [1 ]),
144+ zp = cast (int , args [2 ]),
145+ qmin = cast (int , args [3 ]),
146+ qmax = cast (int , args [4 ]),
147+ dtype = cast (torch .dtype , args [5 ]),
148+ )
149+ else :
150+ # We're only handling per tensor quantization
151+ raise NotImplementedError
152+
64153
65154def quantize_value (x , qargs : QuantArgs , dtype = np .int8 ):
66155 return np .clip (
@@ -77,13 +166,7 @@ def dequantize_value(qx, qargs: QuantArgs):
77166def qargs_from_qnode (node : torch .fx .Node ):
78167 assert node .target in dq_q_ops , f"Op { node } is not a quant node."
79168
80- return QuantArgs (
81- scale = cast (float , node .args [1 ]),
82- zp = cast (int , node .args [2 ]),
83- qmin = cast (int , node .args [3 ]),
84- qmax = cast (int , node .args [4 ]),
85- dtype = cast (torch .dtype , node .args [5 ]),
86- )
169+ return QuantArgs .from_operator (node .target , node .args )
87170
88171
89172def get_neighbour_quant_args (
@@ -214,8 +297,13 @@ def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs:
214297
215298
216299def get_quantized_node_output_dtype (node : torch .fx .Node ) -> torch .dtype :
217- if isinstance (node .target , Callable ) and "tosa" in node .target .__name__ :
218- return node .meta ["val" ].dtype
300+ if isinstance (node .target , Callable ) and "output_qparams" in node .meta .keys ():
301+ # Check if the node has had it's quantization parameters folded
302+ # and retrieve the dtype from the meta dict in that case.
303+ assert len (node .meta ["output_qparams" ]) == 1
304+ qargs = cast (QuantArgs , node .meta ["output_qparams" ][0 ])
305+ return qargs .dtype
306+
219307 if node .target in dq_q_ops :
220308 return cast (torch .dtype , node .args [5 ])
221309
0 commit comments